@@ -174,21 +174,26 @@ def get_dequantize_per_tensor_activation_pattern(
174174 output_dtype = KeywordArg ("w_dtype" ),
175175)
176176
177+
177178def get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ):
178179 return _may_generate_pattern_with_dtype_convert (
179180 dequant_wgt_pattern ,
180181 KeywordArg ("autocast_wgt_dtype" ),
181182 )
182183
184+
183185def get_dequantize_clone_weight_pattern (dequant_wgt_pattern ):
184186 return CallFunction (
185187 aten .clone .default ,
186188 dequant_wgt_pattern ,
187189 memory_format = KeywordArg ("memory_format" ),
188190 )
189191
192+
190193def get_dequantize_to_bf16_clone_weight_pattern (dequant_wgt_pattern ):
191- return get_dequantize_clone_weight_pattern (get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ))
194+ return get_dequantize_clone_weight_pattern (
195+ get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern )
196+ )
192197
193198
194199def get_qconv_pt2e_pattern (x_scale_zp_are_tensors = False , users = 1 ):
@@ -450,14 +455,18 @@ def fn(match):
450455 break
451456 assert extra_input_of_binary_node is not None
452457 # Extra input of binary node comes from dequant pattern
453- if not is_fp8 and extra_input_from_dequant and (
454- (not isinstance (extra_input_of_binary_node , torch .fx .Node ))
455- or (
456- extra_input_of_binary_node .target
457- not in [
458- quantized_decomposed .dequantize_per_tensor .default ,
459- torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
460- ]
458+ if (
459+ not is_fp8
460+ and extra_input_from_dequant
461+ and (
462+ (not isinstance (extra_input_of_binary_node , torch .fx .Node ))
463+ or (
464+ extra_input_of_binary_node .target
465+ not in [
466+ quantized_decomposed .dequantize_per_tensor .default ,
467+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
468+ ]
469+ )
461470 )
462471 ):
463472 return False
@@ -692,7 +701,9 @@ def _inner(match):
692701 return _inner
693702
694703
695- def _register_qconv_weight_prepack_pass (pattern , pass_number , dtype = torch .float32 , is_fp8 = False ):
704+ def _register_qconv_weight_prepack_pass (
705+ pattern , pass_number , dtype = torch .float32 , is_fp8 = False
706+ ):
696707 @register_freezing_graph_pattern (
697708 pattern ,
698709 extra_check = _is_valid_dequant_conv_pattern (dtype ),
@@ -776,7 +787,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
776787 if is_fp8 :
777788 # For float8, we assume the scales are from aten.full.default instead of
778789 # a constant buffer to avoid constant folding of q/dq before fusion passes.
779- assert w_scale .target is torch .ops .aten .full .default and x_scale .target is torch .ops .aten .full .default
790+ assert (
791+ w_scale .target is torch .ops .aten .full .default
792+ and x_scale .target is torch .ops .aten .full .default
793+ )
780794 with torch .utils ._python_dispatch ._disable_current_modes ():
781795 w_scale_tensor = torch .tensor ([w_scale .args [1 ]])
782796 match .graph .owning_module .register_buffer ("w_scale" , w_scale_tensor )
@@ -1446,8 +1460,12 @@ def _register_dequant_promotion():
14461460
14471461
14481462def _register_qconv_weight_prepack ():
1449- for dtype , is_fp8 in itertools .product ([torch .float32 , torch .bfloat16 ], [True , False ]):
1450- weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (dtype , is_fp8 = is_fp8 )
1463+ for dtype , is_fp8 in itertools .product (
1464+ [torch .float32 , torch .bfloat16 ], [True , False ]
1465+ ):
1466+ weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (
1467+ dtype , is_fp8 = is_fp8
1468+ )
14511469 for weight_prepack_pattern in weight_prepack_patterns :
14521470 # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
14531471 _register_qconv_weight_prepack_pass (
@@ -2050,7 +2068,13 @@ def qconv(match: Match, *args, **kwargs):
20502068 kwargs ["groups" ],
20512069 )
20522070 output_dtype = _get_pattern_output_dtype (match )
2053- assert output_dtype in [torch .int8 , torch .uint8 , torch .float8_e4m3fn , torch .float32 , torch .bfloat16 ]
2071+ assert output_dtype in [
2072+ torch .int8 ,
2073+ torch .uint8 ,
2074+ torch .float8_e4m3fn ,
2075+ torch .float32 ,
2076+ torch .bfloat16 ,
2077+ ]
20542078 # Output QParams
20552079 if output_dtype == torch .float8_e4m3fn :
20562080 # For float8, we assume the scale is from aten.full.default instead of
@@ -2297,7 +2321,9 @@ def _register_qconv_unary_fusion():
22972321
22982322
22992323def _register_qconv_binary_fusion ():
2300- for int8_mixed_bf16_with_inplace_add , x_scale_zp_are_tensors in itertools .product ([False , True ], [False , True ]):
2324+ for int8_mixed_bf16_with_inplace_add , x_scale_zp_are_tensors in itertools .product (
2325+ [False , True ], [False , True ]
2326+ ):
23012327 qconv_binary_op = (
23022328 torch .ops .onednn .qconv2d_pointwise .binary_tensor
23032329 if x_scale_zp_are_tensors
@@ -2306,7 +2332,9 @@ def _register_qconv_binary_fusion():
23062332 # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
23072333 swap_binary_inputs_list = [False , True ]
23082334 binary_replace_patterns = {}
2309- for swap_inputs , is_fp8 in itertools .product (swap_binary_inputs_list , [False , True ]):
2335+ for swap_inputs , is_fp8 in itertools .product (
2336+ swap_binary_inputs_list , [False , True ]
2337+ ):
23102338 binary_replace_patterns .update (
23112339 {
23122340 PostOpAttr (
0 commit comments