Skip to content

Commit 861ef12

Browse files
danielhanchenDatta0mmathew23
authored
Bug fixes (#347)
* Update attention_sink.py * Update gpt_oss.py * prefer_nd_tiling * Update patching_utils.py * flex_attention_with_sink * Compile Flex Attention * Update mxfp4.py * Update mxfp4.py * Update mxfp4.py * Update mxfp4.py * Update gpt_oss.py * bitsandbytes patch * Update bitsandbytes.py * Update gpt_oss.py * Inplace ops * Update gpt_oss.py * has_static_cache * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update attention_sink.py * Update attention_sink.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * torch compile * Update attention_sink.py * Update common.py * Update common.py * Patches * Compiled mask creation * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Revert * Update gpt_oss.py * Update gpt_oss.py * Fix up * Update attention_sink.py * Update attention_sink.py * Update utils.py * Update attention_sink.py * Update attention_sink.py * Retry * Update gpt_oss.py * Update gpt_oss.py * Fix Flex * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Bug fixes * Update patching_utils.py * Update patching_utils.py * Update patching_utils.py * Update rl_replacements.py * Update patching_utils.py * Update patching_utils.py * Update patching_utils.py * flash attn * Update gpt_oss.py * Update __init__.py * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * dropout_p * Update gpt_oss.py * Update gpt_oss.py * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * fix * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update loss_utils.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update loss_utils.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Versioning * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Fix Gemma 3 * Update misc.py * Update rl_environments.py * Update pyproject.toml * Update rl_environments.py * Update __init__.py * Update empty_model.py * Update empty_model.py * Update empty_model.py * Update empty_model.py * Device type * Update vllm_utils.py * Update compiler.py * Update empty_model.py * Update vllm_utils.py * Update empty_model.py * Fixes * Update empty_model.py * Update empty_model.py * Update __init__.py * Update vllm_utils.py * Update vllm_utils.py * Update rl_environments.py * Update cross_entropy_loss.py * Update vllm_utils.py * Update vllm_utils.py * Update rl_environments.py * Update vllm_utils.py * Qwen3 VL vLLM (#324) * qwen3 vl additional layers * qwen3 fused vision qkv * refactor for handling qwen 3 vl * [WIP] fix backward pass issues * out hidden size change * Qwen 2.5 and qwen 3 conv3d->Linear vLLM changes * Update __init__.py * Update __init__.py * Update __init__.py * Update __init__.py * Update __init__.py * Update __init__.py * Update __init__.py * Update vllm_utils.py * Update vllm_utils.py * Update pyproject.toml * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update __init__.py * Update compiler.py * Update __init__.py * Update vllm_utils.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Fix CE compile * Update loss_utils.py * Update cross_entropy_loss.py * Fix * Deepseekocr fix: save single model shard (#346) * DeepSeekOCR Fix: check for saftensors_list shard naming convention * turned off shard padding length check bc deepseeks padding is different * if you try to copy the index.json file and the same file alredy exists it wil throw and error. --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> Co-authored-by: DoubleMathew <mmathew23@gmail.com>
1 parent 6690af3 commit 861ef12

File tree

8 files changed

+71
-21
lines changed

8 files changed

+71
-21
lines changed

‎unsloth_zoo/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# You should have received a copy of the GNU Lesser General Public License
1515
# along with this program. If not, see <https://www.gnu.org/licenses/>.
1616

17-
__version__ = "2025.11.1"
17+
__version__ = "2025.11.2"
1818

1919
import os
2020
import warnings

‎unsloth_zoo/fused_losses/cross_entropy_loss.py‎

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,7 @@ def compute_fused_ce_loss(
9595

9696
# Calculate cross entropy loss
9797
reduction = "sum" if n_items is not None else "mean"
98-
# Since we overwrite torch.compile(torch.nn.functional.cross_entropy)
99-
# We might get double compile errors, so use the uncompiled version
100-
cross_entropy = \
101-
torch.nn.functional._uncompiled_cross_entropy if \
102-
hasattr(torch.nn.functional, "_uncompiled_cross_entropy") else \
103-
torch.nn.functional.cross_entropy
104-
loss = cross_entropy(
98+
loss = torch.nn.functional.cross_entropy(
10599
input = logits.view(-1, vocab_size).float().contiguous(),
106100
target = labels.view(-1).to(device).contiguous(),
107101
reduction = reduction,

‎unsloth_zoo/loss_utils.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ def unsloth_fixed_cross_entropy(source, target, num_items_in_batch: int = None,
106106
ignore_index = ignore_index,
107107
reduction = reduction,
108108
)
109-
if reduction == "sum": loss = loss / num_items_in_batch
109+
if reduction == "sum":
110+
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
111+
if torch.is_tensor(num_items_in_batch):
112+
num_items_in_batch = num_items_in_batch.to(loss.device)
113+
loss = loss / num_items_in_batch
110114
return loss
111115
pass
112116

‎unsloth_zoo/patch_torch_functions.py‎

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ def patch_torch_functions():
171171
if not hasattr(torch.nn.functional, "_uncompiled_layer_norm"):
172172
torch.nn.functional._uncompiled_layer_norm = torch.nn.functional.layer_norm
173173
torch.nn.functional.layer_norm = layer_norm
174-
if not hasattr(torch.nn.functional, "_uncompiled_cross_entropy"):
175-
torch.nn.functional._uncompiled_cross_entropy = torch.nn.functional.cross_entropy
176-
torch.nn.functional.cross_entropy = cross_entropy
174+
# Remove compiling cross_entropy since too many errors
175+
# We already compile this most likely anyways
176+
# if not hasattr(torch.nn.functional, "_uncompiled_cross_entropy"):
177+
# torch.nn.functional._uncompiled_cross_entropy = torch.nn.functional.cross_entropy
178+
# torch.nn.functional.cross_entropy = cross_entropy
177179
pass
178180

179181

‎unsloth_zoo/rl_replacements.py‎

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
126126
pass
127127
RL_REPLACEMENTS["left_pack_padding"] = left_pack_padding
128128

129-
import torch
130129

131130
def align_logprobs_with_mask(
132131
logprob_tensor: torch.Tensor,
@@ -176,10 +175,23 @@ def align_logprobs_with_mask(
176175
padded_logprobs[valid_rows, valid_cols] = valid_vals
177176

178177
return padded_logprobs
179-
178+
pass
180179
RL_REPLACEMENTS["align_logprobs_with_mask"] = align_logprobs_with_mask
181180

182181

182+
def grpo_update_SamplingParams(SamplingParams, generation_kwargs, vllm_sampling_params = None):
183+
good_sampling_params_keys = inspect.signature(SamplingParams).parameters.keys()
184+
if vllm_sampling_params is not None:
185+
for key in good_sampling_params_keys:
186+
if hasattr(vllm_sampling_params, key):
187+
overwrited_key = getattr(vllm_sampling_params, key)
188+
if overwrited_key is not None and (type(overwrited_key) in (list, tuple,) and len(overwrited_key) != 0):
189+
generation_kwargs[key] = overwrited_key
190+
return generation_kwargs
191+
pass
192+
RL_REPLACEMENTS["grpo_update_SamplingParams"] = grpo_update_SamplingParams
193+
194+
183195
# Custom compiled GRPO loss - creates 3 Triton kernels
184196
def grpo_compute_loss(
185197
ref_logits,

‎unsloth_zoo/saving_utils.py‎

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,26 @@ def fix_tokenizer_config_json(tokenizer, saved_folder):
920920
return
921921
pass
922922

923+
def is_hf_sharded_safetensors(filenames: list[str]) -> bool:
924+
"""Check if filenames follow HF sharded naming: model-00001-of-00005.safetensors"""
925+
pattern = re.compile(r'^(.+?)-(\d+)-of-(\d+)\.safetensors$')
926+
927+
matches = [pattern.match(f) for f in filenames]
928+
if not all(matches):
929+
return False
930+
931+
# Keep strings to check padding
932+
parsed = [(m.group(1), m.group(2), m.group(3)) for m in matches]
933+
934+
# shard and total have same padding: turned off as deepseekocr padding is different
935+
# for prefix, shard_str, total_str in parsed:
936+
# if len(shard_str) != len(total_str):
937+
# return False
938+
939+
# same prefix and total
940+
prefixes, _, totals = zip(*parsed)
941+
return len(set(prefixes)) == 1 and len(set(totals)) == 1
942+
923943
@torch.inference_mode
924944
def merge_and_overwrite_lora(
925945
get_model_name,
@@ -1170,7 +1190,8 @@ def upload_items(filename = None):
11701190
_hf_cache_dir = _get_hf_cache_dir()
11711191
copied_all_from_cache = False
11721192
copied_tokenizer_model_from_cache = False
1173-
safe_tensor_index_files = ["model.safetensors.index.json"] if len(safetensors_list) > 1 else []
1193+
is_hf_sharded = is_hf_sharded_safetensors(safetensors_list)
1194+
safe_tensor_index_files = ["model.safetensors.index.json"] if (len(safetensors_list) > 1 or is_hf_sharded) else []
11741195

11751196
# ONLY download/copy the original index if we are NOT dequantizing an MXFP4 model
11761197
if (not (base_model_is_quantized and quant_type == "mxfp4") or (base_model_is_quantized and quant_type == "mxfp4" and save_method == "mxfp4")) and not needs_splitting:
@@ -1180,7 +1201,13 @@ def upload_items(filename = None):
11801201
if safe_tensor_index_files:
11811202
local_index_path = os.path.join(model_name, "model.safetensors.index.json")
11821203
if os.path.exists(local_index_path):
1183-
shutil.copy2(local_index_path, os.path.join(save_directory, "model.safetensors.index.json"))
1204+
try:
1205+
shutil.copy2(local_index_path, os.path.join(save_directory, "model.safetensors.index.json"))
1206+
except shutil.SameFileError:
1207+
pass
1208+
except Exception as e:
1209+
print(f"Error copying model.safetensors.index.json: {e}")
1210+
raise e
11841211
else:
11851212
# Download from HF
11861213
if "model.safetensors.index.json" in [f for f in safe_tensor_index_files]:
@@ -1282,7 +1309,8 @@ def upload_items(filename = None):
12821309
if needs_splitting:
12831310
final_safetensors_list = renumber_safetensor_files(final_safetensors_list, save_directory)
12841311

1285-
regenerate_index = ((base_model_is_quantized and quant_type == "mxfp4") or needs_splitting) and len(final_safetensors_list) > 1 and save_method != "mxfp4"
1312+
is_final_safetensors_list_sharded = is_hf_sharded_safetensors(final_safetensors_list)
1313+
regenerate_index = ((base_model_is_quantized and quant_type == "mxfp4") or needs_splitting) and (len(final_safetensors_list) > 1 or is_final_safetensors_list_sharded) and save_method != "mxfp4"
12861314
weight_map = {}
12871315

12881316
for filename in ProgressBar(final_safetensors_list, desc=f'Unsloth: Merging weights into {"mxfp4" if save_method=="mxfp4" else "16bit"}'):

‎unsloth_zoo/training_utils.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,16 @@ def fix_zero_training_loss(model, tokenizer, train_dataset):
7474
"Unsloth: All labels in your dataset are -100. Training losses will be all 0.\n"\
7575
"For example, are you sure you used `train_on_responses_only` correctly?\n"\
7676
"Or did you mask our tokens incorrectly? Maybe this is intended?\n"\
77-
"Maybe you're using a Llama chat template on a non Llama model for example?"
77+
"Maybe you're using a Llama chat template on a non Llama model for example?"\
78+
"If you used `train_on_responses_only`, confirm your user and assistant parts are correct!"
7879
)
7980
elif seen_bad / (seen_bad + seen_good) >= 0.9:
8081
print(
8182
"Unsloth: Nearly all labels in your dataset are -100. Training losses will be all 0.\n"\
8283
"For example, are you sure you used `train_on_responses_only` correctly?\n"\
8384
"Or did you mask our tokens incorrectly? Maybe this is intended?\n"\
84-
"Maybe you're using a Llama chat template on a non Llama model for example?"
85+
"Maybe you're using a Llama chat template on a non Llama model for example?"\
86+
"If you used `train_on_responses_only`, confirm your user and assistant parts are correct!"
8587
)
8688
pass
8789
pass

‎unsloth_zoo/vllm_utils.py‎

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"generate_batches",
2929
"convert_lora_modules",
3030
"return_lora_modules",
31+
"get_lora_supported_ranks",
3132
]
3233

3334
from typing import Optional, List, Tuple, Dict, Any
@@ -1463,8 +1464,8 @@ def approximate_vllm_memory_usage(
14631464
pass
14641465

14651466

1466-
def determine_max_lora_rank(lora_rank = 16):
1467-
"""vLLM doesn't allow any LoRA rank, so we need to get the next largest"""
1467+
@functools.cache
1468+
def get_lora_supported_ranks():
14681469
possible_max_ranks = [8, 16, 32, 64, 128, 256, 320, 512]
14691470
try:
14701471
import vllm.config.lora
@@ -1482,6 +1483,13 @@ def determine_max_lora_rank(lora_rank = 16):
14821483
if type(possible_max_ranks) is str:
14831484
possible_max_ranks = re.findall(r"[\d]{1,}", possible_max_ranks)
14841485
possible_max_ranks = [int(x) for x in possible_max_ranks]
1486+
return possible_max_ranks
1487+
pass
1488+
1489+
1490+
def determine_max_lora_rank(lora_rank = 16):
1491+
"""vLLM doesn't allow any LoRA rank, so we need to get the next largest"""
1492+
possible_max_ranks = get_lora_supported_ranks()
14851493
for max_lora_rank in possible_max_ranks:
14861494
if max_lora_rank >= lora_rank:
14871495
return max_lora_rank

0 commit comments

Comments
 (0)