mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Cleanup: toward merging core.concrete_aval & xla.abstractify
This commit is contained in:
parent
772339ec60
commit
2c722d9b13
@ -23,8 +23,8 @@ import jax
|
||||
from jax import lax
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import xla
|
||||
from jax._src import array
|
||||
from jax._src import op_shardings
|
||||
from jax._src.pjit import pjit_check_aval_sharding
|
||||
@ -427,7 +427,7 @@ def bench_shaped_abstractify(state):
|
||||
|
||||
def _run_benchmark_for_xla_abstractify(arg, state):
|
||||
while state:
|
||||
xla.abstractify(arg)
|
||||
core.abstractify(arg)
|
||||
|
||||
def bench_xla_abstractify():
|
||||
_abstractify_args = [
|
||||
|
@ -50,23 +50,49 @@ def canonical_concrete_aval(val, weak_type=None):
|
||||
sharding = core._get_abstract_sharding(val)
|
||||
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)
|
||||
|
||||
|
||||
def masked_array_error(*args, **kwargs):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
"Use arr.filled() to convert the value to a standard numpy array.")
|
||||
|
||||
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
|
||||
|
||||
for t in array_types:
|
||||
|
||||
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
dtype = x.dtype
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
core.pytype_aval_mappings[np.ndarray] = canonical_concrete_aval
|
||||
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
|
||||
|
||||
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
dtype = np.dtype(x)
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
for t in numpy_scalar_types:
|
||||
core.pytype_aval_mappings[t] = canonical_concrete_aval
|
||||
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
|
||||
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
|
||||
def _make_concrete_python_scalar(t, x):
|
||||
dtype = dtypes._scalar_type_to_dtype(t, x)
|
||||
weak_type = dtypes.is_weakly_typed(x)
|
||||
return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type)
|
||||
|
||||
|
||||
def _make_abstract_python_scalar(typ, val):
|
||||
# Note: all python scalar types are weak except bool, because bool only
|
||||
# comes in a single width.
|
||||
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
|
||||
weak_type=typ is not bool)
|
||||
|
||||
for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
||||
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
|
||||
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|
||||
|
@ -600,7 +600,7 @@ def _shaped_abstractify_slow(x):
|
||||
"does not have a dtype attribute")
|
||||
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
|
||||
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
|
||||
# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior
|
||||
def shaped_abstractify(x):
|
||||
handler = _shaped_abstractify_handlers.get(type(x), None)
|
||||
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
|
||||
|
@ -1029,7 +1029,7 @@ def make_array_from_single_device_arrays(
|
||||
|
||||
|
||||
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
|
||||
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
|
||||
core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
|
||||
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
||||
def _get_aval_array(self):
|
||||
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
|
||||
|
@ -1386,7 +1386,16 @@ def check_valid_jaxtype(x):
|
||||
f"Value {x!r} of type {type(x)} is not a valid JAX type")
|
||||
|
||||
|
||||
# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
|
||||
# This is tricky because concrete_aval includes sharding information, and
|
||||
# abstractify does not; further, because abstractify is in the dispatch path,
|
||||
# performance is important and simply adding sharding there is not an option.
|
||||
def concrete_aval(x):
|
||||
# This differs from abstractify below in that the abstract values
|
||||
# include sharding where applicable. Historically (before stackless)
|
||||
# the returned avals were concrete, but after the stackless change
|
||||
# this returns ShapedArray like abstractify.
|
||||
# Rules are registered in pytype_aval_mappings.
|
||||
for typ in type(x).__mro__:
|
||||
handler = pytype_aval_mappings.get(typ)
|
||||
if handler: return handler(x)
|
||||
@ -1396,6 +1405,22 @@ def concrete_aval(x):
|
||||
"type")
|
||||
|
||||
|
||||
def abstractify(x):
|
||||
# Historically, this was called xla.abstractify. It differs from
|
||||
# concrete_aval in that it excludes sharding information, and
|
||||
# uses a more performant path for accessing avals. Rules are
|
||||
# registered in xla_pytype_aval_mappings.
|
||||
typ = type(x)
|
||||
aval_fn = xla_pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
for typ in typ.__mro__:
|
||||
aval_fn = xla_pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return abstractify(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
||||
|
||||
|
||||
def get_aval(x):
|
||||
if isinstance(x, Tracer):
|
||||
return x.aval
|
||||
@ -1793,6 +1818,7 @@ class DShapedArray(UnshapedArray):
|
||||
self.weak_type)
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
|
||||
class DArray:
|
||||
@ -1849,6 +1875,7 @@ class DArray:
|
||||
|
||||
pytype_aval_mappings[DArray] = \
|
||||
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
|
||||
xla_pytype_aval_mappings[DArray] = lambda x: x._aval
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class bint(dtypes.ExtendedDType):
|
||||
@ -1881,6 +1908,7 @@ class MutableArray:
|
||||
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
|
||||
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
|
||||
pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
|
||||
def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
@ -1934,6 +1962,7 @@ class Token:
|
||||
def block_until_ready(self):
|
||||
self._buf.block_until_ready()
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
xla_pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
|
||||
|
||||
# TODO(dougalm): Deprecate these. They're just here for backwards compat.
|
||||
|
@ -457,7 +457,7 @@ def _device_put_impl(
|
||||
" please provide a concrete Sharding with memory_kind.")
|
||||
|
||||
try:
|
||||
aval = xla.abstractify(x)
|
||||
aval = core.abstractify(x)
|
||||
except TypeError as err:
|
||||
raise TypeError(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
||||
|
@ -35,7 +35,6 @@ import numpy as np
|
||||
import opt_einsum
|
||||
|
||||
import jax
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -1206,7 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
|
||||
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
|
||||
|
||||
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
|
||||
xla.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
|
||||
core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
|
||||
dtypes._weak_types.append(_DimExpr)
|
||||
|
||||
def _convertible_to_int(p: DimSize) -> bool:
|
||||
|
@ -1825,7 +1825,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
|
||||
def aval(v: core.Atom) -> core.AbstractValue:
|
||||
if type(v) is core.Literal:
|
||||
return xla.abstractify(v.val)
|
||||
return core.abstractify(v.val)
|
||||
else:
|
||||
return v.aval
|
||||
|
||||
|
@ -349,7 +349,7 @@ def xla_pmap_impl_lazy(
|
||||
donated_invars=donated_invars,
|
||||
is_explicit_global_axis_size=is_explicit_global_axis_size)
|
||||
return _emap_apply_fn
|
||||
abstract_args = unsafe_map(xla.abstractify, args)
|
||||
abstract_args = unsafe_map(core.abstractify, args)
|
||||
compiled_fun, fingerprint = parallel_callable(
|
||||
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
@ -360,7 +360,7 @@ def xla_pmap_impl_lazy(
|
||||
distributed_debug_log(("Running pmapped function", name),
|
||||
("python function", fun.f),
|
||||
("devices", devices),
|
||||
("abstract args", map(xla.abstractify, args)),
|
||||
("abstract args", map(core.abstractify, args)),
|
||||
("fingerprint", fingerprint))
|
||||
return compiled_fun
|
||||
|
||||
@ -598,7 +598,7 @@ class MapTracer(core.Tracer):
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
aval = xla.abstractify(self.val)
|
||||
aval = core.abstractify(self.val)
|
||||
shard_axes = dict(self.shard_axes)
|
||||
for axis_idx in sorted(shard_axes.values())[::-1]:
|
||||
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
|
||||
@ -1145,7 +1145,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
@profiler.annotate_function
|
||||
def call(self, *args):
|
||||
# TODO(frostig): do we need to check sharding and sharded avals?
|
||||
arg_avals = map(xla.abstractify, args)
|
||||
arg_avals = map(core.abstractify, args)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
@ -3090,7 +3090,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
ref_avals = self._all_args_info.in_avals
|
||||
debug_info = self._all_args_info.debug_info
|
||||
|
||||
all_arg_avals = map(xla.abstractify, kept_args)
|
||||
all_arg_avals = map(core.abstractify, kept_args)
|
||||
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
|
||||
check_array_xla_sharding_layout_match(
|
||||
args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info,
|
||||
|
@ -146,44 +146,12 @@ canonicalize_dtype_handlers[core.Token] = identity
|
||||
canonicalize_dtype_handlers[core.DArray] = identity
|
||||
canonicalize_dtype_handlers[core.MutableArray] = identity
|
||||
|
||||
# TODO(jakevdp): deprecate and remove this.
|
||||
def abstractify(x) -> Any:
|
||||
typ = type(x)
|
||||
aval_fn = pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
for typ in typ.__mro__:
|
||||
aval_fn = pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return abstractify(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
||||
|
||||
def _make_abstract_python_scalar(typ, val):
|
||||
# Note: all python scalar types are weak except bool, because bool only
|
||||
# comes in a single width.
|
||||
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
|
||||
weak_type=typ is not bool)
|
||||
|
||||
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
dtype = np.dtype(x)
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
dtype = x.dtype
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
|
||||
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
|
||||
pytype_aval_mappings[core.DArray] = lambda x: x._aval
|
||||
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
|
||||
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
|
||||
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
|
||||
for t in numpy_scalar_types)
|
||||
pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
pytype_aval_mappings.update(
|
||||
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
|
||||
return core.abstractify(x)
|
||||
|
||||
# TODO(jakevdp): deprecate and remove this.
|
||||
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings
|
||||
|
||||
initial_style_primitives: set[core.Primitive] = set()
|
||||
|
||||
|
@ -1689,7 +1689,7 @@ def _pjit_call_impl_python(
|
||||
("out_shardings", out_shardings),
|
||||
("in_layouts", in_layouts),
|
||||
("out_layouts", out_layouts),
|
||||
("abstract args", map(xla.abstractify, args)),
|
||||
("abstract args", map(core.abstractify, args)),
|
||||
("fingerprint", fingerprint))
|
||||
try:
|
||||
return compiled.unsafe_call(*args), compiled, pgle_profiler
|
||||
|
@ -463,7 +463,7 @@ class KeyTy(dtypes.ExtendedDType):
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
|
@ -38,7 +38,6 @@ from jax import tree_util
|
||||
from jax import sharding
|
||||
from jax import export
|
||||
from jax.experimental.jax2tf import impl_no_xla
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
@ -1153,7 +1152,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
else:
|
||||
return val, jax_dtype
|
||||
else: # A constant
|
||||
jax_dtype = jax_dtype or xla.abstractify(val).dtype
|
||||
jax_dtype = jax_dtype or core.abstractify(val).dtype
|
||||
# TODO(document): We assume that the value of a constant does not
|
||||
# change through the scope of the function. But it may be an ndarray, ...
|
||||
# JAX has the same problem when generating HLO.
|
||||
|
@ -26,7 +26,6 @@ from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api import make_jaxpr
|
||||
from jax._src.interpreters.partial_eval import dce_jaxpr
|
||||
from jax._src.interpreters.xla import abstractify
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
|
||||
from jax.experimental import shard_map
|
||||
@ -142,14 +141,14 @@ def _roofline_interpreter(
|
||||
|
||||
def read(v: core.Atom) -> RooflineShape:
|
||||
if type(v) is core.Literal:
|
||||
return RooflineShape.from_aval(abstractify(v.val))
|
||||
return RooflineShape.from_aval(core.abstractify(v.val))
|
||||
else:
|
||||
assert isinstance(v, core.Var)
|
||||
return env[v]
|
||||
|
||||
def aval(v: core.Atom) -> core.AbstractValue:
|
||||
if type(v) is core.Literal:
|
||||
return abstractify(v.val)
|
||||
return core.abstractify(v.val)
|
||||
else:
|
||||
return v.aval
|
||||
|
||||
|
@ -3932,7 +3932,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
core.pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
|
||||
xla.pytype_aval_mappings[FooArray] = \
|
||||
core.xla_pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
|
||||
mlir._constant_handlers[FooArray] = foo_array_constant_handler
|
||||
@ -3946,7 +3946,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
def tearDown(self):
|
||||
del core.pytype_aval_mappings[FooArray]
|
||||
del xla.canonicalize_dtype_handlers[FooArray]
|
||||
del xla.pytype_aval_mappings[FooArray]
|
||||
del core.xla_pytype_aval_mappings[FooArray]
|
||||
del mlir._constant_handlers[FooArray]
|
||||
del mlir._lowerings[make_p]
|
||||
del mlir._lowerings[bake_p]
|
||||
|
Loading…
x
Reference in New Issue
Block a user