Facebook's Knowledge-Assisted NLP

Facebook recently published a blog post about their Retrieval-Augmented Generation (RAG) paper (published in May 2020). The blog post is light on detail, but, as usual, the news coverage is much worse (filled with ads and poorly written). I decided to dive in and figure out what the work was about.

In 2020, Facebook has had several publications about knowledge-aided NLP (KILT, RAG, BART, DPR, BLINK), so in this blog post I'd like to go through what all these acronyms are and how they fit together.

To summarize:

  • Dense Passage Retrieval (DPR): A model which, given a question, retrieves relevant passages from a database of passages built from Wikipedia.
  • BART: A sequence-to-sequence (seq2seq) version of BERT.
  • BLINK: A DPR-based entity linker.
  • Retrieval-Augmented Generation (RAG): A DPR- and BART-based question answering model.
  • Knowledge Intensive Language Tasks (KILT): A benchmark for evaluating quality of knowledge-based tasks, tested on baselines and Facebook's models.

I'll go through these models in more detail below.

Dense Passage Retrieval (DPR)

One of the core components of knowledge-based question answering is passage retrieval. Given a question, passage retrieval selects from a large database candidate passages that may be relevant to the question.

Traditional approaches to this problem include TF-IDF and BM25, which, at their core, are fairly simple models for assessing document similarity based on word frequencies in those documents. There are many variations and changes to the core models that make these work well, such as stemming, smoothing, stopword removal, etc.

Dense Passage Retrieval (DPR) (April 2020), a neural-network based passage retriever, is Facebook's approach to this challenge. Although it's published as an independent paper, it seems like it's really part of the same effort as RAG – published less than two months apart with significant author overlap.

The DPR model consists of two parts: a passage encoder (fine-tuned based on BERT) and a question encoder (also fine-tuned based on BERT). Each encoder returns a dense vector representation of its input by selecting the output embedding of the [CLS] token. Given a question vector and a passage vector, the similarity between the two is the dot product of the two vectors.

A single training sample for this model consists of a question, a positive passage (the passage which has the answer to the question), and a set of negative passages (which are irrelevant to the question). The model is trained to maximize similarity of the question vector to the positive passage vector, and minimize similarity to negative passage vectors. (The loss is equivalent to using the passage similarities as logits and then using a softmax cross-entropy loss.)

Models trained like this are dependent on the quality of the negative passages chosen. If all the negative passages are just completely random unrelated passages, the task is too easy, and the model will learn shallow features based on word frequency, etc, and won't generalize well. In order to learn high-quality representations, the negative passages for a question must be a mix of arbitrary unrelated passages and passages that are close to the positive but are still wrong; the latter of these two are called "hard negatives". DPR is trained with a single hard negative per question, which is sourced by running a BM25 passage retriever and choosing one of its retrieved candidates.

During inference, passage embeddings are generated by the passage encoder and cached. To retrieve passages related to a question, the question is encoded with the question encoder, and then a fast similarity search algorithm (FAISS) is used to find the top $k$ cached passage encodings with maximal similarity to the question encoding.

In the end, this neural passage retriever works better on most datasets than a Lucene-based TF-IDF or BM25 retriever. Given the recent experience in NLP, this isn't too surprising.

In the DPR paper, not only do they implement their passage retriever, but they also implement an extractive question answerer. This is an alternative to RAG, which generates the answer with a seq2seq model, instead of selecting a span in the supporting documents. As we'll see below, RAG, in some sense, is just DPR with a slightly more advanced answer-generating model.

BART: BERT for Seq2Seq Models

BART (Oct 2019) is a model from Facebook that attempts to answer the question: How do you do BERT, but for seq2seq models?

BERT is (effectively) a denoising autoencoder for text, replacing noised [MASK] tokens with the original tokens. BART is a denoising autoencoder as well, but one where the noise function can alter the sequence length, and thus it uses a seq2seq transformer instead of BERT's vanilla feedforward transformer. In some sense, BART is an extension of BERT, since it allows for a strictly more powerful noise function than BERT.

The biggest question in all of this is, what noise function do you use in this setup that yields a useful pretrained model?

