Cleanup: toward merging core.concrete_aval & xla.abstractify

This commit is contained in:
Jake VanderPlas 2024-12-17 09:27:00 -08:00
parent 772339ec60
commit 2c722d9b13
15 changed files with 79 additions and 59 deletions

View File

@ -23,8 +23,8 @@ import jax
from jax import lax
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src import core
from jax._src.lib import xla_client as xc
from jax.interpreters import xla
from jax._src import array
from jax._src import op_shardings
from jax._src.pjit import pjit_check_aval_sharding
@ -427,7 +427,7 @@ def bench_shaped_abstractify(state):
def _run_benchmark_for_xla_abstractify(arg, state):
while state:
xla.abstractify(arg)
core.abstractify(arg)
def bench_xla_abstractify():
_abstractify_args = [

View File

@ -50,23 +50,49 @@ def canonical_concrete_aval(val, weak_type=None):
sharding = core._get_abstract_sharding(val)
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)
def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
for t in array_types:
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
core.pytype_aval_mappings[np.ndarray] = canonical_concrete_aval
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = canonical_concrete_aval
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.literalable_types.update(array_types)
def _make_concrete_python_scalar(t, x):
dtype = dtypes._scalar_type_to_dtype(t, x)
weak_type = dtypes.is_weakly_typed(x)
return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type)
def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

View File

@ -600,7 +600,7 @@ def _shaped_abstractify_slow(x):
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior
def shaped_abstractify(x):
handler = _shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)

View File

@ -1029,7 +1029,7 @@ def make_array_from_single_device_arrays(
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):

View File

@ -1386,7 +1386,16 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")
# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
# This is tricky because concrete_aval includes sharding information, and
# abstractify does not; further, because abstractify is in the dispatch path,
# performance is important and simply adding sharding there is not an option.
def concrete_aval(x):
# This differs from abstractify below in that the abstract values
# include sharding where applicable. Historically (before stackless)
# the returned avals were concrete, but after the stackless change
# this returns ShapedArray like abstractify.
# Rules are registered in pytype_aval_mappings.
for typ in type(x).__mro__:
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
@ -1396,6 +1405,22 @@ def concrete_aval(x):
"type")
def abstractify(x):
# Historically, this was called xla.abstractify. It differs from
# concrete_aval in that it excludes sharding information, and
# uses a more performant path for accessing avals. Rules are
# registered in xla_pytype_aval_mappings.
typ = type(x)
aval_fn = xla_pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = xla_pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
@ -1793,6 +1818,7 @@ class DShapedArray(UnshapedArray):
self.weak_type)
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
class DArray:
@ -1849,6 +1875,7 @@ class DArray:
pytype_aval_mappings[DArray] = \
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
xla_pytype_aval_mappings[DArray] = lambda x: x._aval
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
@ -1881,6 +1908,7 @@ class MutableArray:
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval
def mutable_array(init_val):
return mutable_array_p.bind(init_val)
@ -1934,6 +1962,7 @@ class Token:
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
xla_pytype_aval_mappings[Token] = lambda _: abstract_token
# TODO(dougalm): Deprecate these. They're just here for backwards compat.

View File

@ -457,7 +457,7 @@ def _device_put_impl(
" please provide a concrete Sharding with memory_kind.")
try:
aval = xla.abstractify(x)
aval = core.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err

View File

@ -35,7 +35,6 @@ import numpy as np
import opt_einsum
import jax
from jax.interpreters import xla
from jax._src import config
from jax._src import core
@ -1206,7 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
xla.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
dtypes._weak_types.append(_DimExpr)
def _convertible_to_int(p: DimSize) -> bool:

View File

@ -1825,7 +1825,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return xla.abstractify(v.val)
return core.abstractify(v.val)
else:
return v.aval

View File

