1from __future__ import annotations
2from abc import ABC, ABCMeta, abstractmethod
3
4import logging
5from typing import Any, Callable
6
7import numpy as np
8from numpy.typing import DTypeLike
9
10
11logger = logging.getLogger(__name__)
12
13
14class LazyMeta(ABCMeta):
15
16 def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
17 def __getattr__(self, name: str) -> Any:
18 meta_attr = getattr(self._meta, name)
19 if callable(meta_attr):
20 return type(self)._wrap_fn(
21 (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
22 use_self=self,
23 )
24 elif isinstance(meta_attr, self._tensor_type):
25 # e.g. self.T with torch.Tensor should still be wrapped
26 return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
27 else:
28 # no need to wrap non-tensor properties,
29 # and they likely don't depend on the actual contents of the tensor
30 return meta_attr
31
32 namespace["__getattr__"] = __getattr__
33
34 # need to make a builder for the wrapped wrapper to copy the name,
35 # or else it fails with very cryptic error messages,
36 # because somehow the same string would end up in every closures
37 def mk_wrap(op_name: str, *, meta_noop: bool = False):
38 # need to wrap the wrapper to get self
39 def wrapped_special_op(self, *args, **kwargs):
40 return type(self)._wrap_fn(
41 getattr(type(self)._tensor_type, op_name),
42 meta_noop=meta_noop,
43 )(self, *args, **kwargs)
44 return wrapped_special_op
45
46 # special methods bypass __getattr__, so they need to be added manually
47 # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
48 # NOTE: doing this from a metaclass is very convenient
49 # TODO: make this even more comprehensive
50 for binary_op in (
51 "lt", "le", "eq", "ne", "ge", "gt",
52 "add", "and", "floordiv", "lshift", "mod", "mul", "matmul",
53 "or", "pow", "rshift", "sub", "truediv", "xor",
54 "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
55 "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
56 ):
57 attr_name = f"__{binary_op}__"
58 # evaluation on the meta tensor is needed in case there's broadcasting
59 namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
60
61 for unary_op in ("not", "abs", "invert", "neg", "pos"):
62 attr_name = f"__{unary_op}__"
63 # the result of these operators usually has the same shape and dtype as the input,
64 # so evaluation on the meta tensor can be skipped.
65 namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
66
67 for special_op in (
68 "getitem", "setitem", "len",
69 ):
70 attr_name = f"__{special_op}__"
71 namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
72
73 return super().__new__(cls, name, bases, namespace, **kwargs)
74
75
76# Tree of lazy tensors
77class LazyBase(ABC, metaclass=LazyMeta):
78 _tensor_type: type
79 _meta: Any
80 _data: Any | None
81 _args: tuple
82 _kwargs: dict[str, Any]
83 _func: Callable[[Any], Any] | None
84
85 def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
86 super().__init__()
87 self._meta = meta
88 self._data = data
89 self._args = args
90 self._kwargs = kwargs if kwargs is not None else {}
91 self._func = func
92 assert self._func is not None or self._data is not None
93
94 def __init_subclass__(cls) -> None:
95 if "_tensor_type" not in cls.__dict__:
96 raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
97 return super().__init_subclass__()
98
99 @staticmethod
100 def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
101 # TODO: dict and set
102 if isinstance(o, (list, tuple)):
103 L = []
104 for item in o:
105 L.append(LazyBase._recurse_apply(item, fn))
106 if isinstance(o, tuple):
107 L = tuple(L)
108 return L
109 elif isinstance(o, LazyBase):
110 return fn(o)
111 else:
112 return o
113
114 @classmethod
115 def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
116 def wrapped_fn(*args, **kwargs):
117 if kwargs is None:
118 kwargs = {}
119 args = ((use_self,) if use_self is not None else ()) + args
120
121 meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
122 # TODO: maybe handle tensors in kwargs too
123
124 if isinstance(meta_noop, bool) and not meta_noop:
125 try:
126 res = fn(*meta_args, **kwargs)
127 except NotImplementedError:
128 # running some operations on PyTorch's Meta tensors can cause this exception
129 res = None
130 else:
131 # some operators don't need to actually run on the meta tensors
132 assert len(args) > 0
133 res = args[0]
134 assert isinstance(res, cls)
135 res = res._meta
136 # allow operations to override the dtype and shape
137 if meta_noop is not True:
138 if isinstance(meta_noop, tuple):
139 dtype, shape = meta_noop
140 assert callable(shape)
141 res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
142 else:
143 res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
144
145 if isinstance(res, cls._tensor_type):
146 return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
147 elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
148 # share the evaluation between lazy tuple elements
149 shared_args: list = [args, None]
150
151 def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
152 assert len(a) == 2
153 if a[1] is None:
154 a[1] = fn(*a[0], **kw)
155 return a[1][i]
156 return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res)))
157 else:
158 del res # not needed
159 # non-tensor return likely relies on the contents of the args
160 # (e.g. the result of torch.equal)
161 eager_args = cls.to_eager(args)
162 return fn(*eager_args, **kwargs)
163 return wrapped_fn
164
165 @classmethod
166 def to_eager(cls, t: Any) -> Any:
167 def simple_to_eager(_t: LazyBase) -> Any:
168 if _t._data is not None:
169 return _t._data
170
171 # NOTE: there's a recursion limit in Python (usually 1000)
172
173 assert _t._func is not None
174 _t._args = cls._recurse_apply(_t._args, simple_to_eager)
175 _t._data = _t._func(*_t._args, **_t._kwargs)
176 # sanity check
177 assert _t._data is not None
178 assert _t._data.dtype == _t._meta.dtype
179 assert _t._data.shape == _t._meta.shape
180
181 return _t._data
182
183 # recurse into lists and/or tuples, keeping their structure
184 return cls._recurse_apply(t, simple_to_eager)
185
186 @classmethod
187 def eager_to_meta(cls, t: Any) -> Any:
188 return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
189
190 # must be overridden, meta tensor init is backend-specific
191 @classmethod
192 @abstractmethod
193 def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
194
195 @classmethod
196 def from_eager(cls, t: Any) -> Any:
197 if type(t) is cls:
198 # already lazy
199 return t
200 elif isinstance(t, cls._tensor_type):
201 return cls(meta=cls.eager_to_meta(t), data=t)
202 else:
203 return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
204
205
206class LazyNumpyTensor(LazyBase):
207 _tensor_type = np.ndarray
208
209 shape: tuple[int, ...] # Makes the type checker happy in quants.py
210
211 @classmethod
212 def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
213 # The initial idea was to use np.nan as the fill value,
214 # but non-float types like np.int16 can't use that.
215 # So zero it is.
216 cheat = np.zeros(1, dtype)
217 return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
218
219 def astype(self, dtype, *args, **kwargs):
220 meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
221 full_args = (self, dtype,) + args
222 return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
223
224 def tofile(self, *args, **kwargs):
225 eager = LazyNumpyTensor.to_eager(self)
226 return eager.tofile(*args, **kwargs)
227
228 # TODO: __array_function__