Skip to content

Commit d707bd4

Browse files
authored
Handle TRL version compatibility in rl_replacements.py (#3540)
1 parent 2267b5c commit d707bd4

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

‎unsloth/models/rl_replacements.py‎

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,19 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
215215
# The new multi-line string that will replace the line above
216216
replacement_lines = """
217217
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
218-
if not has_images:
219-
# Left pad prompt before calculation old and ref hidden states
220-
prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)
221-
self.model.for_training()"""
218+
try:
219+
#TRL 0.23.1 and below path
220+
if not has_images:
221+
# Left pad prompt before calculation old and ref hidden states
222+
prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)
223+
self.model.for_training()
224+
except:
225+
#TRL 0.24.0 and below path
226+
if images is None:
227+
# Left pad prompt before calculation old and ref hidden states
228+
prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)
229+
self.model.for_training()"""
222230

223-
if "has_images" not in function:
224-
raise NotImplementedError("Unsloth: For now we support `trl<=0.23.1`. Please downgrade!")
225231
function = function.replace(line_to_replace, replacement_lines)
226232

227233
pattern_to_find = re.compile(

0 commit comments

Comments
 (0)