Skip to content

Commit 256f8fe

Browse files
committed
Merge branch 'main' into nightly
2 parents 8dbd008 + dee6de9 commit 256f8fe

File tree

6 files changed

+337
-22
lines changed

6 files changed

+337
-22
lines changed

‎tests/utils/test_qat.py‎

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from unsloth import FastLanguageModel
2+
3+
from typing import Dict
4+
5+
import pytest
6+
import torch
7+
from torchao.quantization.qat import FakeQuantizedLinear
8+
from torchao.quantization.qat.fake_quantizer import (
9+
FakeQuantizerBase,
10+
Float8FakeQuantizer,
11+
Int4WeightPreshuffledFakeQuantizer,
12+
)
13+
14+
15+
class _CountingFakeQuantizer(torch.nn.Module):
16+
"""
17+
Dummy fake quantizer that counts the number of times it has been called.
18+
"""
19+
def __init__(self):
20+
super().__init__()
21+
self.count = 0
22+
23+
def forward(self, x: torch.Tensor) -> torch.Tensor:
24+
self.count += 1
25+
return x
26+
27+
28+
def _get_model(qat_scheme: str, full_finetuning: bool):
29+
"""
30+
Return a 2-tuple of (model, tokenizer), where the model has been configured
31+
to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
32+
"""
33+
model, tokenizer = FastLanguageModel.from_pretrained(
34+
model_name = "unsloth/Qwen3-1.7B",
35+
load_in_4bit = False,
36+
full_finetuning = full_finetuning,
37+
qat_scheme = qat_scheme if full_finetuning else None,
38+
)
39+
if not full_finetuning:
40+
model = FastLanguageModel.get_peft_model(
41+
model,
42+
qat_scheme = qat_scheme,
43+
)
44+
return model, tokenizer
45+
46+
47+
def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
48+
"""
49+
Verify that the given linear contains fake quantizers according to the `qat_scheme`.
50+
"""
51+
if qat_scheme == "fp8-int4":
52+
act_fq_class = Float8FakeQuantizer
53+
weight_fq_class = Int4WeightPreshuffledFakeQuantizer
54+
min_in_features = 128
55+
elif qat_scheme == "fp8-fp8":
56+
act_fq_class = Float8FakeQuantizer
57+
weight_fq_class = Float8FakeQuantizer
58+
min_in_features = -1
59+
else:
60+
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
61+
62+
# Check base layer activations and weights
63+
base_layer = getattr(linear, "base_layer", linear)
64+
if base_layer.in_features >= min_in_features:
65+
assert isinstance(base_layer, FakeQuantizedLinear)
66+
assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
67+
assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)
68+
69+
# Check lora A and B (only for full_finetuning=False)
70+
if hasattr(linear, "lora_A") and hasattr(linear, "lora_B"):
71+
lora_A = linear.lora_A.default
72+
lora_B = linear.lora_B.default
73+
if lora_A.in_features >= min_in_features:
74+
assert isinstance(lora_A, FakeQuantizedLinear)
75+
assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
76+
assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)
77+
if lora_B.in_features >= min_in_features:
78+
assert isinstance(lora_B, FakeQuantizedLinear)
79+
assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
80+
assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)
81+
82+
83+
def _test_fake_quantizers_are_called(
84+
model: torch.nn.Module,
85+
example_inputs: Dict,
86+
full_finetuning: bool,
87+
):
88+
"""
89+
Verify that the fake quantizers are actually called when the model is called.
90+
"""
91+
def _swap_fake_quantizers(model: torch.nn.Module):
92+
for name, child in model.named_children():
93+
if isinstance(child, FakeQuantizerBase):
94+
setattr(model, name, _CountingFakeQuantizer())
95+
96+
def _assert_fake_quantizers_are_called(model: torch.nn.Module):
97+
for name, child in model.named_children():
98+
if full_finetuning:
99+
if isinstance(child, FakeQuantizedLinear):
100+
assert child.activation_fake_quantizer.count == 1
101+
assert child.weight_fake_quantizer.count == 1
102+
else:
103+
# For LoRA, we only fake quantize the input activations once per block:
104+
# For self_attn, we only fake quantize the q_proj's input activations
105+
# For mlp, we only fake quantize the gate_proj's input activations
106+
if name == "self_attn":
107+
base_layer = child.q_proj.base_layer
108+
assert hasattr(base_layer, "activation_fake_quantizer")
109+
assert base_layer.activation_fake_quantizer.count == 1
110+
elif name == "mlp":
111+
base_layer = child.gate_proj.base_layer
112+
assert hasattr(base_layer, "activation_fake_quantizer")
113+
assert base_layer.activation_fake_quantizer.count == 1
114+
elif isinstance(child, FakeQuantizedLinear):
115+
# Weight fake quantizers should always be called
116+
assert child.weight_fake_quantizer.count == 1
117+
118+
for k, v in example_inputs.items():
119+
example_inputs[k] = v.cuda()
120+
model.apply(_swap_fake_quantizers)
121+
model(**example_inputs)
122+
model.apply(_assert_fake_quantizers_are_called)
123+
124+
125+
def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
126+
"""
127+
Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
128+
"""
129+
model, tokenizer = _get_model(qat_scheme, full_finetuning)
130+
if full_finetuning:
131+
model = model.model
132+
else:
133+
model = model.base_model.model.model
134+
for layer in model.layers:
135+
_test_linear_is_fake_quantized(layer.self_attn.q_proj, qat_scheme)
136+
_test_linear_is_fake_quantized(layer.self_attn.k_proj, qat_scheme)
137+
_test_linear_is_fake_quantized(layer.self_attn.v_proj, qat_scheme)
138+
_test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
139+
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
140+
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
141+
inputs = tokenizer("How are you?", return_tensors="pt")
142+
_test_fake_quantizers_are_called(model, inputs, full_finetuning)
143+
144+
145+
# TODO: there are bad interactions across tests right now, need to figure out
146+
# how to disable model caching before re-enabling this test
147+
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
148+
def _test_full_model_fake_quantize(qat_scheme: bool):
149+
_test_model_fake_quantize(qat_scheme, full_finetuning=True)
150+
151+
152+
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
153+
def test_lora_model_fake_quantize(qat_scheme: bool):
154+
_test_model_fake_quantize(qat_scheme, full_finetuning=False)

