From c560f8e06c0cfb17bd27f7a1ad65c07963f332c7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 19 Dec 2024 11:40:11 -0800 Subject: [PATCH] Unify abstractify & shaped_abstractify rules --- jax/_src/abstract_arrays.py | 24 ++++-------------------- jax/_src/api.py | 1 + jax/_src/core.py | 12 +++++++++--- jax/_src/export/shape_poly.py | 1 + jax/_src/numpy/lax_numpy.py | 1 + 5 files changed, 16 insertions(+), 23 deletions(-) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 1cc8c4e48..2502b705b 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -49,6 +49,7 @@ def masked_array_error(*args, **kwargs): "Use arr.filled() to convert the value to a standard numpy array.") core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error +core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: @@ -56,14 +57,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: dtypes.check_valid_dtype(dtype) return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) -def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray: - dtype = x.dtype - dtypes.check_valid_dtype(dtype) - return ShapedArray(x.shape, - dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) - core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array -core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify +core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: @@ -71,15 +66,9 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: dtypes.check_valid_dtype(dtype) return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype)) -def _np_scalar_abstractify(x: np.generic) -> ShapedArray: - dtype = np.dtype(x) - dtypes.check_valid_dtype(dtype) - return ShapedArray(np.shape(x), - dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) - for t in numpy_scalar_types: core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar - core.shaped_abstractify_handlers[t] = _np_scalar_abstractify + core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar core.literalable_types.update(array_types) @@ -90,13 +79,8 @@ def _make_abstract_python_scalar(typ, val): return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), weak_type=typ is not bool) -def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray: - typ = type(x) - dtype = dtypes._scalar_type_to_dtype(typ, x) - return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types) - for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) - core.shaped_abstractify_handlers[t] = _python_scalar_abstractify + core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t) core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4bf964a72..38ba4fd2d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2564,6 +2564,7 @@ def _sds_aval_mapping(x): x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=x.weak_type) core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping +core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping @api_boundary diff --git a/jax/_src/core.py b/jax/_src/core.py index 13fbd78eb..5f351bd46 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1832,10 +1832,11 @@ class DShapedArray(UnshapedArray): self.weak_type) pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} -shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {} +shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {} def _str_abstractify(x): raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") +pytype_aval_mappings[str] = _str_abstractify shaped_abstractify_handlers[str] = _str_abstractify class DArray: @@ -1889,9 +1890,12 @@ class DArray: data = self._data[slices] return data +def _darray_aval(x): + return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) + +pytype_aval_mappings[DArray] = _darray_aval +shaped_abstractify_handlers[DArray] = _darray_aval -pytype_aval_mappings[DArray] = \ - lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) @dataclass(frozen=True) class bint(dtypes.ExtendedDType): @@ -1924,6 +1928,7 @@ class MutableArray: def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) pytype_aval_mappings[MutableArray] = lambda x: x._aval +shaped_abstractify_handlers[MutableArray] = lambda x: x._aval def mutable_array(init_val): return mutable_array_p.bind(init_val) @@ -1979,6 +1984,7 @@ class Token: def block_until_ready(self): self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token +shaped_abstractify_handlers[Token] = lambda _: abstract_token # TODO(dougalm): Deprecate these. They're just here for backwards compat. diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 5462723c8..b82890cab 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1205,6 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool: f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}") core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval +core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval dtypes._weak_types.append(_DimExpr) def _convertible_to_int(p: DimSize) -> bool: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b587fd0c5..d191907be 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -191,6 +191,7 @@ class _ScalarMeta(type): def _abstractify_scalar_meta(x): raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.") +core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: