Skip to content

Commit 4903c55

Browse files
authored
Add conv2d support for IntxUnpackedToInt8Tensor (#3371)
**Summary:** This enables `Int8DynamicActivationIntxWeightConfig` and `IntxWeightOnlyConfig` for conv2d. **Test Plan:** ``` python test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py -k test_conv2d ```
1 parent ab6bc89 commit 4903c55

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

‎test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ def test_linear(self):
6161
error = compute_error(original, quantized)
6262
self.assertTrue(error > 20)
6363

64+
def test_conv2d(self):
65+
dtype = torch.bfloat16
66+
device = "cpu"
67+
input = torch.randn(1, 128, 224, 224, dtype=dtype, device=device)
68+
conv = torch.nn.Conv2d(128, 64, 3, dtype=dtype, device=device)
69+
original = conv(input)
70+
is_conv = lambda n, _: isinstance(n, torch.nn.Conv2d)
71+
quantize_(conv, self.config, filter_fn=is_conv)
72+
quantized = conv(input)
73+
error = compute_error(original, quantized)
74+
self.assertGreater(error, 15)
75+
6476
def test_hqq_intx_weight_only_config(self):
6577
dtype = torch.bfloat16
6678
device = "cpu"

‎torchao/quantization/quant_api.py‎

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,20 +2297,32 @@ def _intx_weight_only_quantize_tensor(
22972297
intx_packing_format = config.intx_packing_format
22982298
intx_choose_qparams_algorithm = config.intx_choose_qparams_algorithm
22992299

2300-
assert weight.dim() == 2, (
2301-
f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
2302-
)
2300+
if weight.dim() == 2:
2301+
input_dim = -1
2302+
elif weight.dim() == 4:
2303+
# conv2d: N, C_in, H, W
2304+
input_dim = 1
2305+
else:
2306+
raise ValueError(
2307+
f"IntxWeightOnlyConfig only works for 2-d and 4-d Tensors, got: {weight.dim()}"
2308+
)
2309+
23032310
if isinstance(granularity, PerGroup):
23042311
group_size = granularity.group_size
23052312
elif isinstance(granularity, PerAxis):
23062313
assert granularity.axis == 0, (
23072314
f"axis must be 0 with PerAxis, but got {granularity.axis}"
23082315
)
2309-
group_size = weight.shape[-1]
2316+
group_size = weight.shape[input_dim]
23102317
else:
23112318
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")
23122319

2313-
block_size = (1, group_size)
2320+
if weight.dim() == 2:
2321+
block_size = (1, group_size)
2322+
else:
2323+
# conv2d: N, C_in, H, W
2324+
assert weight.dim() == 4
2325+
block_size = (1, group_size, 1, 1)
23142326

23152327
if config.version == 2:
23162328
if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8:

‎torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py‎

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,36 @@ def _(func, types, args, kwargs):
343343
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
344344

345345

346+
@implements(aten.conv2d.default)
347+
@implements_torch_function(torch.nn.functional.conv2d)
348+
def _(func, types, args, kwargs):
349+
(
350+
input_tensor,
351+
weight_tensor,
352+
bias,
353+
stride,
354+
padding,
355+
dilation,
356+
groups,
357+
) = fill_defaults(args, 7, [None, [1, 1], [0, 0], [1, 1], 1])
358+
assert isinstance(weight_tensor, IntxUnpackedToInt8Tensor)
359+
360+
# Apply dynamic activation quant
361+
if weight_tensor.activation_quantization is not None:
362+
if (
363+
weight_tensor.activation_quantization
364+
== IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN
365+
):
366+
input_tensor = _apply_int8_act_asym_per_token_quant_dequant(input_tensor)
367+
else:
368+
raise NotImplementedError(
369+
f"Unsupported activation quantization: {weight_tensor.activation_quantization}"
370+
)
371+
372+
weight_tensor = weight_tensor.dequantize()
373+
return func(input_tensor, weight_tensor, bias, stride, padding, dilation, groups)
374+
375+
346376
@implements(aten.embedding.default)
347377
@implements_torch_function(torch.nn.functional.embedding)
348378
def _(func, types, args, kwargs):

0 commit comments

Comments
 (0)