Back to Blog
The LLM Data Company

Mismatch Praxis: Rollout Settings and IS Corrections

·Qi Fang*, Daanish Khazi*
*Equal contribution

TL;DR: Using separate backends for rollout and training creates distribution mismatch even with identical weights. Non-default sampling increases the mismatch: temperature is partially correctable by scaling logits consistently on both rollout and trainer (e.g., with processed_logprobs in vLLM), while top-p/top-k require actor-side truncation and renormalization. Importance sampling techniques correct for drift, but introduce complexity in practice. We show that sequence-level IS collapses when T×KLCLAMPT \times \text{KL} \geq \text{CLAMP} (VERL uses CLAMP=20), and often earlier due to variance and ESS collapse. On the other hand, token-level IS is stable and outperforms no correction despite its bias. Engineering fixes (FP16/FP8 unification, deterministic kernels) and rejection sampling variants (Seq-MIS, Geo-RS) remain as alternative solutions but add infrastructure or tuning complexity. In our experiments, Token-TIS with default sampling (temperature=1.0, top_p=1.0, top_k=-1) was the most robust choice.


1. The Mismatch Problem

Modern LLM-RL frameworks separate rollout generation and gradient computation across different backends, producing different policies despite equivalent parameters.

Let πrollout\pi_{\text{rollout}} denote the inference policy (e.g., vLLM/SGLang) and πlearner\pi_{\text{learner}} the training policy (e.g., FSDP/Megatron). In practice, the parameter update computed is:

ExD  Eyπrollout(x)  [R(x,y)  θlogπlearner(yx)]\mathbb{E}_{x\sim \mathcal{D}}\; \mathbb{E}_{y\sim \pi_{\text{rollout}}(\cdot|x)}\;\big[\, R(x,y)\; \nabla_{\theta}\log \pi_{\text{learner}}(y|x)\,\big]

whereas the on-policy objective requires:

ExD  Eyπlearner(x)  [R(x,y)  θlogπlearner(yx)]\mathbb{E}_{x\sim \mathcal{D}}\; \mathbb{E}_{y\sim \pi_{\text{learner}}(\cdot|x)}\;\big[\, R(x,y)\; \nabla_{\theta}\log \pi_{\text{learner}}(y|x)\,\big]

This turns training into an off-policy setting unless corrected. As Yao et al.[1] demonstrate, even identical model parameters can produce contradictory predictions where πvLLM(a)1\pi_{\text{vLLM}}(a) \approx 1 while πFSDP(a)0\pi_{\text{FSDP}}(a) \approx 0 for certain tokens.

Beyond precision differences, sampling transforms (temperature, top-p, top-k) create additional mismatch that interacts with importance sampling in subtle ways.

Engineering solutions have been introduced including: FP16 unification across trainer/sampler[2], casting inference head to FP32[3], and batch invariant kernels[4]. These fixes reduce mismatch at the source but require infrastructure changes. This post focuses on immediately ready solutions that don’t require modifying your training stack: (1) statistical corrections and (2) sampling transforms.


2. Sampling Transforms and Mismatch

Practical View: How VERL Handles Temperature

In VERL, the Actor scales logits by temperature before computing log-probabilities, e.g.:

# In dp_actor.py, line 272
logits.div_(temperature)

On the inference side, vLLM has two modes for returning logprobs:

  • raw_logprobs (default): Returns logprobs as if temperature=1.0, ignoring the actual sampling temperature
  • processed_logprobs: Returns logprobs after applying temperature scaling

The mismatch scenarios for temperature:

vLLM ModeTemperatureActor BehaviorResult
raw_logprobsτ=1.0\tau = 1.0Divides by 1.0 (no-op)No mismatch
raw_logprobsτ1.0\tau \neq 1.0Divides by τ\tauMismatch: Actor is scaled while Rollout is not
processed_logprobsτ=1.0\tau = 1.0Divides by 1.0 (no-op)No mismatch
processed_logprobsτ1.0\tau \neq 1.0Divides by τ\tauCorrected: Both use same temperature

To enable processed logprobs in VERL:

+actor_rollout_ref.rollout.engine_kwargs.vllm.logprobs_mode: "processed_logprobs"

How Temperature Interacts with Mismatch

Even with processed_logprobs enabled, the sampling temperature τ changes how backend mismatches show up in practice.

