Unify abstractify & shaped_abstractify rules

This commit is contained in:
Jake VanderPlas 2024-12-19 11:40:11 -08:00
parent 20efbd965f
commit c560f8e06c
5 changed files with 16 additions and 23 deletions

View File

@ -49,6 +49,7 @@ def masked_array_error(*args, **kwargs):
"Use arr.filled() to convert the value to a standard numpy array.")
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
@ -56,14 +57,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
@ -71,15 +66,9 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x),
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.shaped_abstractify_handlers[t] = _np_scalar_abstractify
core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar
core.literalable_types.update(array_types)
@ -90,13 +79,8 @@ def _make_abstract_python_scalar(typ, val):
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
typ = type(x)
dtype = dtypes._scalar_type_to_dtype(typ, x)
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.shaped_abstractify_handlers[t] = _python_scalar_abstractify
core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t)
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

View File

@ -2564,6 +2564,7 @@ def _sds_aval_mapping(x):
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping
@api_boundary

View File

@ -1832,10 +1832,11 @@ class DShapedArray(UnshapedArray):
self.weak_type)
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {}
shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify
shaped_abstractify_handlers[str] = _str_abstractify
class DArray:
@ -1889,9 +1890,12 @@ class DArray:
data = self._data[slices]
return data
def _darray_aval(x):
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
pytype_aval_mappings[DArray] = _darray_aval
shaped_abstractify_handlers[DArray] = _darray_aval
pytype_aval_mappings[DArray] = \
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
@ -1924,6 +1928,7 @@ class MutableArray:
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
shaped_abstractify_handlers[MutableArray] = lambda x: x._aval
def mutable_array(init_val):
return mutable_array_p.bind(init_val)
@ -1979,6 +1984,7 @@ class Token:
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
shaped_abstractify_handlers[Token] = lambda _: abstract_token
# TODO(dougalm): Deprecate these. They're just here for backwards compat.

View File

@ -1205,6 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval
dtypes._weak_types.append(_DimExpr)
def _convertible_to_int(p: DimSize) -> bool:

View File

@ -191,6 +191,7 @@ class _ScalarMeta(type):
def _abstractify_scalar_meta(x):
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: