Cross Encoder Models: Architecture, Training, and RAG Reranking

Deep dive into cross encoder models for RAG — transformer architecture for relevance scoring, training with MS MARCO, fine-tuning, and production deployment.

Cross Encoder Models: The Precision Engine Behind RAG Reranking

If you’ve used a reranker in your RAG pipeline, you’ve used a cross encoder. But understanding what these models actually do — and how they differ architecturally from the bi-encoders used in vector search — reveals why they’re so much more accurate and when you should consider training your own.

Architecture Deep Dive

A cross encoder takes a pair of texts as input and outputs a single relevance score. It processes query and document jointly, with full transformer attention between every token in both sequences.

Input:
[CLS] How does attention work in transformers? [SEP]
Transformers use scaled dot-product attention to compute
dependencies between tokens. The attention score is computed
as softmax(QK^T / sqrt(d_k)) * V, where Q, K, V are
learned projections of the input. [SEP]
Architecture:
┌─────────────────────────────────────────────────────┐
│ [CLS] Q1 Q2 Q3 [SEP] D1 D2 D3 D4 ... [SEP] │
│ ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ │
│ 12-layer BERT (or RoBERTa, DeBERTa) │
│ Full cross-attention │
│ Q1 attends to D1, D2... Dn simultaneously │
│ D1 attends to Q1, Q2... Qn simultaneously │
└─────────────────────────────────────────────────────┘
↓ [CLS] representation
Linear(768 → 1) → Relevance Score: 0.92

Every word in the query can directly attend to every word in the document. This allows the model to understand things like:

  • “The query asks about ‘Python 3.10 walrus operator’ and this document introduces the walrus operator in Python 3.8”
  • Entity coreference: “its” in the query refers to “BERT” mentioned in the document
  • Numeric comparison: the query asks for “less than 5ms” and the document says “average latency is 3.2ms”

None of this is possible with separate bi-encoder embeddings.

Standard Backbone Models

DeBERTa: The Preferred Backbone

DeBERTa (Decoupled Enhanced BERT) is the current preferred backbone for cross encoders because its disentangled attention mechanism handles relative positions better than BERT:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load DeBERTa-based cross encoder
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" # small, fast
# model_name = "BAAI/bge-reranker-v2-m3" # multilingual
# model_name = "cross-encoder/ms-marco-electra-base" # ELECTRA backbone
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

Model Family Comparison

ModelBackboneParamsLatency (20 docs)NDCG@10
ms-marco-MiniLM-L-6-v2BERT-mini22M25ms0.694
ms-marco-MiniLM-L-12-v2BERT-mini33M42ms0.700
bge-reranker-baseBERT278M85ms0.697
bge-reranker-largeBERT-large560M140ms0.702
bge-reranker-v2-m3multilingual568M145ms0.721
ms-marco-deberta-baseDeBERTa184M70ms0.708

Training Cross Encoders

Cross encoders are trained on labeled query-document pairs with binary or graded relevance labels. MS MARCO (Microsoft Machine Reading Comprehension) is the dominant training dataset.

MS MARCO Training Format

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
import torch
dataset = load_dataset("ms_marco", "v1.1")
# Training examples: (query, positive_doc, negative_doc) triplets
# Cross encoder trained to score positive > negative for same query
class CrossEncoderDataset(torch.utils.data.Dataset):
def __init__(self, pairs, labels, tokenizer, max_length=512):
self.encodings = tokenizer(
[p[0] for p in pairs], # queries
[p[1] for p in pairs], # documents
truncation=True,
padding=True,
max_length=max_length,
return_tensors="pt"
)
self.labels = torch.tensor(labels, dtype=torch.float)
def __getitem__(self, idx):
return {
"input_ids": self.encodings["input_ids"][idx],
"attention_mask": self.encodings["attention_mask"][idx],
"labels": self.labels[idx],
}
def __len__(self):
return len(self.labels)

Pairwise vs Pointwise Training

Pointwise: Train to predict binary label (relevant=1, not relevant=0):

