mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Unify abstractify & shaped_abstractify rules
This commit is contained in:
parent
20efbd965f
commit
c560f8e06c
@ -49,6 +49,7 @@ def masked_array_error(*args, **kwargs):
|
||||
"Use arr.filled() to convert the value to a standard numpy array.")
|
||||
|
||||
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
|
||||
core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error
|
||||
|
||||
|
||||
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
@ -56,14 +57,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
|
||||
dtype = x.dtype
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape,
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
|
||||
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
|
||||
core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
|
||||
|
||||
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
@ -71,15 +66,9 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
|
||||
dtype = np.dtype(x)
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x),
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
|
||||
for t in numpy_scalar_types:
|
||||
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
|
||||
core.shaped_abstractify_handlers[t] = _np_scalar_abstractify
|
||||
core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar
|
||||
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
@ -90,13 +79,8 @@ def _make_abstract_python_scalar(typ, val):
|
||||
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
|
||||
weak_type=typ is not bool)
|
||||
|
||||
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
|
||||
typ = type(x)
|
||||
dtype = dtypes._scalar_type_to_dtype(typ, x)
|
||||
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
|
||||
|
||||
for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
|
||||
core.shaped_abstractify_handlers[t] = _python_scalar_abstractify
|
||||
core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t)
|
||||
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|
||||
|
@ -2564,6 +2564,7 @@ def _sds_aval_mapping(x):
|
||||
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
|
||||
weak_type=x.weak_type)
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
|
||||
core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping
|
||||
|
||||
|
||||
@api_boundary
|
||||
|
@ -1832,10 +1832,11 @@ class DShapedArray(UnshapedArray):
|
||||
self.weak_type)
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {}
|
||||
shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
def _str_abstractify(x):
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
|
||||
pytype_aval_mappings[str] = _str_abstractify
|
||||
shaped_abstractify_handlers[str] = _str_abstractify
|
||||
|
||||
class DArray:
|
||||
@ -1889,9 +1890,12 @@ class DArray:
|
||||
data = self._data[slices]
|
||||
return data
|
||||
|
||||
def _darray_aval(x):
|
||||
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
|
||||
|
||||
pytype_aval_mappings[DArray] = _darray_aval
|
||||
shaped_abstractify_handlers[DArray] = _darray_aval
|
||||
|
||||
pytype_aval_mappings[DArray] = \
|
||||
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class bint(dtypes.ExtendedDType):
|
||||
@ -1924,6 +1928,7 @@ class MutableArray:
|
||||
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
|
||||
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
|
||||
pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
shaped_abstractify_handlers[MutableArray] = lambda x: x._aval
|
||||
|
||||
def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
@ -1979,6 +1984,7 @@ class Token:
|
||||
def block_until_ready(self):
|
||||
self._buf.block_until_ready()
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
shaped_abstractify_handlers[Token] = lambda _: abstract_token
|
||||
|
||||
|
||||
# TODO(dougalm): Deprecate these. They're just here for backwards compat.
|
||||
|
@ -1205,6 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
|
||||
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
|
||||
|
||||
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
|
||||
core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval
|
||||
dtypes._weak_types.append(_DimExpr)
|
||||
|
||||
def _convertible_to_int(p: DimSize) -> bool:
|
||||
|
@ -191,6 +191,7 @@ class _ScalarMeta(type):
|
||||
|
||||
def _abstractify_scalar_meta(x):
|
||||
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
|
||||
core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta
|
||||
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
|
||||
|
||||
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
|
||||
|
Loading…
x
Reference in New Issue
Block a user