This blog post takes you through the mathematical foundations of RAG, from its initial conception to advanced domain adaptation techniques.
The basic idea behind RAG is simple yet powerful:
- Encode all text chunks in your dataset into vector representations
- Encode your prompt using the same model
- Find the most similar text chunks to your prompt
- Provide these chunks as additional context to the LLM
- Allow the LLM to generate a response using both its internal knowledge and this retrieved context.
This process resembles how we use search engines — when you don’t know something, you search for relevant information before formulating a complete answer.
The Mathematical Foundation of RAG
At its core, RAG combines two probabilistic components: a retriever and a generator (or language model). Let’s break down the mathematics behind each.
The Retriever Component
The retriever is mathematically represented as pₙ(z|x), which means “the probability of retrieving document z given query x, using parameters n.” In practice, this retriever returns the top-k most relevant documents from a corpus.
Most RAG systems use what’s called a Dense Passage Retriever (DPR), which consists of:
- A document encoder: converts documents into vector embeddings
- A query encoder: converts queries into vector embeddings in the same vector space
Both encoders are typically based on models like BERT. The similarity between a document and query is calculated as the dot product between their respective embeddings.
Mathematically, when training a retriever, we aim to maximize the probability of retrieving relevant documents and minimize the probability of retrieving irrelevant ones. This is often done using contrastive learning techniques.
The Generator Component
The generator (LLM) is mathematically represented as p_θ(yᵢ|x, z, y₁:ᵢ₋₁), which reads as “the probability of generating the next token yᵢ given the query x, retrieved documents z, and previously generated tokens y₁:ᵢ₋₁, using parameters θ.”
In the simplest case, the generator concatenates the query and retrieved documents and generates the output token by token. The overall probability of generating the entire response y given query x and retrieved documents z is:
p(y|x,z) = ∏ᵢ p_θ(yᵢ|x, z, y₁:ᵢ₋₁)
This means the probability of the entire output is the product of the probabilities of each token given all previous tokens, the query, and retrieved documents.
Advanced RAG: Domain Adaptation
One limitation of basic RAG systems is that they may not perform well on specialized domains without appropriate tuning. This is where domain adaptation comes into play.
Asynchronous Re-indexing for Domain Adaptation
When adapting RAG to specific domains, one challenge is updating document embeddings whenever the query encoder changes. A clever solution is asynchronous re-indexing:
As shown in the image, the system uses separate threads or GPUs to re-encode and re-index the document collection periodically while the main training continues. This enables training both the query encoder and document encoder together without computational bottlenecks.
Finetuning the Retriever
To adapt RAG to specialized domains, we can finetune the retriever itself. One approach uses a technique where we optimize the retriever based on how useful each retrieved document is for answering the query:
p_LSR(c|x,y) = exp(p_LM(y|c ∘ x)/τ) / ∑c’∈C exp(p_LM(y|c’ ∘ x)/τ)
This formula represents the probability of selecting context c given input x and desired output y. It measures how much each context contributes to generating the correct answer.
The goal is to have the retriever’s output distribution p_R(c|x) match this ideal distribution p_LSR(c|x,y) by minimizing their Kullback-Leibler divergence:
𝓛(𝓓_R) = 𝔼_{(x,y)∈𝓓_R} KL(p_R(c|x) || p_LSR(c|x,y))
The Dual Training Approach
For larger models, end-to-end training becomes impractical. A more efficient approach is dual training where we:
- Train the language model to better utilize retrieved information (p_LM)
- Train the retriever to return more relevant content (p_R)
This is captured by the following formula:
p_LM(y|x,C’) = ∑_{c∈C’} p_LM(y|c ∘ x) · p_R(c|x)
This equation tells us that the probability of generating output y given input x and retrieved contexts C’ is the weighted sum of generating y from each individual context, weighted by how relevant each context is.
The language model is trained with this loss function:
𝓛(𝓓_L) = -∑_i ∑_j log p_LM(y_i|c_{ij} ∘ x_i)
This function encourages the model to generate the correct output regardless of whether the retrieved information is correct, teaching it to utilize helpful information while ignoring irrelevant or incorrect information.
Practical Improvements and Results
When implemented correctly, these mathematical foundations translate to measurable improvements in model performance:
- Enhanced knowledge retrieval: RAG systems consistently outperform standalone LLMs on knowledge-intensive tasks.
- Adaptability to new domains: Domain-adapted RAG models can achieve significantly better performance on specialized tasks.
- Source attribution: Retrieved passages provide natural attribution for model outputs.
- Reduced hallucination: By grounding responses in retrieved text, RAG systems are less likely to generate factually incorrect information.
Conclusion
The mathematics behind RAG systems reveals why they’ve become so critical for extending LLM capabilities.
Understanding these mathematical foundations gives insight into how these systems can be optimized and extended to create even more powerful AI assistants.