# Loss: Binary Cross Entropy
loss = F.binary_cross_entropy_with_logits(score, label)

Pairwise: Train so that positive_doc score > negative_doc score by margin:

# Loss: Margin ranking loss
loss = F.margin_ranking_loss(pos_score, neg_score, target=1, margin=1.0)

Pairwise training often produces better ranking quality because the model directly optimizes the ordering relationship rather than absolute relevance scores.

Fine-Tuning for Domain Adaptation

A general MS MARCO cross encoder may not perform well on your domain. Fine-tuning on domain-specific data often produces 10–30% improvement:

from sentence_transformers import CrossEncoder, InputExample
from torch.utils.data import DataLoader
# Your domain-specific training data
training_examples = [
InputExample(texts=["What is our API rate limit?",
"The API rate limit is 1000 requests per minute."], label=1),
InputExample(texts=["What is our API rate limit?",
"Our company was founded in 2015."], label=0),
# ... hundreds more examples
]
# Start from pretrained cross encoder
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", num_labels=1)
# Fine-tune
train_dataloader = DataLoader(training_examples, shuffle=True, batch_size=16)
model.fit(
train_dataloader=train_dataloader,
epochs=3,
warmup_steps=100,
output_path="./models/domain-reranker",
)

Weak Supervision for Fine-Tuning Data

Generating labeled training data is expensive. Weak supervision reduces this cost:

  1. Collect query-document pairs where users clicked (implicit positive labels)
  2. Use BM25 to generate hard negatives (documents that match keywords but aren’t relevant)
  3. Use an LLM to label a small subset for quality anchoring
  4. Train on the weakly labeled dataset, validate on the LLM-labeled set

Batched Inference for Production

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class CrossEncoderReranker:
def __init__(self, model_name: str, device: str = "cuda", batch_size: int = 32):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.to(device)
self.model.eval()
self.device = device
self.batch_size = batch_size
def rerank(self, query: str, documents: list[str], top_n: int = 5) -> list[tuple]:
pairs = [[query, doc] for doc in documents]
all_scores = []
for i in range(0, len(pairs), self.batch_size):
batch = pairs[i:i + self.batch_size]
features = self.tokenizer(
[p[0] for p in batch], [p[1] for p in batch],
padding=True, truncation=True,
max_length=512, return_tensors="pt"
).to(self.device)
with torch.no_grad():
scores = self.model(**features).logits.flatten().cpu().tolist()
all_scores.extend(scores)
ranked = sorted(
zip(documents, all_scores),
key=lambda x: x[1], reverse=True
)
return ranked[:top_n]

2025 Trend: LLM-as-Reranker

Large language models themselves are being used as rerankers, leveraging their deep language understanding for relevance assessment:

import anthropic
client = anthropic.Anthropic()
def llm_rerank(query: str, candidates: list[str], top_n: int = 3) -> list[str]:
doc_list = "\n\n".join([f"[{i+1}] {doc[:300]}" for i, doc in enumerate(candidates)])
response = client.messages.create(
model="claude-haiku-4-5-20251001",
max_tokens=100,
messages=[{
"role": "user",
"content": f"""Rank these documents by relevance to the query.
Return ONLY the document numbers in order of relevance, most relevant first.
Query: {query}
Documents:
{doc_list}
Ranking (numbers only, comma separated):"""
}]
)
ranking_text = response.content[0].text.strip()
indices = [int(x.strip()) - 1 for x in ranking_text.split(",")]
return [candidates[i] for i in indices[:top_n] if i < len(candidates)]

LLM-as-reranker is slower and more expensive than cross encoders but shows superior performance on complex reasoning tasks where domain understanding matters more than semantic similarity. It’s increasingly used as a final reranking stage after cross-encoder prefiltering.

Cross encoder models are the precision layer that turns good RAG into great RAG. For most teams, using a pretrained cross encoder (FlashRank for speed, BGE-reranker-large for quality) is the right starting point before considering fine-tuning.