From 676070f4cd587bd6a1d84ffc8d82439639735fd0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Dec 2024 15:18:00 -0800 Subject: [PATCH] Refactor: move shaped_abstractify to core --- jax/_src/abstract_arrays.py | 20 ++++++++++ jax/_src/api_util.py | 53 ++------------------------- jax/_src/array.py | 2 +- jax/_src/core.py | 27 ++++++++++++++ jax/_src/earray.py | 3 +- jax/_src/interpreters/partial_eval.py | 2 +- jax/_src/numpy/lax_numpy.py | 3 +- jax/_src/prng.py | 3 +- 8 files changed, 56 insertions(+), 57 deletions(-) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 8ddc33fd8..1cc8c4e48 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -56,7 +56,14 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: dtypes.check_valid_dtype(dtype) return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) +def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray: + dtype = x.dtype + dtypes.check_valid_dtype(dtype) + return ShapedArray(x.shape, + dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) + core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array +core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: @@ -64,8 +71,15 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: dtypes.check_valid_dtype(dtype) return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype)) +def _np_scalar_abstractify(x: np.generic) -> ShapedArray: + dtype = np.dtype(x) + dtypes.check_valid_dtype(dtype) + return ShapedArray(np.shape(x), + dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) + for t in numpy_scalar_types: core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar + core.shaped_abstractify_handlers[t] = _np_scalar_abstractify core.literalable_types.update(array_types) @@ -76,7 +90,13 @@ def _make_abstract_python_scalar(typ, val): return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), weak_type=typ is not bool) +def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray: + typ = type(x) + dtype = dtypes._scalar_type_to_dtype(typ, x) + return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types) + for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) + core.shaped_abstractify_handlers[t] = _python_scalar_abstractify core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index b2bed4d4b..8cdd84be9 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -20,14 +20,10 @@ import operator from functools import partial, lru_cache from typing import Any -import numpy as np - from jax._src import core from jax._src import config from jax._src import dtypes from jax._src.state.types import AbstractRef -from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ShapedArray from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, tree_map, treedef_children, generate_key_paths, keystr, broadcast_prefix, @@ -587,54 +583,13 @@ def _dtype(x): except ValueError: return dtypes.result_type(getattr(x, 'dtype')) -def _shaped_abstractify_slow(x): - try: - return x if isinstance(x, core.AbstractValue) else core.get_aval(x) - except TypeError: - pass - - weak_type = getattr(x, 'weak_type', False) - if hasattr(x, 'dtype'): - dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) - else: - raise TypeError( - f"Cannot interpret value of type {type(x)} as an abstract array; it " - "does not have a dtype attribute") - return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type) - # TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior +# TODO(jakevdp): fix downstream consumers and remove this. def shaped_abstractify(x): - handler = _shaped_abstractify_handlers.get(type(x), None) - return handler(x) if handler is not None else _shaped_abstractify_slow(x) + return core.shaped_abstractify(x) -_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {} - - -def _str_abstractify(x): - raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") -_shaped_abstractify_handlers[str] = _str_abstractify - -def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray: - dtype = x.dtype - dtypes.check_valid_dtype(dtype) - return ShapedArray(x.shape, - dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) -_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify - -def _np_scalar_abstractify(x: np.generic) -> ShapedArray: - dtype = np.dtype(x) - dtypes.check_valid_dtype(dtype) - return ShapedArray(np.shape(x), - dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) -_shaped_abstractify_handlers.update((t, _np_scalar_abstractify) - for t in numpy_scalar_types) - -def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray: - typ = type(x) - dtype = dtypes._scalar_type_to_dtype(typ, x) - return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types) -_shaped_abstractify_handlers.update((t, _python_scalar_abstractify) - for t in dtypes.python_scalar_dtypes) +# TODO(jakevdp): fix downstream consumers and remove this. +_shaped_abstractify_handlers = core.shaped_abstractify_handlers # This decorator exists to make it easier to monkey-patch APIs in JAX. # By default it does nothing, but it can be monkey-patched to do other things. diff --git a/jax/_src/array.py b/jax/_src/array.py index e5d6902d1..db5bfe1bf 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1036,7 +1036,7 @@ def _get_aval_array(self): else: return self.aval -api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array +core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array core.pytype_aval_mappings[ArrayImpl] = _get_aval_array # TODO(jakevdp) replace this with true inheritance at the C++ level. diff --git a/jax/_src/core.py b/jax/_src/core.py index c75206894..13fbd78eb 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1400,6 +1400,29 @@ def check_valid_jaxtype(x): f"Value {x!r} of type {type(x)} is not a valid JAX type") +def _shaped_abstractify_slow(x): + try: + return x if isinstance(x, AbstractValue) else get_aval(x) + except TypeError: + pass + + weak_type = getattr(x, 'weak_type', False) + if hasattr(x, 'dtype'): + dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) + else: + raise TypeError( + f"Cannot interpret value of type {type(x)} as an abstract array; it " + "does not have a dtype attribute") + return ShapedArray(np.shape(x), dtype, weak_type=weak_type) + +# TODO(jakevdp): deduplicate this with abstractify +def shaped_abstractify(x): + # This was originally api_util.shaped_abstractify; temporarily moved + # here in order to facilitate combining it with abstractify. + handler = shaped_abstractify_handlers.get(type(x), None) + return handler(x) if handler is not None else _shaped_abstractify_slow(x) + + def abstractify(x): for typ in type(x).__mro__: aval_fn = pytype_aval_mappings.get(typ) @@ -1809,7 +1832,11 @@ class DShapedArray(UnshapedArray): self.weak_type) pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} +shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {} +def _str_abstractify(x): + raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") +shaped_abstractify_handlers[str] = _str_abstractify class DArray: _aval: DShapedArray diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 7bade8171..25c2bc2bf 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -16,7 +16,6 @@ from __future__ import annotations import math -from jax._src import api_util from jax._src import basearray from jax._src import core from jax._src import tree_util @@ -116,7 +115,7 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics): return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler -api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval +core.shaped_abstractify_handlers[EArray] = lambda self: self.aval core.pytype_aval_mappings[EArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[EArray] = lambda x: x tree_util.dispatch_registry.register_node( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index f188687f1..c90d65bcb 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1572,7 +1572,7 @@ class DynamicJaxprTracer(core.Tracer): def _dynamic_jaxpr_tracer_shaped_abstractify(x): return x.aval -api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify +core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: sentinel = object() diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 9aa131420..b587fd0c5 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -43,7 +43,6 @@ import jax from jax import errors from jax import jit from jax import lax -from jax._src import api_util from jax._src import config from jax._src import core from jax._src import deprecations @@ -192,7 +191,7 @@ class _ScalarMeta(type): def _abstractify_scalar_meta(x): raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.") -api_util._shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta +core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), diff --git a/jax/_src/prng.py b/jax/_src/prng.py index d0f3b644b..d29bad5d5 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -26,7 +26,6 @@ from jax import lax from jax import numpy as jnp from jax import tree_util -from jax._src import api_util from jax._src import api from jax._src import config as config from jax._src import core @@ -303,7 +302,6 @@ _set_array_base_attributes(PRNGKeyArray, include=[ 'at', 'flatten', 'ravel', 'reshape', 'squeeze', 'swapaxes', 'take', 'transpose', 'T']) -api_util._shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval') def prngkeyarray_flatten(x): return (x._base_array,), x._impl @@ -463,6 +461,7 @@ class KeyTy(dtypes.ExtendedDType): core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval +core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x