Rejection Sampling in Alignment
Rejection sampling (Best-of-N) generates N candidate responses from a language model, scores each with a reward model, and selects the highest-scoring output – providing an implicit KL-constrained policy improvement that captured most of the alignment gains in Llama 2, often matching PPO while being far simpler.
What Is Rejection Sampling?
RL algorithms like PPO and GRPO update model weights to produce better outputs. But there is a much simpler way to improve outputs: generate many candidates and pick the best one.
flowchart LR
S1["N responses generated"]
S2["scored by reward model"]
S3["best selected"]
S1 --> S2
S2 --> S3
This is rejection sampling, also called Best-of-N. Generate complete responses, score each with a reward model, return the highest-scoring one.
The analogy: imagine writing an important email. You do not send the first draft. You write several versions, reread them, and send the best. Rejection sampling does exactly this, with a reward model as the editor.
What makes this powerful is its theoretical properties. Selecting the best of samples is mathematically equivalent to sampling from a KL-constrained distribution – exactly what PPO optimizes, but without any gradient updates. The KL is approximately , meaning increasing gives diminishing returns. captures ~80% of the available improvement; adds relatively little.
How It Works
flowchart LR
S1["Rejection sampling performance curves"]
S2["diminishing returns as N increases"]
S1 --> S2
Best-of-N at Inference Time
- Sample: Generate independent completions from policy for prompt .
- Score: Evaluate each with the reward model: .
- Select: Return .
This requires forward passes (parallelizable) and reward evaluations. No backward passes, no weight updates.
The Implicit KL Constraint
The theoretical insight (Stiennon et al., 2020; Gao et al., 2022): Best-of-N from distribution is equivalent to sampling from a tilted distribution with:
Quality improvement scales as :
- : KL , ~60% of maximum improvement
- : KL , ~80% of maximum improvement
- : KL , ~90% of maximum improvement
- : KL , ~95% of maximum improvement
This implicit constraint prevents reward hacking: even with large , the selected response stays within the model’s natural output distribution.
Iterative Rejection Sampling for Training
The real power is using rejection sampling to generate training data. The Llama 2 pipeline:
- Sample: Generate completions per prompt.
- Score and select: Use reward model to select the best per prompt.
- Fine-tune: Train on selected completions with standard SFT (cross-entropy loss).
- Iterate: Use the improved model to sample new completions; repeat.
Each iteration distills the reward model’s preferences into the policy’s weights. Llama 2 found that iterative rejection sampling captured most alignment improvement, with PPO adding only modest additional gains.
Rejection Sampling + DPO (Llama 3)
Llama 3 combined rejection sampling with DPO:
- Generate many responses per prompt with the current policy.
- Score with reward model; select best and worst as preference pairs .
- Fine-tune with DPO on these on-policy preference pairs.
- Iterate with the improved model.
This “online DPO” addresses DPO’s off-policy weakness by generating fresh, on-policy data each iteration.
Why It Matters
- Simplicity: No RL algorithm, no critic, no PPO hyperparameters. Generate-score-select in a few dozen lines of code.
- Stability: No training instability during selection. Only standard SFT occurs.
- Effectiveness: Llama 2 ablations showed rejection sampling captured the large majority of alignment improvement.
- Implicit regularization: KL bound provides natural reward hacking protection.
- Training data generation: Widely used beyond alignment for synthetic data curation.
Practical Considerations
- Temperature: Higher sampling temperatures (-) increase diversity, making high-quality outliers more likely. Too high produces incoherent outputs.
- Batched generation: All samples can run in a single batched forward pass, efficient on modern GPUs.
- Reward model calibration: Only relative rankings matter for selection, not absolute scores – a weaker requirement than PPO needs.
- Storage: samples across 100K+ prompts produces terabytes of text. Efficient data pipelines are essential.
- Verifier combination: For math/code, rejection sampling pairs naturally with execution-based verification.
Key Technical Details
- Llama 2: samples per prompt, iterative RS fine-tuning, reward model re-trained between iterations.
- Compute: forward passes per prompt. For at 70B, substantial but easily parallelized.
- Reward model ceiling: Selected outputs can only be as good as the reward model’s ranking ability.
- Diminishing returns: scaling. Doubling from 32 to 64 provides much less improvement than doubling from 2 to 4.
- RS vs. PPO: Llama 2 reported comparable alignment quality, with PPO showing small advantages mainly on safety benchmarks.
- On-policy freshness: Iterative sampling keeps data on-policy, avoiding distribution shift that degrades off-policy methods.
Common Misconceptions
- “Rejection sampling is not real alignment.” It performs KL-constrained optimization – the same type as PPO, via sampling rather than gradients. Llama 2 validated this empirically.
- “More samples is always better.” scaling means returns diminish rapidly beyond .
- “Only useful at inference time.” Most impactful use is training data generation, where sampling cost is amortized over the dataset’s lifetime.
- “PPO makes it unnecessary.” Llama 2 and 3 both used rejection sampling alongside RL. The approaches are complementary.
Connections to Other Concepts
- Rlhf: Rejection sampling can replace or complement PPO as the policy optimization step.
- Grpo: Both sample multiple outputs per prompt. GRPO uses all for policy gradients; RS uses only the best for SFT.
- Dpo: Best and worst samples form on-policy preference pairs for DPO (as in Llama 3).
- Reward Modeling: RS quality is entirely dependent on the reward model’s ranking ability.
- Synthetic Data: RS is a common technique for high-quality synthetic data generation beyond alignment.
Further Reading
- “Llama 2: Open Foundation and Fine-Tuned Chat Models” (Touvron et al., 2023, arXiv:2307.09288) – Demonstrates iterative rejection sampling as a core alignment technique with ablations comparing to PPO. [arXiv]
- “The Llama 3 Herd of Models” (Dubey et al., 2024, arXiv:2407.21783) – Combines rejection sampling with DPO in iterative online alignment. [arXiv]
- “Scaling Laws for Reward Model Overoptimization” (Gao et al., 2022, arXiv:2210.10760) – Establishes the theoretical scaling relationship between Best-of-N, KL divergence, and reward improvement. [arXiv]