In this paper, the following noise functions are evaluated:

  • Token Masking: Replace tokens with [MASK] (as in BERT).
  • Token Deletion: Delete tokens. Do not replace them with [MASK].
  • Token Infilling: Replace a span of text (length 0 upward) with [MASK]. (Thus, [MASK] isn't guaranteed to be a single token.)
  • Sentence Permutation: Shuffle sentences, delineated by periods. (Not ultimately helpful.)
  • Document Rotation: "A token is chosen uniformly at random, and the document is rotated so that it begins with that token. This task trains the model to identify the start of the document." (Not ultimately helpful.)

The encoder encodes a noised input, and the decoder (autoregressively) predicts the original input.

On first glance, this model felt weird to me. The additional noise function flexibility is obviously a positive, but using a seq2seq model to predict the output feels like overkill. However, the decoder effectively learns a smarter copy function, one which alternates between text copying and text generation appropriately.

The pretrained model can be fine-tuned for a variety of tasks, including sequence classification (using final timestep decoder state as output layer), token classification (using last layer decoder state for each token as outputs), and sequence generation (using the full model). Machine translation into English is also tried by re-initializing encoder token embeddings (and keeping the rest of the model).

Interpreting the results here is hard. Document rotation and sentence shuffling do not improve performance, which is unsurprising, given that they resemble next sentence prediction (NSP) in BERT, a loss which has been shown to be unnecessary or even harmful. Text infilling seems to be superior to other noise functions, which isn't too surprising -- it's strictly more general than masking or deletion. BART mostly does well on all the tested tasks, except for one, which seems to be an outlier, as it is best handled by a straight language model. BART isn't any better than SotA (state-of-the-art) on SQuAD and GLUE, but isn't any worse either. BART works better for summarization than other approaches, likely since summarization is a seq2seq task with a lot of copying in it.

All in all, it's a valuable data point, but I don't see this approach becoming popular outside of Facebook, possibly with the exception of summarization. The performance isn't generally superior, and there are too many details to interpret in this paper; it's hard to tell a single cohesive story about this paper. Regardless, it's one of the building blocks of RAG, the paper that initiated this blog post.

BLINK (September 2020) is a recent Facebook model for entity linking. Entity linking is the task or process of connecting a short span of text (a "mention") to an entity, an object in some sort of database with an associated description. Entity linking is inherently knowledge-based, since there can be millions of candidate entities. In some sense, question answering with a database (RAG) and entity linking are very similar tasks, with the caveat that entity linking is guaranteed to only link to a single entity, whereas knowledge-assisted question answering may require multiple supporting sources.

More specifically, BLINK is a zero-shot entity linker, making it even more similar to knowledge-assisted QA. Zero-shot, in this case, means that the entities are not part of the model, so the set of entities can be different during inference than during training. You won't find any learned entity embeddings in this paper, but you will find an entity encoder, so adding an entity just corresponds to running its description through the entity encoder. (In fact, to evaluate this fairly, the entity set used in training is disjoint from the test entity set.)

BLINK operates in two phases for performance reasons. The first phase chooses a set of candidate entities for each mention. The second phase links precisely one of those candidates to the mention. Since the first phase needs to consider millions of entities, it must be incredibly fast, while the second phase can involve more computation for each entity-mention candidate.

Model-wise, BLINK is more or less what you would expect. Phase one of BLINK is more-or-less identical to DPR (see above), with the difference that the hard negatives are sourced by running BLINK itself (rather than BM25, as in DPR).

Phase two of BLINK is yet another transformer (initialized with BERT), this time taking both entities (titles and descriptions) and mentions (along with context) at the same time and outputting a single vector by using the [CLS] output embedding. The output vector for each pair is reduced to a single logit with a fully-connected layer, and these logits are used with a softmax loss (with the target being the correct candidate entity). The candidates are generated for each mention by phase one, which means that any time phase one is retrained, phase two must also be retrained; the training distribution for phase two depends on the phase one performance.

As with DPR, selecting the top $k$ candidates in phase one is done by fast approximate nearest neighbor search (FAISS). A hyperparameter sweep suggests $k=10$ is optimal, and searching through 5.9M entities takes just 2ms at inference time.

To summarize the results: apply this at scale (5.9M entities from wikipedia), and it works great. As usual, train on a large dataset, fine-tune on your smaller dataset. As often lately in NLP, simple model designs and scale dominate the benchmarks.

Retrieval Augmented Generation (RAG)

Now, finally, onto the paper that spawned this blog post.

Retrieval-Augmented Generation (RAG) is a question answering model. It's roughly what you would get if you took DPR and then used your retrieved passages (along with your question) as input to a seq2seq model (pretrained via BART), which was trained to generate your answer.

If you have a passage retriever, you could take its outputs and then feed them as inputs to your seq2seq model, trained to generate the answer to your questions. However, this means that your two models need to be trained in sequence, and that your second model depends on your trained first models. Pipelines like this are harder operationally and generally more brittle – so instead, RAG opts to train this system end-to-end.

This is the key question for RAG, as I see it: How do you jointly train a passage retriever and a seq2seq answer generator?

To train this model end-to-end, you cannot simply choose and use the top passage from DPR. RAG, instead, marginalizes over the top-$k$ passages, and does so in two different ways. This bit is crucial, so I'm going to just screenshot the relevant passage in the paper:

In both of these models, we sum over the probabilities given the different top-$k$ passages, weighted by the probability assigned to each passage by the passage retriever. In sequence-level marginalization, we compute the probability of the target sequence conditional on the chosen passage (for the entire sequence), and then take the weighted average of those probabilities. In token-level marginalization, we compute the probability of the sequence as the product of the probabilities of the tokens, where each token probability is the weighted average of the token probabilities of the model conditioned upon different retrieved passages.

Decoding from these models must be done in different ways. When using token-level marginalization, decoding is easy since we can compute token probabilities (marginalized over passages); we can use a simple beam search. When using sequence-level marginalization, we cannot use a single beam search. Instead, we do $k$ separate beam searches and take the set of all their final candidates. We then evaluate all candidates likelihoods under all possible conditioning passages and score each candidate based on the weighted sum of those likelihoods. (Unfortunately, this is much slower, since each candidate must be evaluated with each possible conditioning passages; for $n$ candidates per search, you might have to do $k^2n$ evaluations, since each of $k$ passages might generate $n$ candidates which each must be evaluated under all $k$ conditionings.)

Much of the paper focuses on evaluating the created model and decoding schemes, and in general, the results are good. They show that substituting the knowledge base (by using a Wikipedia snapshot from a different year) significantly changes the answers, which is important as it demonstrates that the passages are being used effectively. There's no clear conclusion as to whether a token-level or sequence-level marginalization is preferable in general; it depends somewhat on the task. Questions that require using multiple sources (such as Jeopardy) are easier for token-level marginalization, but if the answer is generally contained in one passage, the results are less clear.

It's worth noting that even though this QA system seems state-of-the-art on many metrics, it still only achieves 50% accuracy (give or take) in human evaluations of its answers to Jeopardy questions. So we're pretty far from a simple end-to-end system which can reliably synthesize cohesive responses to any Jeopardy question when using Wikipedia as a database.

Knowledge Intensive Language Tasks (KILT)

Machine learning research is driven not only by model development, but also by a variety of other factors, such as hardware developments, datasets, and metrics. In NLP, there exist several commonly-used benchmarks for assessing model quality, such as SQuAD (for question answering) and GLUE (for general language understanding). These benchmarks are crucial for measuring the quality of models and the progress of the field as a whole. Additionally, benchmarks are key to every researchers dream – claiming state-of-the-art performance on their task of choice.

Knowledge Intensive Language Tasks (KILT) (September 2020) is a new benchmark from Facebook to assess progress in NLP for areas that require access to a large database of factual information. KILT is based on a snapshot of Wikipedia with five key tasks: fact checking, entity linking, slot filling, question answering, and dialogue.

Facebook's desire for a new benchmark is easily understandable, given all the work described above. On one hand, a benchmark is necessary for them to evaluate their own models and measure modeling progress, and given that a benchmark is necessary, it might as well be public and have an associated publication. On the other hand, this benchmark is practically explicitly created for Facebook's models to do well on – so it should come as no surprise that the best-performing models in the KILT paper are Facebook's BLINK, DPR, BART, and RAG. The paper ends up being half benchmark and half showing off the quality of the aforementioned models.

Even though the benchmark seems tailored to Facebook's prior work, it nonetheless seems like a very useful addition. From the best performance, it's clear that there is probably still room to improve, with maximum accuracy across all the tasks peaking at about 80%. We can't know for sure, as the benchmark doesn't include a human evaluation – it's possible that human performance would be no higher than the current models (although I doubt that's the case). The release also includes a library, so that future papers can evaluate their models on KILT using the same data and evaluation criteria.

It's too early to tell if this ends up being useful for the field. In the month since publication, there have been no citations, and the Github project is moderately quiet with 12 commits and 3 (closed) issues. However, it's only been a month, and the benchmark is also available through HuggingFace's library, which receives quite a bit of use and may drive more adoption. We'll see in 3-6 months whether this benchmark gets any uptake.

Summary

In the past year, folks at Facebook have done a ton of good work on knowledge-aided NLP. Almost all of the work is based on taking snapshots of Wikipedia, chunking it up into small BERT-sized passages, and then using BERT-based encoders and dot-product similarity to look up passages relevant to various target tasks. There's clearly room for improvement on all fronts, but explicitly incorporating knowledge databases into neural NLP seems like a great direction, and the results generally support that.

References