mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #25595 from jakevdp:mv-shaped-abstractify
PiperOrigin-RevId: 707888615
This commit is contained in:
commit
7680532512
@ -56,7 +56,14 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
|
||||
dtype = x.dtype
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape,
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
|
||||
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
|
||||
|
||||
|
||||
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
@ -64,8 +71,15 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
|
||||
dtype = np.dtype(x)
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x),
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
|
||||
for t in numpy_scalar_types:
|
||||
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
|
||||
core.shaped_abstractify_handlers[t] = _np_scalar_abstractify
|
||||
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
@ -76,7 +90,13 @@ def _make_abstract_python_scalar(typ, val):
|
||||
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
|
||||
weak_type=typ is not bool)
|
||||
|
||||
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
|
||||
typ = type(x)
|
||||
dtype = dtypes._scalar_type_to_dtype(typ, x)
|
||||
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
|
||||
|
||||
for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
|
||||
core.shaped_abstractify_handlers[t] = _python_scalar_abstractify
|
||||
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|
||||
|
@ -20,14 +20,10 @@ import operator
|
||||
from functools import partial, lru_cache
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.tree_util import (
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_children, generate_key_paths, keystr, broadcast_prefix,
|
||||
@ -587,54 +583,13 @@ def _dtype(x):
|
||||
except ValueError:
|
||||
return dtypes.result_type(getattr(x, 'dtype'))
|
||||
|
||||
def _shaped_abstractify_slow(x):
|
||||
try:
|
||||
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
weak_type = getattr(x, 'weak_type', False)
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {type(x)} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
|
||||
# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior
|
||||
# TODO(jakevdp): fix downstream consumers and remove this.
|
||||
def shaped_abstractify(x):
|
||||
handler = _shaped_abstractify_handlers.get(type(x), None)
|
||||
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
|
||||
return core.shaped_abstractify(x)
|
||||
|
||||
_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {}
|
||||
|
||||
|
||||
def _str_abstractify(x):
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
|
||||
_shaped_abstractify_handlers[str] = _str_abstractify
|
||||
|
||||
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
|
||||
dtype = x.dtype
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape,
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
|
||||
|
||||
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
|
||||
dtype = np.dtype(x)
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x),
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
_shaped_abstractify_handlers.update((t, _np_scalar_abstractify)
|
||||
for t in numpy_scalar_types)
|
||||
|
||||
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
|
||||
typ = type(x)
|
||||
dtype = dtypes._scalar_type_to_dtype(typ, x)
|
||||
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
|
||||
_shaped_abstractify_handlers.update((t, _python_scalar_abstractify)
|
||||
for t in dtypes.python_scalar_dtypes)
|
||||
# TODO(jakevdp): fix downstream consumers and remove this.
|
||||
_shaped_abstractify_handlers = core.shaped_abstractify_handlers
|
||||
|
||||
# This decorator exists to make it easier to monkey-patch APIs in JAX.
|
||||
# By default it does nothing, but it can be monkey-patched to do other things.
|
||||
|
@ -1036,7 +1036,7 @@ def _get_aval_array(self):
|
||||
else:
|
||||
return self.aval
|
||||
|
||||
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
|
||||
core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
|
||||
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
|
||||
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
|
@ -1400,6 +1400,29 @@ def check_valid_jaxtype(x):
|
||||
f"Value {x!r} of type {type(x)} is not a valid JAX type")
|
||||
|
||||
|
||||
def _shaped_abstractify_slow(x):
|
||||
try:
|
||||
return x if isinstance(x, AbstractValue) else get_aval(x)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
weak_type = getattr(x, 'weak_type', False)
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {type(x)} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
|
||||
# TODO(jakevdp): deduplicate this with abstractify
|
||||
def shaped_abstractify(x):
|
||||
# This was originally api_util.shaped_abstractify; temporarily moved
|
||||
# here in order to facilitate combining it with abstractify.
|
||||
handler = shaped_abstractify_handlers.get(type(x), None)
|
||||
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
|
||||
|
||||
|
||||
def abstractify(x):
|
||||
for typ in type(x).__mro__:
|
||||
aval_fn = pytype_aval_mappings.get(typ)
|
||||
@ -1809,7 +1832,11 @@ class DShapedArray(UnshapedArray):
|
||||
self.weak_type)
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {}
|
||||
|
||||
def _str_abstractify(x):
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
|
||||
shaped_abstractify_handlers[str] = _str_abstractify
|
||||
|
||||
class DArray:
|
||||
_aval: DShapedArray
|
||||
|
@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
@ -116,7 +115,7 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
|
||||
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
|
||||
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
||||
|
||||
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
||||
core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
||||
core.pytype_aval_mappings[EArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
|
||||
tree_util.dispatch_registry.register_node(
|
||||
|
@ -1572,7 +1572,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
|
||||
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
|
||||
return x.aval
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
|
||||
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
|
||||
|
||||
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
sentinel = object()
|
||||
|
@ -43,7 +43,6 @@ import jax
|
||||
from jax import errors
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
@ -192,7 +191,7 @@ class _ScalarMeta(type):
|
||||
|
||||
def _abstractify_scalar_meta(x):
|
||||
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
|
||||
api_util._shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
|
||||
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
|
||||
|
||||
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
|
||||
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
|
||||
|
@ -26,7 +26,6 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import api
|
||||
from jax._src import config as config
|
||||
from jax._src import core
|
||||
@ -303,7 +302,6 @@ _set_array_base_attributes(PRNGKeyArray, include=[
|
||||
'at', 'flatten', 'ravel', 'reshape',
|
||||
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
|
||||
|
||||
api_util._shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
|
||||
|
||||
def prngkeyarray_flatten(x):
|
||||
return (x._base_array,), x._impl
|
||||
@ -463,6 +461,7 @@ class KeyTy(dtypes.ExtendedDType):
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user