Merge pull request #25595 from jakevdp:mv-shaped-abstractify

PiperOrigin-RevId: 707888615
This commit is contained in:
jax authors 2024-12-19 06:07:14 -08:00
commit 7680532512
8 changed files with 56 additions and 57 deletions

View File

@ -56,7 +56,14 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
@ -64,8 +71,15 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x),
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.shaped_abstractify_handlers[t] = _np_scalar_abstractify
core.literalable_types.update(array_types)
@ -76,7 +90,13 @@ def _make_abstract_python_scalar(typ, val):
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
typ = type(x)
dtype = dtypes._scalar_type_to_dtype(typ, x)
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.shaped_abstractify_handlers[t] = _python_scalar_abstractify
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

View File

@ -20,14 +20,10 @@ import operator
from functools import partial, lru_cache
from typing import Any
import numpy as np
from jax._src import core
from jax._src import config
from jax._src import dtypes
from jax._src.state.types import AbstractRef
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
treedef_children, generate_key_paths, keystr, broadcast_prefix,
@ -587,54 +583,13 @@ def _dtype(x):
except ValueError:
return dtypes.result_type(getattr(x, 'dtype'))
def _shaped_abstractify_slow(x):
try:
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
except TypeError:
pass
weak_type = getattr(x, 'weak_type', False)
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior
# TODO(jakevdp): fix downstream consumers and remove this.
def shaped_abstractify(x):
handler = _shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
return core.shaped_abstractify(x)
_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
_shaped_abstractify_handlers[str] = _str_abstractify
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x),
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers.update((t, _np_scalar_abstractify)
for t in numpy_scalar_types)
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
typ = type(x)
dtype = dtypes._scalar_type_to_dtype(typ, x)
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
_shaped_abstractify_handlers.update((t, _python_scalar_abstractify)
for t in dtypes.python_scalar_dtypes)
# TODO(jakevdp): fix downstream consumers and remove this.
_shaped_abstractify_handlers = core.shaped_abstractify_handlers
# This decorator exists to make it easier to monkey-patch APIs in JAX.
# By default it does nothing, but it can be monkey-patched to do other things.

View File

@ -1036,7 +1036,7 @@ def _get_aval_array(self):
else:
return self.aval
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
# TODO(jakevdp) replace this with true inheritance at the C++ level.

View File

@ -1400,6 +1400,29 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")
def _shaped_abstractify_slow(x):
try:
return x if isinstance(x, AbstractValue) else get_aval(x)
except TypeError:
pass
weak_type = getattr(x, 'weak_type', False)
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# TODO(jakevdp): deduplicate this with abstractify
def shaped_abstractify(x):
# This was originally api_util.shaped_abstractify; temporarily moved
# here in order to facilitate combining it with abstractify.
handler = shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
def abstractify(x):
for typ in type(x).__mro__:
aval_fn = pytype_aval_mappings.get(typ)
@ -1809,7 +1832,11 @@ class DShapedArray(UnshapedArray):
self.weak_type)
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
shaped_abstractify_handlers[str] = _str_abstractify
class DArray:
_aval: DShapedArray

View File

@ -16,7 +16,6 @@ from __future__ import annotations
import math
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import tree_util
@ -116,7 +115,7 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(

View File

@ -1572,7 +1572,7 @@ class DynamicJaxprTracer(core.Tracer):
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return x.aval
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()

View File

@ -43,7 +43,6 @@ import jax
from jax import errors
from jax import jit
from jax import lax
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import deprecations
@ -192,7 +191,7 @@ class _ScalarMeta(type):
def _abstractify_scalar_meta(x):
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
api_util._shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),

View File

@ -26,7 +26,6 @@ from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax._src import api_util
from jax._src import api
from jax._src import config as config
from jax._src import core
@ -303,7 +302,6 @@ _set_array_base_attributes(PRNGKeyArray, include=[
'at', 'flatten', 'ravel', 'reshape',
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
api_util._shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
def prngkeyarray_flatten(x):
return (x._base_array,), x._impl
@ -463,6 +461,7 @@ class KeyTy(dtypes.ExtendedDType):
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x