From 48f24b6acb9fe67dfe227ff3349787b4045c09ff Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Oct 2024 14:06:08 -0700 Subject: [PATCH] Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it. PiperOrigin-RevId: 691929385 --- jax/_src/abstract_arrays.py | 9 +- jax/_src/api.py | 8 +- jax/_src/array.py | 2 - jax/_src/core.py | 108 ++++++---------------- jax/_src/interpreters/ad.py | 3 + jax/_src/interpreters/mlir.py | 1 - jax/_src/interpreters/partial_eval.py | 3 +- jax/_src/interpreters/xla.py | 3 +- jax/_src/lax/control_flow/conditionals.py | 7 +- jax/_src/lax/control_flow/loops.py | 7 +- jax/_src/lax/lax.py | 22 ++--- jax/_src/lax/utils.py | 11 +-- jax/_src/lax/windowed_reductions.py | 11 ++- jax/_src/linear_util.py | 2 +- jax/_src/numpy/lax_numpy.py | 8 +- jax/_src/numpy/ufuncs.py | 2 +- jax/_src/state/indexing.py | 7 +- jax/core.py | 2 +- jax/experimental/shard_map.py | 14 ++- jax/experimental/slab/slab.py | 2 +- tests/api_test.py | 12 --- tests/core_test.py | 7 -- 22 files changed, 83 insertions(+), 168 deletions(-) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 9a49a09c7..95216fb6f 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -24,9 +24,7 @@ from jax._src import dtypes from jax._src import traceback_util traceback_util.register_exclusion(__file__) -UnshapedArray = core.UnshapedArray ShapedArray = core.ShapedArray -ConcreteArray = core.ConcreteArray AbstractToken = core.AbstractToken abstract_token = core.abstract_token canonicalize_shape = core.canonicalize_shape @@ -47,8 +45,11 @@ if dtypes.int2 is not None: array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic def canonical_concrete_aval(val, weak_type=None): - return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val, - weak_type=weak_type) + weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type + dtype = dtypes.canonicalize_dtype(np.result_type(val)) + dtypes.check_valid_dtype(dtype) + sharding = core._get_abstract_sharding(val) + return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding) def masked_array_error(*args, **kwargs): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " diff --git a/jax/_src/api.py b/jax/_src/api.py index 652542571..0b3cb0871 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -56,7 +56,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import pjit from jax._src import xla_bridge as xb -from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray +from jax._src.core import eval_jaxpr, ShapedArray from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, flatten_axes, donation_vector, @@ -2188,9 +2188,9 @@ def _infer_src_sharding(src, x) -> Sharding | None: if isinstance(x, array.ArrayImpl): return x.sharding elif isinstance(x, core.Tracer): - aval = core.get_aval(x) - if isinstance(aval, ConcreteArray) and isinstance(aval.val, array.ArrayImpl): - return aval.val.sharding + val = x.to_concrete_value() + if val is not None and isinstance(val, array.ArrayImpl): + return val.sharding return None diff --git a/jax/_src/array.py b/jax/_src/array.py index 2f29f1376..30fedf4cf 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1184,7 +1184,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed): global_aval, out_sharding, committed=committed, _skip_checks=True ) pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler -pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler # Only used for Arrays that come out of pmap. def _array_local_result_handler(aval, sharding, indices): @@ -1197,7 +1196,6 @@ def _array_local_result_handler(aval, sharding, indices): aval, sharding, committed=True, _skip_checks=True ) pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler -pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler # Token handlers diff --git a/jax/_src/core.py b/jax/_src/core.py index 43cb5cc1e..7d912e3c2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -19,7 +19,7 @@ from collections.abc import (Callable, Collection, Hashable, Iterable, Iterator, from contextlib import contextmanager, ExitStack from dataclasses import dataclass import functools -from functools import partial, partialmethod, total_ordering +from functools import partial, total_ordering import gc import inspect import itertools as it @@ -696,6 +696,10 @@ class Tracer(typing.Array, metaclass=StrictABCMeta): def __len__(self): return self.aval._len(self) + def to_concrete_value(self): + # Should return the concrete value if there is one, or else None. + return None + @property def sharding(self): # This attribute is part of the jax.Array API, but only defined on concrete arrays. @@ -739,10 +743,12 @@ class Tracer(typing.Array, metaclass=StrictABCMeta): return self # Override for object equivalence checking def __bool__(self): + if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_bool_conversion(self) return self.aval._bool(self) def __int__(self): + if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_scalar_conversion(self) return self.aval._int(self) @@ -755,14 +761,17 @@ class Tracer(typing.Array, metaclass=StrictABCMeta): return self.aval._complex(self) def __hex__(self): + if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._hex(self) def __oct__(self): + if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._oct(self) def __index__(self): + if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._index(self) @@ -1393,12 +1402,16 @@ def get_aval(x): else: return concrete_aval(x) -def get_type(x): - aval = get_aval(x) - if isinstance(aval, ConcreteArray): - return raise_to_shaped(aval) +get_type = get_aval + +def is_concrete(x): + return to_concrete_value(x) is not None + +def to_concrete_value(x): + if isinstance(x, Tracer): + return x.to_concrete_value() else: - return aval + return x def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) @@ -1423,10 +1436,11 @@ def concrete_or_error(force: Any, val: Any, context=""): if force is None: force = lambda x: x if isinstance(val, Tracer): - if isinstance(val.aval, ConcreteArray): - return force(val.aval.val) - else: + maybe_concrete = val.to_concrete_value() + if maybe_concrete is None: raise ConcretizationTypeError(val, context) + else: + return force(maybe_concrete) else: return force(val) @@ -1578,7 +1592,7 @@ def _invalid_shape_error(shape: Shape, context: str=""): msg += f" {context}." if not config.dynamic_shapes.value and any( isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) - and not isinstance(get_aval(x), ConcreteArray) for x in shape): + and not is_concrete(x) for x in shape): msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " "smaller subfunctions.") for x in shape: @@ -1677,10 +1691,6 @@ def _get_shape_sharding_str(shape, spec): else: yield f"{s1}@{s2}" - -def _forward_to_value(self, fun, ignored_tracer, *args): - return fun(self.val, *args) - def _get_abstract_sharding(val): from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error @@ -1690,59 +1700,6 @@ def _get_abstract_sharding(val): val.sharding._normalized_spec(val.ndim)) return None -class ConcreteArray(ShapedArray): - __slots__ = ['val'] - array_abstraction_level = 0 - - def __init__(self, dtype, val, weak_type=None): - super().__init__( - np.shape(val), dtype, - weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type, - sharding=_get_abstract_sharding(val)) - dtypes.check_valid_dtype(self.dtype) - # Note: canonicalized self.dtype doesn't necessarily match self.val - assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype) - self.val = val - - def update(self, dtype=None, val=None, weak_type=None): - dtype = self.dtype if dtype is None else dtype - val = self.val if val is None else val - weak_type = self.weak_type if weak_type is None else weak_type - return ConcreteArray(dtype, val, weak_type) - - def __eq__(self, other): - if (type(self) is type(other) and self.dtype == other.dtype - and self.shape == other.shape and self.weak_type == other.weak_type): - with eval_context(): # in case self.val is an Array - return (self.val == other.val).all() - else: - return False - - def __hash__(self): - return id(self.val) - - def join(self, other) -> AbstractValue: - if self == other: - return self - elif self.shape == other.shape and self.dtype == other.dtype: - weak_type = self.weak_type and other.weak_type - return ShapedArray(self.shape, self.dtype, weak_type=weak_type) - else: - raise TypeError(self, other) - - def str_short(self, short_dtypes=False) -> str: - dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - return f'{self.val}, dtype={dt_str}' - - _bool = partialmethod(_forward_to_value, bool) - _int = partialmethod(_forward_to_value, int) - _hex = partialmethod(_forward_to_value, hex) - _oct = partialmethod(_forward_to_value, oct) - _index = partialmethod(_forward_to_value, operator.index) - - _float = concretization_function_error(float, True) - _complex = concretization_function_error(complex, True) - def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): return primal_dtype._rules.tangent_dtype(primal_dtype) @@ -1817,14 +1774,6 @@ class DShapedArray(UnshapedArray): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) -class DConcreteArray(DShapedArray): - __slots__ = ['val'] - array_abstraction_level = 1 - def __init__(self, shape, dtype, weak_type, val): - super().__init__(shape, dtype, weak_type) - self.val = val - - pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} @@ -1881,8 +1830,7 @@ class DArray: pytype_aval_mappings[DArray] = \ - lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, - x._data) + lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) @dataclass(frozen=True) class bint(dtypes.ExtendedDType): @@ -1984,10 +1932,7 @@ raise_to_shaped_mappings: dict[type, Callable] = { AbstractToken: lambda aval, _: aval, Bot: lambda aval, _: aval, ShapedArray: _shaped_array_mapping, - DShapedArray: lambda aval, _: aval, - DConcreteArray: lambda aval, weak_type: DShapedArray( - aval.shape, aval.dtype, weak_type - ), + DShapedArray: lambda aval, _: aval } ### Operations on shapes and dimension sizes. @@ -2323,7 +2268,6 @@ AvalMapHandlerPair = tuple[Callable, Callable] aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { DShapedArray: (_map_dshaped_array, _unmap_dshaped_array), ShapedArray: (_map_shaped_array, _unmap_shaped_array), - ConcreteArray: (_map_shaped_array, _unmap_shaped_array), AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a) } diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index b9cace3de..47c788237 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -429,6 +429,9 @@ class JVPTracer(Tracer): else: return self + def to_concrete_value(self): + return core.to_concrete_value(self.primal) + def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c71e52385..10154bbd6 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -230,7 +230,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes: raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err ir_type_handlers[core.ShapedArray] = _array_ir_types -ir_type_handlers[core.ConcreteArray] = _array_ir_types ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get() ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 9e5e1ee9b..2f63eb386 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -40,7 +40,7 @@ from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, fun_sourceinfo) from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - ConcreteArray, Var, DropVar, raise_to_shaped, Atom, + Var, DropVar, raise_to_shaped, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) @@ -299,7 +299,6 @@ class JaxprTrace(Trace['JaxprTracer']): # With dynamic shapes, we may need to substitute Tracers into avals. out_tracers = [] for aval, _ in out_type: - assert not isinstance(aval, ConcreteArray) if type(aval) is DShapedArray: shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val] if type(d) is InDBIdx else d for d in aval.shape] diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 14635a46e..46bc7bef7 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -25,7 +25,7 @@ import numpy as np from jax._src import core from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -101,7 +101,6 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: _xla_shape_handlers: dict[type[core.AbstractValue], Callable[[Any], Sequence[xc.Shape]]] = { ShapedArray: _make_array_shape, - ConcreteArray: _make_array_shape, } _xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index d189dc0bd..8dae3433e 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -35,7 +35,7 @@ from jax._src import source_info_util from jax._src import util from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect -from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects +from jax._src.core import raise_to_shaped, replace_jaxpr_effects from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -130,8 +130,7 @@ def switch(index, branches: Sequence[Callable], *operands, hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) - if (config.disable_jit.value and - isinstance(core.get_aval(index), ConcreteArray)): + if (config.disable_jit.value and core.is_concrete(index)): return branches[int(index)](*operands) ops, ops_tree = tree_flatten(operands) @@ -220,7 +219,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, msg = ("Pred type must be either boolean or number, got {}.") raise TypeError(msg.format(pred_dtype)) - if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray): + if config.disable_jit.value and core.is_concrete(pred): if pred: return true_fun(*operands) else: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 6d6338b0b..ddbbe0213 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -35,7 +35,7 @@ from jax._src import source_info_util from jax._src import state from jax._src import util from jax._src.api_util import shaped_abstractify -from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped +from jax._src.core import ShapedArray, raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -2015,12 +2015,11 @@ def fori_loop(lower, upper, body_fun, init_val, # If we can specialize on the trip count, call scan instead of a while_loop # to enable efficient reverse-mode differentiation. - if (isinstance(core.get_aval(lower), ConcreteArray) and - isinstance(core.get_aval(upper), ConcreteArray)): + if core.is_concrete(lower) and core.is_concrete(upper): try: lower_ = int(lower) upper_ = int(upper) - except TypeError: + except (TypeError, core.InconclusiveDimensionOperation): use_scan = False else: use_scan = True diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2a44d9ec9..e6dbcbb12 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -47,7 +47,7 @@ from jax._src import source_info_util from jax._src import state from jax._src import util from jax._src.abstract_arrays import array_types -from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, +from jax._src.core import (Primitive, UnshapedArray, ShapedArray, raise_to_shaped, abstract_token, canonicalize_shape) from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -582,8 +582,7 @@ def _convert_element_type( if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array) and - not (isinstance(operand, core.Tracer) and - isinstance(core.get_aval(operand), core.ConcreteArray)) and + not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and (sharding is None or getattr(operand, 'sharding', None) == sharding)): return operand else: @@ -1438,23 +1437,24 @@ def _get_monoid_reducer(monoid_op: Callable, x, = xs aval = core.get_aval(x) dtype = _dtype(x) - if (type(aval) is ConcreteArray) and aval.shape == (): + if core.is_concrete(x) and aval.shape == (): + val = core.to_concrete_value(x) # allow bitwise reductions for boolean and integer types _is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer) if monoid_op is add: - return _reduce_sum if np.equal(aval.val, 0) else None + return _reduce_sum if np.equal(val, 0) else None elif monoid_op is mul: - return _reduce_prod if np.equal(aval.val, 1) else None + return _reduce_prod if np.equal(val, 1) else None elif monoid_op is bitwise_or and _is_intlike: - return _reduce_or if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None + return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is bitwise_and and _is_intlike: - return _reduce_and if np.equal(aval.val, _get_bitwise_and_identity(dtype)) else None + return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None elif monoid_op is bitwise_xor and _is_intlike: - return _reduce_xor if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None + return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is max: - return _reduce_max if np.equal(aval.val, _get_max_identity(dtype)) else None + return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None elif monoid_op is min: - return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None + return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None return None def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray: diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index deb3c19c0..82804c796 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -52,10 +52,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = type(max(avals, key=_get_array_abstraction_level)) - if least_specialized is core.ConcreteArray: - out = prim.impl(*[x.val for x in avals], **kwargs) - return core.ConcreteArray(out.dtype, out, weak_type=weak_type) - elif least_specialized is core.ShapedArray: + if least_specialized is core.ShapedArray: return core.ShapedArray( shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, @@ -77,11 +74,7 @@ def standard_multi_result_abstract_eval( assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = max(map(type, avals), key=_get_array_abstraction_level) weak_types = weak_type_rule(*avals, **kwargs) - if least_specialized is core.ConcreteArray: - out_vals = prim.impl(*[x.val for x in avals], **kwargs) - return [core.ConcreteArray(val.dtype, val, weak_type=weak_type) - for val, weak_type in zip(out_vals, weak_types)] - elif least_specialized is core.ShapedArray: + if least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) return [core.ShapedArray(s, d, weak_type=weak_type) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 089a77de2..462e5fbed 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -23,7 +23,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import util -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -142,14 +142,15 @@ def _get_monoid_window_reducer( return None x, = xs aval = core.get_aval(x) - if (type(aval) is ConcreteArray) and aval.shape == (): + if core.is_concrete(x) and aval.shape == (): + val = core.to_concrete_value(x) if monoid_op is lax.add: - return aval.val == 0 and _reduce_window_sum + return val == 0 and _reduce_window_sum elif monoid_op is lax.max: - return (aval.val == lax._get_max_identity(aval.dtype) + return (val == lax._get_max_identity(aval.dtype) and _reduce_window_max) elif monoid_op is lax.min: - return (aval.val == lax._get_min_identity(aval.dtype) + return (val == lax._get_min_identity(aval.dtype) and _reduce_window_min) return None diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index dd8f671c6..08f94c6e8 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -276,7 +276,7 @@ def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed assert type(in_type) is tuple and all(type(e) is tuple for e in in_type) assert all(isinstance(a, core.AbstractValue) and type(b) is bool - and not isinstance(a, core.ConcreteArray) for a, b in in_type) + for a, b in in_type) def valid_size(d) -> bool: if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ee33be8a1..f79d6bc07 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -52,7 +52,7 @@ from jax._src import dtypes from jax._src import xla_bridge from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax as lax_internal from jax._src.lax.lax import ( PrecisionLike,_array_copy, @@ -11789,7 +11789,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], except TypeError: abstract_i = None # Handle basic int indexes. - if isinstance(abstract_i, (ConcreteArray, ShapedArray)) and _int(abstract_i): + if isinstance(abstract_i, ShapedArray) and _int(abstract_i): if core.definitely_equal(x_shape[x_axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") @@ -11945,7 +11945,7 @@ def _expand_bool_indices(idx, shape): i = array(i) abstract_i = core.get_aval(i) - if not type(abstract_i) is ConcreteArray: + if not core.is_concrete(i): # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete raise errors.NonConcreteBooleanIndexError(abstract_i) elif _ndim(i) == 0: @@ -11975,7 +11975,7 @@ def _is_slice_element_none_or_constant_or_symbolic(elt): if elt is None: return True if core.is_symbolic_dim(elt): return True try: - return type(core.get_aval(elt)) is ConcreteArray + return core.is_concrete(elt) except TypeError: return False diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index ade9cb206..8692c30a3 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2512,7 +2512,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # lax.pow. # Case 1: concrete integer scalar powers: - if isinstance(core.get_aval(x2), core.ConcreteArray): + if core.is_concrete(x2): try: x2 = operator.index(x2) # type: ignore[arg-type] except TypeError: diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index cb653547b..538f3f8e4 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -123,12 +123,7 @@ def _maybe_concretize(x: Any): # This is roughly the same logic as core.concrete_or_error, but we avoid # calling that because constructing the ConcretizationTypeError can be # expensive as the size of the tracing context (i.e. the jaxpr) grows. - if isinstance(x, core.Tracer): - if isinstance(x.aval, core.ConcreteArray): - return x.aval.val - else: - return None - return x + return core.to_concrete_value(x) @tree_util.register_pytree_node_class @dataclasses.dataclass diff --git a/jax/core.py b/jax/core.py index 6869f747b..2880e42c6 100644 --- a/jax/core.py +++ b/jax/core.py @@ -24,7 +24,6 @@ from jax._src.core import ( AxisName as AxisName, CallPrimitive as CallPrimitive, ClosedJaxpr as ClosedJaxpr, - ConcreteArray as ConcreteArray, ConcretizationTypeError as ConcretizationTypeError, DShapedArray as DShapedArray, DropVar as DropVar, @@ -84,6 +83,7 @@ from jax._src.core import ( get_aval as get_aval, get_type as get_type, get_referent as get_referent, + is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, jaxpr_as_fun as jaxpr_as_fun, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 2fa028b2f..615fd3128 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -911,13 +911,14 @@ class ShardMapTracer(core.Tracer): @property def aval(self): aval = core.get_aval(self.val) - if (isinstance(aval, core.ConcreteArray) and - self.rep == set(self._trace.mesh.axis_names)): + return core.mapped_aval(self._trace.mesh.size, 0, aval) + + def to_concrete_value(self): + if self.rep == set(self._trace.mesh.axis_names): with core.eval_context(): - return core.get_aval(self.val[0]) + return core.to_concrete_value(self.val[0]) else: - aval = core.raise_to_shaped(aval) - return core.mapped_aval(self._trace.mesh.size, 0, aval) + return None def __str__(self) -> str: with core.eval_context(): @@ -1768,6 +1769,9 @@ class RewriteTracer(core.Tracer): def aval(self) -> core.AbstractValue: return core.get_aval(self.val) + def to_concrete_value(self): + return core.to_concrete_value(self.val) + def __str__(self) -> str: return str(self.val) # TODO(mattjj): could show replication info here __repr__ = __str__ # for debuggers, like `p x` diff --git a/jax/experimental/slab/slab.py b/jax/experimental/slab/slab.py index af7b079ee..8324e4c55 100644 --- a/jax/experimental/slab/slab.py +++ b/jax/experimental/slab/slab.py @@ -89,7 +89,7 @@ def xprod(xs: Iterable[XInt]) -> XInt: return xmul(*list(xs)) def static_int(x: XInt) -> bool: - return isinstance(core.get_aval(x), core.ConcreteArray) + return core.is_concrete(x) def static_shape(s: DShape) -> bool: return all(map(static_int, s)) diff --git a/tests/api_test.py b/tests/api_test.py index 197784d99..e98f4299c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3691,18 +3691,6 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): g(1) - def test_join_concrete_arrays_with_omnistaging(self): - # https://github.com/jax-ml/jax/issues/4622 - x = jnp.array([1., 2., 3.]) - y = jnp.array([1., 2., 4.]) - - @jit - def f(): - core.lattice_join(core.ConcreteArray(x.dtype, x), - core.ConcreteArray(y.dtype, y)) - - f() # doesn't crash - def test_linearize_aux(self): def fn(x): return x * 2 - 3, x > 0 diff --git a/tests/core_test.py b/tests/core_test.py index 387000372..1471e334c 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -347,13 +347,6 @@ class CoreTest(jtu.JaxTestCase): 'This BatchTracer with object id'): g_vmap(jnp.ones((1, ))) - def test_concrete_array_string_representation(self): - # https://github.com/jax-ml/jax/issues/5364 - self.assertEqual( - str(core.ConcreteArray(np.dtype(np.int32), - np.array([1], dtype=np.int32))), - 'ConcreteArray([1], dtype=int32)') - def test_dropvar_avals(self): def f(x): def body(c, _):