In an idealized, infinite-precision world where the rollout and learner logits differ by a fixed perturbation ϵ\epsilon (i.e., zrollout=zlearner+ϵz^\text{rollout} = z^\text{learner} + \epsilon), the two distributions at temperature τ\tau are

pτ=softmax((z+ϵ)/τ),qτ=softmax(z/τ).p_\tau = \text{softmax}\big((z+\epsilon)/\tau\big),\quad q_\tau = \text{softmax}\big(z/\tau\big).

As τ\tau \to \infty, both distributions converge to uniform, so the KL divergence KL(pτqτ)\mathrm{KL}(p_\tau \,\|\, q_\tau) actually goes to zero. In other words, purely at the level of “two softmaxes over the same logits plus a fixed perturbation,” increasing τ\tau tends to reduce the KL between them.

What we care about in real systems, however, is not this clean mathematical limit but how temperature interacts with two kinds of mismatch:

  1. Head mismatch – differences among high‑probability tokens (often due to FP16/FP8 rounding, different attention kernels, etc.).
  2. Tail mismatch – differences among extremely low‑probability tokens (underflow to zero, different numerical cutoffs, or different kernel implementations).

Temperature moves mass between head and tail, changing which mismatch matters more.

  • τ<1\tau < 1 (sharpening).
    Sharpening concentrates probability on the head. If we linearize the log‑probabilities for a given token,

    logpτ(z)=zτlogjezj/τ,\log p_\tau(z) = \frac{z}{\tau} - \log \sum_j e^{z_j/\tau},

    then a small logit perturbation ϵ\epsilon changes the log‑probability by roughly

    ΔlogpτO(ϵτ),\Delta \log p_\tau \approx O\big(\tfrac{\|\epsilon\|}{\tau}\big),

    so head discrepancies are amplified by ~1/τ1/\tau. If the two backends disagree on which token is top‑1, sharpening makes one assign probability 1\approx 1 and the other 0\approx 0, causing large local KL spikes.

  • τ>1\tau > 1 (flattening).
    Flattening dampens those head discrepancies but pushes more mass into the tail: tokens that were effectively zero become non‑negligible. Any disagreement in the tail (e.g., one backend underflowing to 0 while the other keeps a tiny but nonzero probability) now contributes more to KL. In finite‑precision systems, we often see new mismatch appear from these tail effects as τ\tau grows.

In our GSM8K long‑horizon experiments we found that τ1.0\tau \approx 1.0 was a robust operating point: sharpening (τ1\tau \ll 1) made head‑level discrepancies much more damaging, while very large τ\tau caused the rollout–trainer KL and χ2\chi^2 metrics to grow due to tail noise. This is an empirical “sweet spot” for our setup, not a general theorem about temperature.

How VERL Handles Top-P and Top-K

Unlike temperature, the Actor does not naively replicate top-p/top-k truncation and computes log-probabilities over the full vocabulary using a standard softmax.

The mismatch scenarios for top-p/top-k:

vLLM ModeTop-P/Top-KActor BehaviorResult
raw_logprobsDefault (1.0, -1)Full softmaxNo mismatch
raw_logprobsNon-defaultFull softmaxMismatch: vLLM samples from truncated dist, Actor computes over full dist
processed_logprobsDefault (1.0, -1)Full softmaxNo mismatch
processed_logprobsNon-defaultFull softmaxMismatch: vLLM returns correctly truncated probs, Actor returns full probs

When you use non-default top-p/top-k, the importance ratio becomes biased regardless of logprobs_mode:

Without processed_logprobs: The rollout logprobs are computed as if from the full distribution, but the actual samples came from a truncated distribution. The importance weights are computed on the wrong distribution entirely.

With processed_logprobs: The rollout logprobs reflect the truncated distribution (renormalized over kept tokens), but the Actor logprobs are from the full distribution. The importance ratio πActor/πrollout\pi_{\text{Actor}} / \pi_{\text{rollout}} is still biased.

A potential fix is to replicate the top-p/top-k truncation on the actor side by sorting the vocabulary and renormalizing.

What If the Actor Also Applies Top-p/Top-k?

One natural fix is to replicate the truncation on the trainer side: recompute the top-p/top-k mask from the trainer’s logits and renormalize. However, this approach is fragile. Because the trainer and rollout have slightly different logits, tokens near the cutoff boundary can flip in or out of the mask. Suppose top_k=50 during rollout, and the sampled token was the 49th most likely. If the trainer’s logits differ slightly and that token becomes the 51st, it falls outside the recomputed mask. The trainer assigns it zero probability, causing logπ(a)=\log \pi(a) = -\infty and broken gradients. As such, removing any top-p/top-k truncation is the simplest way to reduce mismatch amplification.

After we ran our experiments, DeepSeek [5] published a more elegant solution in their V3.2 publication yesterday: “Keep Sampling Mask”. Instead of recomputing the mask, you preserve the rollout’s mask and apply it to the trainer’s distribution. Both policies now share identical action subspaces.

This approach is motivated by deliberate choice to “avoid sampling extremely low-probability tokens that would be used as optimization targets.” DeepSeek’s sampling mask strategy for handling top-p/top-k unlocks this rollout setting and is a worthwhile addition to RL frameworks.

Summary: Sampling Configuration Tradeoffs

ParameterSafe ValueEffect of Deviation
temperature1.0Any τ1\tau \neq 1 tends to increase KL. Correctable with processed_logprobs.
top_p1.0Creates mismatch. Correctable with actor-side replication.
top_k-1 (disabled)Creates mismatch. Correctable with actor-side replication.

The simplest path: temperature=1.0, top_p=1.0, top_k=-1 with logprobs_mode="processed_logprobs". Non-default top-p/top-k settings require actor-side replication, ideally with DeepSeek’s Keep Sampling Mask approach.

Actor-Side Sampling Support (as of Dec 2, 2025)

FrameworkActor Temp ScalingActor Top-P/K Mask
VERLFSDP, MegatronNot implemented
Prime-RLFSDPNot implemented
SkyRLFSDP/DeepSpeed, MegatronNot implemented
TRLAccelerateNot implemented
OATDeepSpeed, DeepSpeed (fused)Not implemented
SLIMEFSDP, MegatronNot implemented

3. Importance Sampling Corrections

This section summarizes results from Liu & Li, Part 2: A Trial of Gradient Estimators[6]. We recommend reading their derivations in detail for intuition.

Importance sampling (IS) corrects for mismatch by reweighting samples from the rollout distribution μ\mu to match the training distribution π\pi:

Eyπ[f(y)]=Eyμ[π(y)μ(y)f(y)]\mathbb{E}_{y \sim \pi}[f(y)] = \mathbb{E}_{y \sim \mu}\left[\frac{\pi(y)}{\mu(y)} f(y)\right]

For sequences, the importance weight decomposes into per-token ratios:

ρ(y)=t=1Tρt=t=1Tπ(ytst)μ(ytst)\rho(y) = \prod_{t=1}^T \rho_t = \prod_{t=1}^T \frac{\pi(y_t|s_t)}{\mu(y_t|s_t)}

Token-Level Reweighting

Reweight each token independently:

g^token=tmin(π(atst)μ(atst),C)Atθlogπ(atst)\hat{g}_{\text{token}} = \sum_t \min\left(\frac{\pi(a_t|s_t)}{\mu(a_t|s_t)}, C\right) \cdot A_t \cdot \nabla_\theta \log \pi(a_t|s_t)

This fails to account for state visitation. Token probabilities depend on context, and by ignoring how earlier tokens affect later states, token-level IS ignores the mismatch in the state visitation distribution and only reweights the per‑step action probabilities. As Liu & Li [6] show, this induces a structural bias that scales as O(T2Δmax)O(T^2 \Delta_{\max}) when the rollout and learner policies diverge: in effect, we estimate gradients under the behavior state distribution dμ(s)d_\mu(s) instead of the true on‑policy distribution dπ(s)d_\pi(s).

Sequence-Level Reweighting

To preserve the correct state distribution, use the full sequence-level product:

g^seq=(tρt)tAtθlogπ(atst)\hat{g}_{\text{seq}} = \left(\prod_t \rho_t\right) \cdot \sum_t A_t \cdot \nabla_\theta \log \pi(a_t|s_t)

This estimator is unbiased: Eμ[ρ(y)f(y)]=Eπ[f(y)]\mathbb{E}_\mu[\rho(y) \cdot f(y)] = \mathbb{E}_\pi[f(y)].

The tradeoff: pathological variance. Even with 1% per-token drift, a 1000-token sequence has weight (0.99)10000.00004(0.99)^{1000} \approx 0.00004. Variance scales as:

Var[ρ(y)](1+χˉ2)T1\text{Var}[\rho(y)] \approx (1 + \bar{\chi}^2)^T - 1

With χ2=0.01\chi^2 = 0.01 and T=1000T=1000: (1.01)100020,000×(1.01)^{1000} \approx 20{,}000\times variance amplification.

Truncation

To control variance, we can clip the sequence product at threshold CC:

w=min(tρt,C)w = \min\left(\prod_t \rho_t, C\right)

This introduces bias but bounds variance. Liu & Li[6] show bias O(T(1+χ2)/C)O(T(1+\chi^2)/C) and variance O(T2C2)O(T^2 C^2), a trade-off that is controllable by adjusting C.

Table 1: Theoretical Properties

MethodImportance WeightBiasVarianceFailure Mode
Token IS (truncated)wt=min(ρt,C)w_t = \min(\rho_t, C) per tokenO(T2Δmax)O(T^2\Delta_{\max})O(T2(1+χˉ2))O(T^2(1+\bar{\chi}^2))Structural bias scaling with T2T^2
Sequence IS (untruncated)w=tρtw = \prod_t \rho_t per sequence00 (unbiased)O((1+χˉ2)T)O((1+\bar{\chi}^2)^T)Exponential variance explosion
Sequence IS (truncated)w=min(tρt,C)w = \min(\prod_t \rho_t, C)O(T(1+χ2)/C)O(T(1+\chi^2)/C)O(T2C2)O(T^2 C^2)Controllable until clamp (T×KLCLAMPT \times \text{KL} \geq \text{CLAMP})

Under the simplified assumptions in Liu & Li’s analysis (independent steps, bounded per‑token drift), sequence‑level IS is unbiased and truncated Seq‑TIS achieves a more favorable bias–variance trade‑off. In real long‑horizon LLM settings, however, we find that the variance explosion can dominate these theoretical advantages.


4. The T x KL Boundary

The Core Problem

Sequence-level IS multiplies per-token ratios across the whole sequence. Small drifts compound exponentially. The sequence weight in log-space is:

logw=t=1Tlogρt=t=1Tlogπold(xt)πrollout(xt)\log w = \sum_{t=1}^T \log \rho_t = \sum_{t=1}^T \log \frac{\pi_{\text{old}}(x_t)}{\pi_{\text{rollout}}(x_t)}

Linking to KL Divergence

Let ρt=πold(xt)πrollout(xt)\rho_t = \frac{\pi_{\text{old}}(x_t)}{\pi_{\text{rollout}}(x_t)} be the per‑token importance ratio, and

w=t=1Tρtw = \prod_{t=1}^T \rho_t

the sequence‑level IS weight for a single trajectory.

VERL reports a rollout–train KL metric (rollout_corr/kl) that, up to Monte Carlo noise, tracks the expected per‑token drift

KLglobal=E[logπrollout(xt)logπold(xt)].\mathrm{KL}_\text{global} = \mathbb{E}\big[\log \pi_{\text{rollout}}(x_t) - \log \pi_{\text{old}}(x_t)\big].

For a single trajectory τ=(x1,,xT)\tau = (x_1,\dots,x_T), we can form a path‑wise Monte Carlo estimate

KL^(τ)=1Tt=1T(logπrollout(xt)logπold(xt)).\widehat{\mathrm{KL}}(\tau) = \frac{1}{T}\sum_{t=1}^T \big(\log \pi_{\text{rollout}}(x_t) - \log \pi_{\text{old}}(x_t)\big).

Using log(a/b)=log(b/a)\log(a/b) = -\log(b/a), this is directly related to the log‑weight:

logw(τ)=tlogρt=t(logπold(xt)logπrollout(xt))=TKL^(τ).\log w(\tau) = \sum_t \log \rho_t = \sum_t \big(\log \pi_{\text{old}}(x_t) - \log \pi_{\text{rollout}}(x_t)\big) = -T \cdot \widehat{\mathrm{KL}}(\tau).

So for each individual trajectory, the sequence‑level importance weight is

w(τ)=exp(TKL^(τ)),w(\tau) = \exp\big(-T \cdot \widehat{\mathrm{KL}}(\tau)\big),

and the global rollout_corr/kl metric can be viewed as the expectation of these per‑trajectory KL estimates. In other words, “total KL” along a trajectory directly controls the magnitude of its sequence‑level IS weight in log‑space.

Numerical Safety Clamps

Frameworks clamp the log-sum to prevent overflow and underflow given finite precision. Let CLAMP\text{CLAMP} denote this threshold (VERL uses CLAMP=20\text{CLAMP} = 20; this could theoretically be adjusted based on precision constraints):

