Skip to content

Commit a5428d3

Browse files
authored
Remove internal executorch dependency on torchao.quantization.subclass
Differential Revision: D84921134 Pull Request resolved: #15223
1 parent 331d771 commit a5428d3

File tree

3 files changed

+174
-2
lines changed

3 files changed

+174
-2
lines changed

examples/models/llama/experimental/__init__.py

Whitespace-only changes.

examples/models/llama/experimental/load_gguf_q4_0.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
2727
from gguf import ReaderTensor
2828
from gguf.constants import GGMLQuantizationType
29-
from torchao.quantization.subclass import QuantizedLinearWeightBase
29+
30+
from .subclass import QuantizedLinearWeightBase
3031

3132
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3233
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama/experimental/subclass.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
#
2121
# This layout is handled internally in the tensor subclass.
2222
import torch
23-
from torchao.quantization.subclass import QuantizedLinearWeightBase
23+
from torch.utils._python_dispatch import return_and_correct_aliasing
24+
from typing_extensions import deprecated
25+
26+
27+
aten = torch.ops.aten
2428

2529

2630
def down_size(size):
@@ -129,6 +133,173 @@ def to_float(
129133
return a * scale.unsqueeze(1)
130134

131135

136+
@deprecated("QuantizedLinearWeightBase is deleted from torchao. DO NOT USE!")
137+
class QuantizedLinearWeightBase(torch.Tensor):
138+
"""
139+
*** LEGACY TORCHAO TENSOR SUBCLASS ***
140+
141+
Note: this subclass no longer exists in torchao. No one should be importing or extending this
142+
subclass anymore. It is added back here just for internal executorch BC. DO NOT USE!
143+
144+
Base quantized tensor subclass for quantized linear weights. When the from_float method is used,
145+
to create an instance of any QuantizedLinearWeightBase, we assume the input
146+
weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels.
147+
148+
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
149+
regardless of the internal representation's type or orientation.
150+
"""
151+
152+
@staticmethod
153+
def __new__(cls, int_data, transposed, shape, *args, **kwargs):
154+
kwargs["device"] = int_data.device
155+
kwargs["layout"] = (
156+
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
157+
)
158+
assert "dtype" in kwargs
159+
assert not kwargs.get("requires_grad", False)
160+
kwargs["requires_grad"] = False
161+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
162+
163+
def __init__(self, int_data, transposed, *args, **kwargs):
164+
self.int_data = int_data
165+
166+
self.transposed = transposed
167+
168+
@staticmethod
169+
def _quantized_op(act_mat, w_qtensor, bias):
170+
pass
171+
172+
def __repr__(self):
173+
return (
174+
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
175+
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
176+
)
177+
178+
def dequantize(self):
179+
pass
180+
181+
def int_repr(self):
182+
pass
183+
184+
def q_params(self):
185+
pass
186+
187+
def half(self):
188+
return self.to(torch.float16)
189+
190+
def _get_to_kwargs(self, *args, **kwargs):
191+
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
192+
device = self.device if device is None else device
193+
dtype = self.dtype if dtype is None else dtype
194+
memory_format = (
195+
memory_format if memory_format is not None else torch.preserve_format
196+
)
197+
kwargs = {
198+
"device": device,
199+
"dtype": dtype,
200+
"memory_format": memory_format,
201+
}
202+
return kwargs
203+
204+
def _apply_fn_to_data(self, fn):
205+
pass
206+
207+
def _change_shape(self):
208+
pass
209+
210+
def __tensor_flatten__(self):
211+
pass
212+
213+
@classmethod
214+
def __tensor_unflatten__(
215+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
216+
):
217+
pass
218+
219+
@classmethod
220+
def from_float(cls, input_float):
221+
pass
222+
223+
# __torch_function__ = torch._C._disabled_torch_function_impl
224+
225+
@classmethod
226+
def __torch_function__(cls, func, types, args=(), kwargs=None):
227+
kwargs = {} if kwargs is None else kwargs
228+
229+
if func is torch.nn.functional.linear:
230+
mat1, w_qtensor, bias = (
231+
args[0],
232+
args[1],
233+
args[2] if len(args) > 2 else None,
234+
)
235+
assert not w_qtensor.transposed
236+
return cls._quantized_op(mat1, w_qtensor, bias)
237+
238+
try:
239+
with torch._C.DisableTorchFunctionSubclass():
240+
return func(*args, **kwargs)
241+
except Exception:
242+
print(f"ERR: subclass doesn't implement {func}")
243+
244+
@classmethod
245+
def __torch_dispatch__(cls, func, types, args, kwargs):
246+
# two scenarios where we currently fall back to vanilla mm:
247+
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
248+
# for consistency and to allow people to test
249+
# 2 - we're given non-floats - quantizing long to int8 is crazy
250+
if (
251+
func in [aten.mm.default, aten.addmm.default]
252+
and args[0].is_floating_point()
253+
and args[0].is_cuda
254+
):
255+
if func == aten.addmm.default:
256+
assert args[1].shape[-1] == args[2].shape[0], (
257+
f"need mat1 shape: {args[1].shape} final"
258+
f"dim to match mat2 shape: {args[2].shape} first dim "
259+
)
260+
mat1, w_qtensor, bias = (
261+
args[1],
262+
args[2],
263+
args[0],
264+
)
265+
else:
266+
assert args[0].shape[-1] == args[1].shape[0], (
267+
f"need mat1 shape: {args[0].shape} final dim"
268+
f"to match mat2 shape: {args[1].shape} first dim"
269+
)
270+
mat1, w_qtensor, bias = (
271+
args[0],
272+
args[1],
273+
None if len(args) == 2 else args[2],
274+
)
275+
# call the quantized op for the specific type
276+
# of quantized tensor subclass
277+
return cls._quantized_op(mat1, w_qtensor, bias)
278+
279+
if func is aten.detach.default:
280+
return return_and_correct_aliasing(
281+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
282+
)
283+
284+
if func is aten.clone.default:
285+
return return_and_correct_aliasing(
286+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
287+
)
288+
289+
if func is aten.t.default:
290+
args[0].transposed = not args[0].transposed
291+
new = args[0]._change_shape(args[0].shape[::-1])
292+
return return_and_correct_aliasing(func, args, kwargs, new)
293+
294+
if func is aten._to_copy.default:
295+
return return_and_correct_aliasing(
296+
func,
297+
args,
298+
kwargs,
299+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
300+
)
301+
302+
132303
class GGMLInt4LinearWeight(QuantizedLinearWeightBase):
133304
"""
134305
A Tensor subclass that when applied to a weight used in a linear op/module,

0 commit comments

Comments
 (0)