-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add On-Policy Distillation from thinking labs to paper index. #4410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add On-Policy Distillation from thinking labs to paper index. #4410
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| config = GKDConfig( | ||
| lmbda=1.0, # student produces rollouts for all batches | ||
| beta=1.0, # to ensure reverse-kl as the loss function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can someone please confirm if beta should be 1.0? The thinking lab's block mentions that we should use reverse-kl, but I see that in the HF blog recommends setting beta=0 for the distillation step., so I'm a bit confused.
From thinking labs:
Reverse KL has natural synergy with RL, which generally optimizes a form of sequence-level reverse KL induced by the reward model. However, unlike most reward models in practice, the reverse KL is “unhackable” in the sense that low KL always corresponds to a high probability of desirable behavior from the teacher model’s point of view. Two other useful properties of reverse KL are that it is “mode seeking”See Eric Jang’s post for more discussion of mode seeking behaviors. — it learns one specific behavior (the teacher’s) instead of spreading its distribution across several suboptimal options — and it reduces exposure bias.
HF blog's linked recipe for distillation
accelerate launch \
--config_file examples/accelerate_configs/multi_gpu.yaml trl/experimental/gold/gold.py \
--model_name_or_path <sft-model> \
--dtype auto \
--attn_implementation kernels-community/flash-attn \
--dataset_name allenai/tulu-3-sft-mixture \
--dataset_train_split train \
--bf16 \
--learning_rate 1e-7 \
--gradient_checkpointing \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 64 \
--num_train_epochs 1 \
--eval_strategy steps \
--eval_steps 100 \
--temperature 1.0 \
--top_p 0.95 \
--top_k 0 \
--max_new_tokens 2048 \
--max_prompt_length 512 \
--lmbda 0.25 \
--beta 0.0 \
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @kashif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets see the GKD paper had beta from 0.1, 0.5 and 0.9, and as the beta-> 1 the gradient of the loss behaves like inverse-KL. In the code when beta=1.0 we do: jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) which is the KL(student || teacher)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so it seems like setting beta=1 is the right thing to do for reproducing what the Thinking Machines blog state. Was the decision to use beta=0 instead of 1.0 in the HF blog intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good catch.
Using beta=0 was intentional because it was more sample efficient compared to beta=1 for some ablations with the Countdown task. We also used lambda=0.25 instead of lambda=1.0, as lower lambda values are faster to run than lambda=1 and with comparable results to those of an entirely online setup (see Figure 5 from the blogpost).
You're right that for exact reproduction of the Thinking Machine blog, we should use beta=1 and lambda=1. However, we wanted to include the parameters we used for the blog post in case someone wanted to reproduce those results.
Running the experiments with the Thinking Machines setup will likely result in marginally better performance, but the conclusion would be essentially the same.
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
What does this PR do?
Add On-Policy Distillation from thinking labs to paper index. https://thinkingmachines.ai/blog/on-policy-distillation/
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@kashif @qgallouedec
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.