wclamped=exp(clamp(logw,CLAMP,CLAMP))w_{\text{clamped}} = \exp\left(\text{clamp}\left(\log w, -\text{CLAMP}, \text{CLAMP}\right)\right)

With CLAMP=20\text{CLAMP} = 20, this bounds weights to [exp(20),exp(20)][2×109,4.85×108][\exp(-20), \exp(20)] \approx [2 \times 10^{-9}, 4.85 \times 10^{8}].

Once the typical per‑token KL drift is large enough that TKLglobalCLAMPT \cdot \mathrm{KL}_\text{global} \gg \text{CLAMP}, most trajectories will have logw(τ)\log w(\tau) pushed to the numerical floor or ceiling. In that regime, the effective IS correction largely disappears: almost all samples share (approximately) the same clamped weight.

The Clamp Threshold

Sequence-level IS is viable only when T×KL<CLAMPT \times \text{KL} < \text{CLAMP}.

The maximum viable sequence length at CLAMP=20\text{CLAMP}=20 is:

Tmax=CLAMPKLT_{\max} = \frac{\text{CLAMP}}{\text{KL}}

KL (per token)TmaxT_{\max} (tokens)
0.0006~33,000
0.001~20,000
0.002~10,000
0.005~4,000
0.01~2,000
0.02~1,000
0.05~400

With typical mismatch (KL ~0.001–0.005) and sequences under 10K tokens, you’re usually safe from clamp.

Sequence IS Can Fail Before Clamp

Staying under clamp (T×KL<CLAMPT \times \text{KL} < \text{CLAMP}) is necessary but not sufficient.

Recall that the variance of the sequence IS estimator scales as (1+χˉ2)T(1 + \bar{\chi}^2)^T—exponential in sequence length. Even with truncation clipping high weights, low weights remain unconstrained. Most sequences contribute almost nothing to the gradient while a few dominate. This is an effective sample size (ESS) problem: ESS=(iwi)2/iwi2\text{ESS} = (\sum_i w_i)^2 / \sum_i w_i^2 collapses when weights span orders of magnitude. In our experiments Sequence-TIS consistently underperformed Token-TIS.


5. Experimental Results

Setup and Task

We train Qwen‑3 8B on GSM8K in VERL with:

  • vLLM rollouts with logprobs_mode="processed_logprobs", FSDP training backend, default float precision
  • GRPO‑style advantage estimation
  • Learning rate 1×1061\times10^{-6}, train batch size 32, PPO mini‑batch size 32 (one update per batch, effectively on‑policy)
  • Rollout n=8n=8, max response length 10000 tokens, KL regularization disabled
  • Raw truncated IS weights (rollout_is_batch_normalize=False)
  • 8xH100

