Skip to content

Conversation

@pramodith
Copy link
Collaborator

What does this PR do?

Add On-Policy Distillation from thinking labs to paper index. https://thinkingmachines.ai/blog/on-policy-distillation/

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@kashif @qgallouedec
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


config = GKDConfig(
lmbda=1.0, # student produces rollouts for all batches
beta=1.0, # to ensure reverse-kl as the loss function
Copy link
Collaborator Author

@pramodith pramodith Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can someone please confirm if beta should be 1.0? The thinking lab's block mentions that we should use reverse-kl, but I see that in the HF blog recommends setting beta=0 for the distillation step., so I'm a bit confused.

From thinking labs:

Reverse KL has natural synergy with RL, which generally optimizes a form of sequence-level reverse KL induced by the reward model. However, unlike most reward models in practice, the reverse KL is “unhackable” in the sense that low KL always corresponds to a high probability of desirable behavior from the teacher model’s point of view. Two other useful properties of reverse KL are that it is “mode seeking”See Eric Jang’s post for more discussion of mode seeking behaviors. — it learns one specific behavior (the teacher’s) instead of spreading its distribution across several suboptimal options — and it reduces exposure bias.

HF blog's linked recipe for distillation

accelerate launch \
  --config_file examples/accelerate_configs/multi_gpu.yaml trl/experimental/gold/gold.py \
  --model_name_or_path <sft-model> \
  --dtype auto \
  --attn_implementation kernels-community/flash-attn \
  --dataset_name allenai/tulu-3-sft-mixture \
  --dataset_train_split train \
  --bf16 \
  --learning_rate 1e-7 \
  --gradient_checkpointing \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 64 \
  --num_train_epochs 1 \
  --eval_strategy steps \
  --eval_steps 100 \
  --temperature 1.0 \
  --top_p 0.95 \
  --top_k 0 \
  --max_new_tokens 2048 \
  --max_prompt_length 512 \
  --lmbda 0.25 \
  --beta 0.0 \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kashif

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets see the GKD paper had beta from 0.1, 0.5 and 0.9, and as the beta-> 1 the gradient of the loss behaves like inverse-KL. In the code when beta=1.0 we do: jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) which is the KL(student || teacher)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so it seems like setting beta=1 is the right thing to do for reproducing what the Thinking Machines blog state. Was the decision to use beta=0 instead of 1.0 in the HF blog intended?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch.

Using beta=0 was intentional because it was more sample efficient compared to beta=1 for some ablations with the Countdown task. We also used lambda=0.25 instead of lambda=1.0, as lower lambda values are faster to run than lambda=1 and with comparable results to those of an entirely online setup (see Figure 5 from the blogpost).

You're right that for exact reproduction of the Thinking Machine blog, we should use beta=1 and lambda=1. However, we wanted to include the parameters we used for the blog post in case someone wanted to reproduce those results.

Running the experiments with the Thinking Machines setup will likely result in marginally better performance, but the conclusion would be essentially the same.

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

5 participants