‎unsloth/dataprep/synthetic.py‎

Lines changed: 125 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
"SyntheticDataKit",
1717
]
1818
import subprocess
19+
import threading
20+
from collections import deque
1921
import time
2022
import os
2123
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
2224
import requests
2325
import torch
2426
import gc
2527
import time
28+
import re
2629
from unsloth_zoo.vllm_utils import (
2730
load_vllm,
2831
patch_vllm,
@@ -35,6 +38,100 @@
3538
synthetic_qa_config,
3639
)
3740

41+
def terminate_tree(proc: subprocess.Popen, timeout=15):
42+
if proc is None or proc.poll() is not None:
43+
return
44+
45+
try:
46+
import psutil
47+
parent = psutil.Process(proc.pid)
48+
for child in parent.children(recursive=True):
49+
child.terminate()
50+
parent.terminate()
51+
parent.wait(timeout=timeout/2)
52+
return
53+
except:
54+
pass
55+
56+
if os.name == 'nt':
57+
try:
58+
subprocess.run(
59+
['taskkill', '/T', '/F', '/PID', str(proc.pid)],
60+
capture_output=True,
61+
timeout=5
62+
)
63+
proc.wait(timeout=1)
64+
return
65+
except:
66+
pass
67+
68+
proc.kill()
69+
try:
70+
proc.wait(timeout=5)
71+
except:
72+
pass
73+
74+
class PipeCapture:
75+
"""Non blocking pipe capture"""
76+
def __init__(self, pipe, keep_lines=2000, echo=False, name="", text=True, encoding='utf-8', errors='replace', ready_regex=None):
77+
self.pipe = pipe
78+
self.buf = deque(maxlen=keep_lines)
79+
self.lock = threading.Lock()
80+
self.echo = echo
81+
self.name = name
82+
self.text = text
83+
self.encoding = encoding
84+
self.errors = errors
85+
86+
self.ready_event = threading.Event()
87+
self.closed_event = threading.Event()
88+
89+
self.ready_regex = None
90+
if ready_regex is not None:
91+
if not hasattr(ready_regex, "search"):
92+
ready_regex = re.compile(ready_regex)
93+
self.ready_regex = ready_regex
94+
95+
self.t = threading.Thread(target=self._reader, daemon=True)
96+
self.t.start()
97+
98+
def _reader(self):
99+
try:
100+
sentinel = '' if self.text else b''
101+
for raw_line in iter(self.pipe.readline, sentinel):
102+
if not self.text:
103+
line = raw_line.decode(self.encoding, self.errors)
104+
else:
105+
line = raw_line
106+
line = line.rstrip('\r\n')
107+
if self.echo:
108+
if "platform is" not in line:
109+
print(f"{self.name}: {line}")
110+
111+
with self.lock:
112+
self.buf.append(line)
113+
114+
if self.ready_regex is not None and self.ready_regex.search(line):
115+
self.ready_event.set()
116+
117+
finally:
118+
try: self.pipe.close()
119+
except Exception: pass
120+
self.closed_event.set()
121+
122+
def wait_for_ready(self, timeout=None):
123+
return self.ready_event.wait(timeout)
124+
125+
def has_closed(self):
126+
return self.closed_event.is_set()
127+
128+
def wait_until_closed(self, timeout=None):
129+
return self.closed_event.wait(timeout)
130+
131+
def tail(self, n=200):
132+
with self.lock:
133+
return '\n'.join(list(self.buf)[-n:])
134+
38135
class SyntheticDataKit:
39136
def __init__(
40137
self,
@@ -44,6 +141,7 @@ def __init__(
44141
float8_kv_cache = False,
45142
conservativeness = 1.0,
46143
token = None,
144+
timeout = 1200, # maybe this is not enough for large models if we need to download
47145
**kwargs,
48146
):
49147
assert(type(model_name) is str)
@@ -128,30 +226,40 @@ def __init__(
128226
stderr = subprocess.PIPE,
129227
start_new_session = True,
130228
)
229+
ready_re = re.compile(r"Starting vLLM API server(?:\s+\d+)?\s+on\b")
131230
self.vllm_process = vllm_process
231+
self.stdout_capture = PipeCapture(vllm_process.stdout, keep_lines = 1000,
232+
echo = True, name = "vLLM STDOUT",
233+
ready_regex = ready_re, text = False)
234+
self.stderr_capture = PipeCapture(vllm_process.stderr, keep_lines = 2000,
235+
echo = False, name = "vLLM STDERR",
236+
ready_regex = None, text = False)
237+
# we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines
132238

133-
ready_message_part = b"Starting vLLM API server on"
134-
ready = False
135-
while vllm_process.poll() is None:
136-
output = vllm_process.stdout.readline()
137-
if not output:
239+
ready = self.stdout_capture.wait_for_ready(timeout = timeout)
240+
if not ready:
241+
if self.stdout_capture.has_closed() or self.vllm_process.poll() is not None:
138242
print("Stdout stream ended before readiness message detected.")
139-
break
140-
output_str = output.decode('utf-8', errors='ignore').strip()
141-
if "platform is" not in output_str:
142-
print(f"vLLM STDOUT: {output_str}")
143-
if ready_message_part in output:
144-
print(f"\n--- vLLM Server Ready (Detected: '{ready_message_part.decode()}') ---")
145-
ready = True
146-
break
147-
pass
243+
print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
244+
print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
245+
else:
246+
print(f"Unsloth: vllm_process failed to load! (timeout={timeout})")
247+
print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
248+
print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
249+
terminate_tree(self.vllm_process)
250+
return
251+
else:
252+
print("vLLM Server Ready Detected")
148253
pass
149-
if vllm_process is None:
150-
raise RuntimeError("Unsloth: vllm_process failed to load!")
254+
151255
trial = 0
152256
while not self.check_vllm_status():
153257
if trial >= 100:
154-
raise RuntimeError("Unsloth: vllm_process failed to load!")
258+
print("Unsloth: vllm_process failed to load!")
259+
print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
260+
print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
261+
terminate_tree(self.vllm_process)
262+
return
155263
trial += 1
156264
time.sleep(1)
157265
return

‎unsloth/kernels/fast_lora.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from .utils import (
17+
_maybe_fake_quantize_activations,
1718
fast_dequantize,
1819
QUANT_STATE,
1920
get_lora_parameters,
@@ -175,6 +176,7 @@ def backward(ctx, dY : torch.Tensor):
175176

176177
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
177178
def apply_lora_mlp_swiglu(self, X, inplace = True):
179+
X = _maybe_fake_quantize_activations(X, self.gate_proj)
178180
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
179181
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
180182
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@@ -190,6 +192,7 @@ def apply_lora_mlp_swiglu(self, X, inplace = True):
190192

191193
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
192194
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
195+
X = _maybe_fake_quantize_activations(X, self.gate_proj)
193196
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
194197
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
195198
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@@ -205,6 +208,7 @@ def apply_lora_mlp_geglu_exact(self, X, inplace = True):
205208

206209
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
207210
def apply_lora_mlp_geglu_approx(self, X):
211+
X = _maybe_fake_quantize_activations(X, self.gate_proj)
208212
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
209213
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
210214
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@@ -360,6 +364,7 @@ def backward(ctx, dQ, dK, dV):
360364

361365

362366
def apply_lora_qkv(self, X, inplace = True):
367+
X = _maybe_fake_quantize_activations(X, self.q_proj)
363368
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
364369
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
365370
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
@@ -453,6 +458,7 @@ def backward(ctx, dY : torch.Tensor):
453458

454459

455460
def apply_lora_o(self, X):
461+
X = _maybe_fake_quantize_activations(X, self.o_proj)
456462
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
457463
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
458464
return O

0 commit comments

Comments
 (0)