@ -349,7 +349,7 @@ def xla_pmap_impl_lazy(
donated_invars=donated_invars,
is_explicit_global_axis_size=is_explicit_global_axis_size)
return _emap_apply_fn
abstract_args = unsafe_map(xla.abstractify, args)
abstract_args = unsafe_map(core.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
@ -360,7 +360,7 @@ def xla_pmap_impl_lazy(
distributed_debug_log(("Running pmapped function", name),
("python function", fun.f),
("devices", devices),
("abstract args", map(xla.abstractify, args)),
("abstract args", map(core.abstractify, args)),
("fingerprint", fingerprint))
return compiled_fun
@ -598,7 +598,7 @@ class MapTracer(core.Tracer):
@property
def aval(self):
aval = xla.abstractify(self.val)
aval = core.abstractify(self.val)
shard_axes = dict(self.shard_axes)
for axis_idx in sorted(shard_axes.values())[::-1]:
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
@ -1145,7 +1145,7 @@ class PmapExecutable(stages.XlaExecutable):
@profiler.annotate_function
def call(self, *args):
# TODO(frostig): do we need to check sharding and sharded avals?
arg_avals = map(xla.abstractify, args)
arg_avals = map(core.abstractify, args)
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
@ -3090,7 +3090,7 @@ class MeshExecutable(stages.XlaExecutable):
ref_avals = self._all_args_info.in_avals
debug_info = self._all_args_info.debug_info
all_arg_avals = map(xla.abstractify, kept_args)
all_arg_avals = map(core.abstractify, kept_args)
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
check_array_xla_sharding_layout_match(
args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info,

View File

@ -146,44 +146,12 @@ canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.DArray] = identity
canonicalize_dtype_handlers[core.MutableArray] = identity
# TODO(jakevdp): deprecate and remove this.
def abstractify(x) -> Any:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
pytype_aval_mappings[core.DArray] = lambda x: x._aval
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
for t in numpy_scalar_types)
pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
return core.abstractify(x)
# TODO(jakevdp): deprecate and remove this.
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings
initial_style_primitives: set[core.Primitive] = set()

View File

@ -1689,7 +1689,7 @@ def _pjit_call_impl_python(
("out_shardings", out_shardings),
("in_layouts", in_layouts),
("out_layouts", out_layouts),
("abstract args", map(xla.abstractify, args)),
("abstract args", map(core.abstractify, args)),
("fingerprint", fingerprint))
try:
return compiled.unsafe_call(*args), compiled, pgle_profiler

View File

@ -463,7 +463,7 @@ class KeyTy(dtypes.ExtendedDType):
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x

View File

@ -38,7 +38,6 @@ from jax import tree_util
from jax import sharding
from jax import export
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import ad_util
@ -1153,7 +1152,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
else:
return val, jax_dtype
else: # A constant
jax_dtype = jax_dtype or xla.abstractify(val).dtype
jax_dtype = jax_dtype or core.abstractify(val).dtype
# TODO(document): We assume that the value of a constant does not
# change through the scope of the function. But it may be an ndarray, ...
# JAX has the same problem when generating HLO.

View File

@ -26,7 +26,6 @@ from jax._src import traceback_util
from jax._src import util
from jax._src.api import make_jaxpr
from jax._src.interpreters.partial_eval import dce_jaxpr
from jax._src.interpreters.xla import abstractify
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
from jax.experimental import shard_map
@ -142,14 +141,14 @@ def _roofline_interpreter(
def read(v: core.Atom) -> RooflineShape:
if type(v) is core.Literal:
return RooflineShape.from_aval(abstractify(v.val))
return RooflineShape.from_aval(core.abstractify(v.val))
else:
assert isinstance(v, core.Var)
return env[v]
def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return abstractify(v.val)
return core.abstractify(v.val)
else:
return v.aval

View File

@ -3932,7 +3932,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
core.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
xla.pytype_aval_mappings[FooArray] = \
core.xla_pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
mlir._constant_handlers[FooArray] = foo_array_constant_handler
@ -3946,7 +3946,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
def tearDown(self):
del core.pytype_aval_mappings[FooArray]
del xla.canonicalize_dtype_handlers[FooArray]
del xla.pytype_aval_mappings[FooArray]
del core.xla_pytype_aval_mappings[FooArray]
del mlir._constant_handlers[FooArray]
del mlir._lowerings[make_p]
del mlir._lowerings[bake_p]