Reward is the sum of a correctness term (GSM8K exact match on the #### answer line) and a length bonus that encourages longer post‑answer explanations. The reward structure is defined as follows:

  • Hard Gating: If the answer is incorrect or does not follow the required format (e.g., missing ####), the total_reward = 0.
  • Success Case: If the answer is correct and structural requirements are met:

total_reward=1.0+min(tokens_after_answer,max_length)max_length\text{total\_reward} = 1.0 + \frac{\min(\text{tokens\_after\_answer}, \text{max\_length})}{\text{max\_length}}

In our experiments, this results in a total reward range of 0 to 2: 0 for failure, and a value between 1.0 and 2.0 for success (1.0 for correctness plus up to 1.0 for length). This design deliberately rewards long but correct outputs, pushing the model to use most of the 10k‑token budget while still solving the math problem.

We use a hand‑crafted prompt that:

  • forces a structured answer with a single #### line,
  • separates a short pre‑answer reasoning segment from a longer post‑answer explanation,
  • disables any hidden/tool “thinking mode” so all reasoning appears in the visible response.

This setup allows us to stress‑test importance sampling in the true long‑horizon regime where sequence‑level IS is expected to exhibit variance explosion.

On top of this common setup, we vary:

  • Correction methods: no correction, Token‑TIS (rollout_is=token with truncation at 2.0), Seq‑TIS (rollout_is=sequence with truncation at 4.0, and log‑weight clamp at ±20)
  • Sampling configurations:
    • Temperature‑only: temperature ∈ {0.4, 1.0}, top_p=1.0, top_k=-1
    • Qwen3’s recommended sampling: temperature=0.7, top_p=0.8, top_k=20 (“non‑thinking default”)

This gives a 3×3 grid: three correction modes × three sampling configurations. All runs share the same optimizer, data, and initialization.

5.1 Final Performance

Figure 1(a)

(a)

Figure 1(b)

(b)

Figure 1. Training reward on GSM8K under temperature‑only configurations. (a) temperature=0.4. (b) temperature=1.0. Each panel compares no-correction, Token‑TIS, and Seq‑TIS.

Figure 1 shows the learning curves of the training reward. At temperature=1.0, Token‑TIS achieves strong final scores and recovers more quickly from mid‑training degradation than no‑correction, eventually reaching comparable performance. In contrast, Seq‑TIS is consistently the worst, lagging behind and plateauing at a noticeably lower value. Lowering the temperature to temperature=0.4, the relative ordering stays the same: Token‑TIS and no‑correction both steadily improve and converge to similar final scores of around 1.91.9, while Seq‑TIS plateaus at 1.4\approx 1.4 with slight degradation late in training. The temperature=0.4 run did not show obvious degradation in training reward relative to temperature=1.0.

5.2 Catastrophic Variance of Sequence-Level IS

Figure 2(a) Importance Weights Mean

(a)

Figure 2(b) Sequence-Level Chi-Square

(b)

Figure 2(c) Effective Sample Size

(c)

Figure 2(d) Fraction of Low-Weight Samples

(d)

Figure 2. Catastrophic variance of sequence-level IS (Seq-TIS) at rollout temperature=1.0. (a) Mean importance weight. (b) Sequence-level χ2\chi^2 divergence (log scale, capped at 10410^4). (c) Effective sample size (ESS). Note that ESS reported here is normalized by batch size (ESS/NESS/N). (d) Fraction of low-weight samples.

The variance of sequence-level IS weights tells a consistent story of instability. While Figure 2(a) shows the mean weights, Figure 2(b) provides a direct view of the variance by plotting the sequence-level χ2\chi^2 divergence. For Seq-TIS, χ2\chi^2 routinely exceeds 10310^3 and occasionally spikes into the 10510^510610^6 range (log scale, capped at 10410^4 in the plot), implying that the second moment E[w2]\mathbb{E}[w^2] of the sequence-level weights is dominated by a handful of extreme outliers. This behavior is symptomatic of catastrophically high-variance importance weights: a small number of trajectories receive extremely large weights, while the majority are heavily down-weighted.

Figure 2(c) makes this explicit by plotting the ESS for Token-TIS and Seq-TIS. Token-TIS maintains an ESS around 0.90.91.01.0, indicating that almost all samples in the batch contribute meaningfully to the update. In contrast, Seq-TIS hovers around the 0.30.30.40.4 range.

Figure 2(d) shows that for Seq-TIS 80–90% of samples (and hence the majority of trajectories) fall into the low-weight bucket. We define the “low-weight fraction” as the percentage of sequences whose sequence-level IS weight falls below 1/C1/C (where C=4.0C=4.0 is the truncation threshold). This confirms that most sequences contribute negligibly, while a few outliers carry almost all the weight.

Putting these diagnostics together, Seq-TIS in our long-horizon setting consistently operates in a regime where the raw IS weights exhibit extremely heavy tails and the effective batch size is drastically reduced. This explains why Seq-TIS underperforms despite being theoretically less biased: the variance cost completely dominates any bias reduction.

5.3 Output Length and Seq-IS Explosion / Underflow

To isolate how sequence length interacts with Seq-IS, we bucket trajectories by response length and plot the low-weight fraction against response length for the two temperature-only runs (temperature=0.4 and temperature=1.0) in Figure 3.

Figure 3 Length vs Seq-IS Failure

Figure 3. Fraction of low-weight trajectories vs. mean response length. Trend lines show rolling average.

The trend is monotonic: for short answers, only about 20%30%20\%–30\% of sequences are low-weight, but as we approach the longest responses (near the 10k-token limit), over 80–90% of trajectories are assigned negligible Seq-IS weights. In other words, for the longest outputs almost every sample in the batch is effectively discarded, and the update is driven by a tiny number of surviving trajectories.

Lower temperature makes this failure mode slightly worse. The temperature=0.4 curve sits consistently above the temperature=1.0 curve, meaning that sharpening the rollout distribution accelerates the underflow of sequence-level weights.

5.4 Qwen3 Sampling Configuration: Universal Failure

We now turn to Qwen3’s recommended non‑thinking decoding configuration:

  • temperature=0.7, top_p=0.8, top_k=20

Using the same GSM8K setup and three correction modes, this sampling configuration causes training to fail across the board.

Figure 4(a)

(a)

Figure 4(b)

(b)

Figure 4(c)

(c)

Figure 4. Training reward and IS diagnostics under the Qwen non‑thinking decoding configuration (temperature=0.7, top_p=0.8, top_k=20). (a) Training reward collapses. (b) Nearly 100% of Seq-TIS trajectories are down-weighted. (c) ESS indicates collapse.

Figure 4(a) shows that all three correction strategies either collapse or remain flat—neither Token‑TIS nor Seq‑TIS can recover, unlike in the temperature‑only runs. Figure 4(b–c) explains why Seq-TIS failed:

  • The fraction of low‑weight trajectories (4b) under Seq‑TIS saturates near 1.0. The saturation means almost every sample is being heavily down-weighted.
  • The ESS (4c) shows a deceptive stability (near 1.0) which contradicts the collapse in reward. Normalized ESS 1\approx 1 usually implies high-quality, uniform weights. However, when T×KL20T \times \text{KL} \gg 20, almost all weights wiw_i hit the clamp value (e20e^{-20} or e20e^{20}). When weights are identical (clamped), the variance is zero and ESS trivially returns 1. This is a “false positive”: the weights are uniform not because the distributions match, but because they have all been clipped to the same safety floor.

This matches the earlier analysis of top‑p/top‑k mismatch: vLLM samples from a truncated distribution while the Actor computes log‑probs over the full vocabulary. Even with processed_logprobs enabled, the trainer never sees the truncated support, so importance ratios become unreliable.

We emphasize that this does not mean Qwen3’s recommended decoding is “wrong” in general. It works well for its intended inference‑time use cases. Our takeaway is: reusing deployment‑oriented decoding settings for long‑horizon RL training, without actor‑side truncation or IS‑aware tuning, can be very brittle.

Figure 5(a) KL Divergence

(a)

Figure 5(b) Mean Response Length

(b)

Figure 5. Diagnostics for Qwen recommended settings. (a) Rollout-train KL divergence and implied maximum sequence length. The right axis shows the theoretical maximum sequence length (Tmax20/KLT_{\max} \approx 20/\text{KL}). (b) Mean response length. Qwen3's recommended settings (Orange) plateau at ~600 tokens, matching the steady-state KL clamp limit.

To quantify the failure, Figure 5(a) compares the KL divergence across the three settings. While the safe temperature=1.0 baseline maintains a low KL (0.001\sim 0.001), Qwen’s recommended settings cause the KL to spike and stabilize around 0.033\sim 0.033.

Recall the Tmax20/KLT_{\max} \approx 20/\text{KL} rule of thumb.

  • Initial Optimism: At step 0, the KL starts at ~0.01, suggesting a theoretical limit of 2000 tokens.
  • Steady-State Reality: As training progresses, the KL stabilizes at ~0.033. This pushes the effective sequence length limit down to 20/0.03360020 / 0.033 \approx 600 tokens.

Figure 5(b) shows this “prison”: the Qwen run (Orange) plateaus at exactly ~600 tokens. We hypothesize that most sequences crossing this length boundary either hit the clamp (uniform e20e^{-20} weights) or receive extremely low IS weights that contribute negligible gradient. Note that VERL’s current implementation does not include a lower-bound truncation (e.g., to 1/C1/C), which means these underflowing weights are simply clamped to the safety log-barrier (e.g., e20e^{-20}). However, even if we were to add lower-bound truncation, we would face a dilemma: either (1) all weights hit the lower bound, causing IS to degenerate to uniform weighting (no correction), or (2) a sparse few high-weight sequences dominate, reproducing the variance collapse seen in Section 5.2. Both scenarios lead to training instability.


6. Alternatives: Rejection Sampling

Beyond truncated IS, Liu & Li[6] propose rejection sampling approaches:

Sequence-level rejection (Seq-MIS): Instead of clipping extreme weights, reject those sequences entirely (seq_is_rs in VERL). The intuition: extreme weights often indicate garbage (numerical glitches, reward hacks) rather than signal. However, this still uses the sequence-level estimator, so it inherits the same exponential variance problem.

Geometric rejection (Geo-RS): Use the geometric mean ρ1/T\rho^{1/T} instead of the product for rejection decisions (geo_rs in VERL). This normalizes for length as a 100-token sequence with 1.1× per-token drift gets the same trust score as a 10-token sequence with 1.1× drift. However, this is a finnicky hyperparameter to tune: recommended thresholds are 1.0002–1.001, but in our experiments means were 1.01–1.05, and every sample was rejected.

Liu & Li recommend combining both: Geo-RS-Seq-TIS uses the geometric mean for rejection decisions and the clipped sequence weight for gradient magnitude. This is theoretically elegant but still inherits the variance problems of sequence-level estimation along with tuning difficulty.


7. Conclusion

Modern LLM-RL separates inference from training for efficiency. Even with identical weights, precision differences and sampling transforms create off-policy drift.

Sampling parameters affect mismatch in different ways:

  • Temperature can be corrected with processed_logprobs, but any τ1\tau \neq 1 tends to increase KL
  • Top-p/top-k are likely correctable with actor-side replication (not yet standard in most frameworks)

The standard fix is importance sampling. Theoretically, sequence-level IS is optimal. Practically, it fails in two ways:

  1. The T×KLT \times \text{KL} boundary: When T×KLCLAMPT \times \text{KL} \geq \text{CLAMP}, the safety clamp produces uniform weights—no correction signal.
  2. ESS collapse before clamp: Even with T×KLCLAMPT \times \text{KL} \ll \text{CLAMP}, log-normal weight distributions crush effective sample size.

In our GSM8K experiments with Qwen‑3 8B, long (up to 10k‑token) responses, and hard‑gated rewards, sequence‑level TIS systematically underperformed token‑level TIS: it suffered from catastrophic variance, low effective sample size, and frequent weight clamping, while Token‑TIS remained stable and consistently improved over no‑correction.

We do not claim that Seq‑TIS is universally worse than Token‑TIS. In shorter‑horizon tasks, or with tighter KL control, shorter responses, or more aggressive rejection sampling, the theoretical advantages of sequence‑level correction may still be attainable. However, in the long‑horizon, high‑mismatch regime we study, the variance cost of Seq‑TIS dominates its bias benefits.

There are alternatives: (1) engineering fixes (FP16 unification, FP32 head casting, deterministic kernels) reduce mismatch at the source at the cost of speed and engineering complexity, and (2) rejection sampling variants (Seq-MIS, Geo-RS) filter outliers but require careful hyperparameter tuning. In our experiments, Token-TIS with default sampling parameters proved most robust.

What worked in our experiments:

  • Correction: Token-TIS
  • Sampling: temperature=1.0, top_p=1.0, top_k=-1
  • VERL config: +actor_rollout_ref.rollout.engine_kwargs.vllm.logprobs_mode="processed_logprobs"

References

  1. Yao, F., Liu, L., et al. (2025). “Your Efficient RL Framework Secretly Brings You Off-Policy RL Training.” https://fengyao.notion.site/off-policy-rl

  2. Qi, P., Liu, Z., Zhou, X., Pang, T., Du, C., Lee, W. S., & Lin, M. (2025). “Defeating the Training-Inference Mismatch via FP16.” https://github.com/sail-sg/Precision-RL

  3. MiniMax et al. (2025). “MiniMax-M1: Scaling Test-Time Compute Efficiently with Lightning Attention.” arXiv:2506.13585

  4. He, H., et al. (2025). “Defeating Nondeterminism in LLM Inference.” Thinking Machines. https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/

  5. DeepSeek-AI. (2025). “DeepSeek-V3.2-Exp: Boosting Long-Context Efficiency with DeepSeek Sparse Attention.”

  6. Liu, J., Li, Y., Fu, Y., Wang, J., Liu, Q., & Shen, Y. (2025). “When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch.” https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda


Acknowledgements

Thank you to Rohit Rastogi for reading early drafts.

Thank you to the Baseten team (Raymond Cano, Aghilan Nathan) for making it a breeze to run experiments.


Citations

Citation:

Fang, Qi and Khazi, Daanish, "Mismatch Praxis: Rollout Settings and IS Corrections", 
LLM Data, Dec 2025.

BibTeX:

@misc{fangkhazi2025mismatch,
  author = {Qi Fang and Daanish Khazi},
  title = {Mismatch Praxis: Rollout Settings and IS Corrections},
  howpublished = {The LLM Data Company Blog},
  year = {2025},
  url = {https://www.llmdata.com/blog/mismatch-praxis/}
}