Skip to content

Conversation

@Pringled
Copy link
Member

Torch shows significant drops in performance across tasks when distilling with MPS on version 2.8.0. For this reason, it's disabled until this has been resolved in Torch.

@codecov
Copy link

codecov bot commented Sep 30, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
model2vec/distill/utils.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@Pringled Pringled merged commit e10118e into main Sep 30, 2025
6 checks passed
@Pringled Pringled deleted the fix-mps-issue branch September 30, 2025 16:32
@davidlowryduda
Copy link

Can you share the benchmark that you used for this test? I'd be interested in monitoring mps behavior.

@Pringled
Copy link
Member Author

Hi @davidlowryduda, for sure, I ran the following script (make sure you install our evaluation package with pip install git+https://github.com/MinishLab/evaluation.git@main:

from evaluation import CustomMTEB, get_tasks, parse_mteb_results, make_leaderboard, summarize_results, TaskType
from model2vec.distill import distill

model_name = "m2v-bge-base-en-v1.5"
model = distill(model_name="BAAI/bge-base-en-v1.5")
task_types = [TaskType.WORDSIM]
tasks = get_tasks(task_types=task_types)
evaluation = CustomMTEB(tasks=tasks)
results = evaluation.run(model, eval_splits=["test"], output_folder=f"local/results/temp/{model_name}", overwrite_results=True)
parsed_results = parse_mteb_results(mteb_results=results, model_name=model_name)
task_scores = summarize_results(parsed_results)
leaderboard = make_leaderboard(task_scores)
print(leaderboard["WordSim"])

This will do a quick eval on the WordSim task which is the fastest to run; you can also run it on e.g. classification, retrieval etc but those are much slower. So to test this behavior you can pip install torch==2.7.1, run it, and then pip install torch==2.8.0 and run it again to compare. You'll see a massive performance degradation.

I have since figured out why: it's because torch 2.8.0 introduced a fast MPS SDPA kernel for short sequences which produces different vectors, and during distillation all our sequences are just individual tokens (with bos/eos tokens). I did some experiments and it can be fixed by padding the sequences to a length of 9, but this is extremely hacky and ugly, so for now I'll keep the code as is and simply disable MPS distillation for torch >= 2.8.0. When I have time I'd like to dive into this on the torch side and perhaps fix it there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants