Skip to content

Commit 316ef03

Browse files
authored
Fix style after #3261 (#3397)
1 parent eb2a8fc commit 316ef03

File tree

2 files changed

+92
-30
lines changed

2 files changed

+92
-30
lines changed

‎test/quantization/pt2e/test_x86inductor_fusion.py‎

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,22 @@ def forward(self, input):
139139

140140

141141
class FP8QDQConv2d(torch.nn.Module):
142-
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
142+
def __init__(
143+
self,
144+
in_channels,
145+
out_channels,
146+
kernel_size,
147+
stride=1,
148+
padding=0,
149+
dilation=1,
150+
groups=1,
151+
bias=True,
152+
):
143153
super().__init__()
144154
self.qtype = torch.float8_e4m3fn
145-
self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype)
155+
self.weight = torch.randn(
156+
(out_channels, in_channels // groups, *kernel_size)
157+
).to(self.qtype)
146158
self.weight_scale = 2.0
147159
self.scale = 2.0
148160
self.bias = None
@@ -170,7 +182,16 @@ def forward(self, input):
170182
output_dtype=torch.float,
171183
)
172184

173-
return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
185+
return torch.nn.functional.conv2d(
186+
dq_input,
187+
weight,
188+
self.bias,
189+
self.stride,
190+
self.padding,
191+
self.dilation,
192+
self.groups,
193+
)
194+
174195

175196
def qdq(input, scale):
176197
dtype = input.dtype
@@ -205,9 +226,7 @@ def create_mod_info_recursion(parent):
205226
parent_child_mod_dict = generate_model_info(model)
206227
for name, mod in model.named_modules():
207228
mod_type_str = mod.__class__.__name__
208-
if mod_type_str not in [
209-
"Linear", "Conv2d"
210-
]:
229+
if mod_type_str not in ["Linear", "Conv2d"]:
211230
continue
212231
param = mod.weight
213232
xmax = torch.max(param)
@@ -225,7 +244,16 @@ def create_mod_info_recursion(parent):
225244
patched_mod.weight_scale = weight_scale.item()
226245
patched_mod.weight.data = q_param
227246
elif mod_type_str in ["Conv2d"]:
228-
patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False)
247+
patched_mod = FP8QDQConv2d(
248+
mod.in_channels,
249+
mod.out_channels,
250+
mod.kernel_size,
251+
mod.stride,
252+
mod.padding,
253+
mod.dilation,
254+
mod.groups,
255+
False,
256+
)
229257
patched_mod.bias = mod.bias
230258
patched_mod.weight_scale = weight_scale.item()
231259
patched_mod.weight.data = q_param
@@ -610,7 +638,9 @@ def test_qconv2d_relu6_fp8_cpu(self):
610638
r"""
611639
This testcase will quantize Conv2d->ReLU6 pattern.
612640
"""
613-
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True)
641+
self._qconv2d_unary_test_helper(
642+
device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True
643+
)
614644

615645
@skipIfNoDynamoSupport
616646
@skipIfNoONEDNN
@@ -627,7 +657,9 @@ def test_qconv2d_hardtanh_fp8_cpu(self):
627657
r"""
628658
This testcase will quantize Conv2d->Hardtanh pattern.
629659
"""
630-
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True)
660+
self._qconv2d_unary_test_helper(
661+
device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True
662+
)
631663

632664
@skipIfNoDynamoSupport
633665
@skipIfNoONEDNNBF16
@@ -678,7 +710,9 @@ def test_qconv2d_hardswish_fp8_cpu(self):
678710
r"""
679711
This testcase will quantize Conv2d->Hardswish pattern.
680712
"""
681-
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True)
713+
self._qconv2d_unary_test_helper(
714+
device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True
715+
)
682716

683717
@skipIfNoDynamoSupport
684718
@skipIfNoONEDNNBF16
@@ -731,7 +765,9 @@ def test_qconv2d_silu_fp8_cpu(self):
731765
r"""
732766
This testcase will quantize Conv2d->SiLU pattern.
733767
"""
734-
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True)
768+
self._qconv2d_unary_test_helper(
769+
device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True
770+
)
735771

736772
@skipIfNoDynamoSupport
737773
@skipIfNoONEDNNBF16
@@ -911,9 +947,7 @@ def forward(self, x, x2, x3):
911947
add_fn_list = quantization_add_fn_list
912948
if not is_fp8:
913949
add_fn_list = add_fn_list + quantization_inplace_add_fn_list
914-
for add_fn, swap_inputs in itertools.product(
915-
add_fn_list, [False, True]
916-
):
950+
for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]):
917951
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
918952
x = torch.randn(
919953
(1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device

‎torchao/quantization/pt2e/inductor_passes/x86.py‎

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,26 @@ def get_dequantize_per_tensor_activation_pattern(
174174
output_dtype=KeywordArg("w_dtype"),
175175
)
176176

177+
177178
def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern):
178179
return _may_generate_pattern_with_dtype_convert(
179180
dequant_wgt_pattern,
180181
KeywordArg("autocast_wgt_dtype"),
181182
)
182183

184+
183185
def get_dequantize_clone_weight_pattern(dequant_wgt_pattern):
184186
return CallFunction(
185187
aten.clone.default,
186188
dequant_wgt_pattern,
187189
memory_format=KeywordArg("memory_format"),
188190
)
189191

192+
190193
def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern):
191-
return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern))
194+
return get_dequantize_clone_weight_pattern(
195+
get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)
196+
)
192197

193198

194199
def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1):
@@ -450,14 +455,18 @@ def fn(match):
450455
break
451456
assert extra_input_of_binary_node is not None
452457
# Extra input of binary node comes from dequant pattern
453-
if not is_fp8 and extra_input_from_dequant and (
454-
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
455-
or (
456-
extra_input_of_binary_node.target
457-
not in [
458-
quantized_decomposed.dequantize_per_tensor.default,
459-
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
460-
]
458+
if (
459+
not is_fp8
460+
and extra_input_from_dequant
461+
and (
462+
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
463+
or (
464+
extra_input_of_binary_node.target
465+
not in [
466+
quantized_decomposed.dequantize_per_tensor.default,
467+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
468+
]
469+
)
461470
)
462471
):
463472
return False
@@ -692,7 +701,9 @@ def _inner(match):
692701
return _inner
693702

694703

695-
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False):
704+
def _register_qconv_weight_prepack_pass(
705+
pattern, pass_number, dtype=torch.float32, is_fp8=False
706+
):
696707
@register_freezing_graph_pattern(
697708
pattern,
698709
extra_check=_is_valid_dequant_conv_pattern(dtype),
@@ -776,7 +787,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
776787
if is_fp8:
777788
# For float8, we assume the scales are from aten.full.default instead of
778789
# a constant buffer to avoid constant folding of q/dq before fusion passes.
779-
assert w_scale.target is torch.ops.aten.full.default and x_scale.target is torch.ops.aten.full.default
790+
assert (
791+
w_scale.target is torch.ops.aten.full.default
792+
and x_scale.target is torch.ops.aten.full.default
793+
)
780794
with torch.utils._python_dispatch._disable_current_modes():
781795
w_scale_tensor = torch.tensor([w_scale.args[1]])
782796
match.graph.owning_module.register_buffer("w_scale", w_scale_tensor)
@@ -1446,8 +1460,12 @@ def _register_dequant_promotion():
14461460

14471461

14481462
def _register_qconv_weight_prepack():
1449-
for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]):
1450-
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8)
1463+
for dtype, is_fp8 in itertools.product(
1464+
[torch.float32, torch.bfloat16], [True, False]
1465+
):
1466+
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(
1467+
dtype, is_fp8=is_fp8
1468+
)
14511469
for weight_prepack_pattern in weight_prepack_patterns:
14521470
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
14531471
_register_qconv_weight_prepack_pass(
@@ -2050,7 +2068,13 @@ def qconv(match: Match, *args, **kwargs):
20502068
kwargs["groups"],
20512069
)
20522070
output_dtype = _get_pattern_output_dtype(match)
2053-
assert output_dtype in [torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float32, torch.bfloat16]
2071+
assert output_dtype in [
2072+
torch.int8,
2073+
torch.uint8,
2074+
torch.float8_e4m3fn,
2075+
torch.float32,
2076+
torch.bfloat16,
2077+
]
20542078
# Output QParams
20552079
if output_dtype == torch.float8_e4m3fn:
20562080
# For float8, we assume the scale is from aten.full.default instead of
@@ -2297,7 +2321,9 @@ def _register_qconv_unary_fusion():
22972321

22982322

22992323
def _register_qconv_binary_fusion():
2300-
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]):
2324+
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product(
2325+
[False, True], [False, True]
2326+
):
23012327
qconv_binary_op = (
23022328
torch.ops.onednn.qconv2d_pointwise.binary_tensor
23032329
if x_scale_zp_are_tensors
@@ -2306,7 +2332,9 @@ def _register_qconv_binary_fusion():
23062332
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
23072333
swap_binary_inputs_list = [False, True]
23082334
binary_replace_patterns = {}
2309-
for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]):
2335+
for swap_inputs, is_fp8 in itertools.product(
2336+
swap_binary_inputs_list, [False, True]
2337+
):
23102338
binary_replace_patterns.update(
23112339
{
23122340
PostOpAttr(

0 commit comments

Comments
 (0)