Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.

PiperOrigin-RevId: 691929385
This commit is contained in:
Dougal Maclaurin 2024-10-31 14:06:08 -07:00 committed by jax authors
parent 8536eca46e
commit 48f24b6acb
22 changed files with 83 additions and 168 deletions

View File

@ -24,9 +24,7 @@ from jax._src import dtypes
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
UnshapedArray = core.UnshapedArray
ShapedArray = core.ShapedArray
ConcreteArray = core.ConcreteArray
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
@ -47,8 +45,11 @@ if dtypes.int2 is not None:
array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
weak_type=weak_type)
weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type
dtype = dtypes.canonicalize_dtype(np.result_type(val))
dtypes.check_valid_dtype(dtype)
sharding = core._get_abstract_sharding(val)
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)
def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "

View File

@ -56,7 +56,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import pjit
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
from jax._src.core import eval_jaxpr, ShapedArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
@ -2188,9 +2188,9 @@ def _infer_src_sharding(src, x) -> Sharding | None:
if isinstance(x, array.ArrayImpl):
return x.sharding
elif isinstance(x, core.Tracer):
aval = core.get_aval(x)
if isinstance(aval, ConcreteArray) and isinstance(aval.val, array.ArrayImpl):
return aval.val.sharding
val = x.to_concrete_value()
if val is not None and isinstance(val, array.ArrayImpl):
return val.sharding
return None

View File

@ -1184,7 +1184,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
global_aval, out_sharding, committed=committed, _skip_checks=True
)
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
# Only used for Arrays that come out of pmap.
def _array_local_result_handler(aval, sharding, indices):
@ -1197,7 +1196,6 @@ def _array_local_result_handler(aval, sharding, indices):
aval, sharding, committed=True, _skip_checks=True
)
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
# Token handlers

View File

@ -19,7 +19,7 @@ from collections.abc import (Callable, Collection, Hashable, Iterable, Iterator,
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
import functools
from functools import partial, partialmethod, total_ordering
from functools import partial, total_ordering
import gc
import inspect
import itertools as it
@ -696,6 +696,10 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
def __len__(self):
return self.aval._len(self)
def to_concrete_value(self):
# Should return the concrete value if there is one, or else None.
return None
@property
def sharding(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
@ -739,10 +743,12 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
return self # Override for object equivalence checking
def __bool__(self):
if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_bool_conversion(self)
return self.aval._bool(self)
def __int__(self):
if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_scalar_conversion(self)
return self.aval._int(self)
@ -755,14 +761,17 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
return self.aval._complex(self)
def __hex__(self):
if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._hex(self)
def __oct__(self):
if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._oct(self)
def __index__(self):
if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._index(self)
@ -1393,12 +1402,16 @@ def get_aval(x):
else:
return concrete_aval(x)
def get_type(x):
aval = get_aval(x)
if isinstance(aval, ConcreteArray):
return raise_to_shaped(aval)
get_type = get_aval
def is_concrete(x):
return to_concrete_value(x) is not None
def to_concrete_value(x):
if isinstance(x, Tracer):
return x.to_concrete_value()
else:
return aval
return x
def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
@ -1423,10 +1436,11 @@ def concrete_or_error(force: Any, val: Any, context=""):
if force is None:
force = lambda x: x
if isinstance(val, Tracer):
if isinstance(val.aval, ConcreteArray):
return force(val.aval.val)
else:
maybe_concrete = val.to_concrete_value()
if maybe_concrete is None:
raise ConcretizationTypeError(val, context)
else:
return force(maybe_concrete)
else:
return force(val)
@ -1578,7 +1592,7 @@ def _invalid_shape_error(shape: Shape, context: str=""):
msg += f" {context}."
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
and not is_concrete(x) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
@ -1677,10 +1691,6 @@ def _get_shape_sharding_str(shape, spec):
else:
yield f"{s1}@{s2}"
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)
def _get_abstract_sharding(val):
from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error
@ -1690,59 +1700,6 @@ def _get_abstract_sharding(val):
val.sharding._normalized_spec(val.ndim))
return None
class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
def __init__(self, dtype, val, weak_type=None):
super().__init__(
np.shape(val), dtype,
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type,
sharding=_get_abstract_sharding(val))
dtypes.check_valid_dtype(self.dtype)
# Note: canonicalized self.dtype doesn't necessarily match self.val
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
self.val = val
def update(self, dtype=None, val=None, weak_type=None):
dtype = self.dtype if dtype is None else dtype
val = self.val if val is None else val
weak_type = self.weak_type if weak_type is None else weak_type
return ConcreteArray(dtype, val, weak_type)
def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type):
with eval_context(): # in case self.val is an Array
return (self.val == other.val).all()
else:
return False
def __hash__(self):
return id(self.val)
def join(self, other) -> AbstractValue:
if self == other:
return self
elif self.shape == other.shape and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return f'{self.val}, dtype={dt_str}'
_bool = partialmethod(_forward_to_value, bool)
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_index = partialmethod(_forward_to_value, operator.index)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
def primal_dtype_to_tangent_dtype(primal_dtype):
if isinstance(primal_dtype, dtypes.ExtendedDType):
return primal_dtype._rules.tangent_dtype(primal_dtype)
@ -1817,14 +1774,6 @@ class DShapedArray(UnshapedArray):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
class DConcreteArray(DShapedArray):
__slots__ = ['val']
array_abstraction_level = 1
def __init__(self, shape, dtype, weak_type, val):
super().__init__(shape, dtype, weak_type)
self.val = val
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
@ -1881,8 +1830,7 @@ class DArray:
pytype_aval_mappings[DArray] = \
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
x._data)
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
@ -1984,10 +1932,7 @@ raise_to_shaped_mappings: dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
ShapedArray: _shaped_array_mapping,
DShapedArray: lambda aval, _: aval,
DConcreteArray: lambda aval, weak_type: DShapedArray(
aval.shape, aval.dtype, weak_type
),
DShapedArray: lambda aval, _: aval
}
### Operations on shapes and dimension sizes.
@ -2323,7 +2268,6 @@ AvalMapHandlerPair = tuple[Callable, Callable]
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
}

View File

@ -429,6 +429,9 @@ class JVPTracer(Tracer):
else:
return self
def to_concrete_value(self):
return core.to_concrete_value(self.primal)
def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)

View File

@ -230,7 +230,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes:
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.ConcreteArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get()
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types

View File

@ -40,7 +40,7 @@ from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
fun_sourceinfo)
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, Var, DropVar, raise_to_shaped, Atom,
Var, DropVar, raise_to_shaped, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
@ -299,7 +299,6 @@ class JaxprTrace(Trace['JaxprTracer']):
# With dynamic shapes, we may need to substitute Tracers into avals.
out_tracers = []
for aval, _ in out_type:
assert not isinstance(aval, ConcreteArray)
if type(aval) is DShapedArray:
shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val]
if type(d) is InDBIdx else d for d in aval.shape]

View File

@ -25,7 +25,7 @@ import numpy as np
from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.core import ShapedArray
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Shape
@ -101,7 +101,6 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
_xla_shape_handlers: dict[type[core.AbstractValue],
Callable[[Any], Sequence[xc.Shape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)

View File

@ -35,7 +35,7 @@ from jax._src import source_info_util
from jax._src import util
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.core import raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -130,8 +130,7 @@ def switch(index, branches: Sequence[Callable], *operands,
hi = np.array(len(branches) - 1, np.int32)
index = lax.clamp(lo, index, hi)
if (config.disable_jit.value and
isinstance(core.get_aval(index), ConcreteArray)):
if (config.disable_jit.value and core.is_concrete(index)):
return branches[int(index)](*operands)
ops, ops_tree = tree_flatten(operands)
@ -220,7 +219,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))
if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray):
if config.disable_jit.value and core.is_concrete(pred):
if pred:
return true_fun(*operands)
else:

View File

@ -35,7 +35,7 @@ from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.api_util import shaped_abstractify
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -2015,12 +2015,11 @@ def fori_loop(lower, upper, body_fun, init_val,
# If we can specialize on the trip count, call scan instead of a while_loop
# to enable efficient reverse-mode differentiation.
if (isinstance(core.get_aval(lower), ConcreteArray) and
isinstance(core.get_aval(upper), ConcreteArray)):
if core.is_concrete(lower) and core.is_concrete(upper):
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
except (TypeError, core.InconclusiveDimensionOperation):
use_scan = False
else:
use_scan = True

View File

@ -47,7 +47,7 @@ from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
from jax._src.core import (Primitive, UnshapedArray, ShapedArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -582,8 +582,7 @@ def _convert_element_type(
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
isinstance(operand, Array) and
not (isinstance(operand, core.Tracer) and
isinstance(core.get_aval(operand), core.ConcreteArray)) and
not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
return operand
else:
@ -1438,23 +1437,24 @@ def _get_monoid_reducer(monoid_op: Callable,
x, = xs
aval = core.get_aval(x)
dtype = _dtype(x)
if (type(aval) is ConcreteArray) and aval.shape == ():
if core.is_concrete(x) and aval.shape == ():
val = core.to_concrete_value(x)
# allow bitwise reductions for boolean and integer types
_is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer)
if monoid_op is add:
return _reduce_sum if np.equal(aval.val, 0) else None
return _reduce_sum if np.equal(val, 0) else None
elif monoid_op is mul:
return _reduce_prod if np.equal(aval.val, 1) else None
return _reduce_prod if np.equal(val, 1) else None
elif monoid_op is bitwise_or and _is_intlike:
return _reduce_or if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None
return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None
elif monoid_op is bitwise_and and _is_intlike:
return _reduce_and if np.equal(aval.val, _get_bitwise_and_identity(dtype)) else None
return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None
elif monoid_op is bitwise_xor and _is_intlike:
return _reduce_xor if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None
return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None
elif monoid_op is max:
return _reduce_max if np.equal(aval.val, _get_max_identity(dtype)) else None
return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None
elif monoid_op is min:
return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None
return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None
return None
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:

View File

@ -52,10 +52,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
assert not prim.multiple_results
weak_type = weak_type_rule(*avals, **kwargs)
least_specialized = type(max(avals, key=_get_array_abstraction_level))
if least_specialized is core.ConcreteArray:
out = prim.impl(*[x.val for x in avals], **kwargs)
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
elif least_specialized is core.ShapedArray:
if least_specialized is core.ShapedArray:
return core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
@ -77,11 +74,7 @@ def standard_multi_result_abstract_eval(
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
weak_types = weak_type_rule(*avals, **kwargs)
if least_specialized is core.ConcreteArray:
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
for val, weak_type in zip(out_vals, weak_types)]
elif least_specialized is core.ShapedArray:
if least_specialized is core.ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
return [core.ShapedArray(s, d, weak_type=weak_type)

View File

@ -23,7 +23,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import util
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -142,14 +142,15 @@ def _get_monoid_window_reducer(
return None
x, = xs
aval = core.get_aval(x)
if (type(aval) is ConcreteArray) and aval.shape == ():
if core.is_concrete(x) and aval.shape == ():
val = core.to_concrete_value(x)
if monoid_op is lax.add:
return aval.val == 0 and _reduce_window_sum
return val == 0 and _reduce_window_sum
elif monoid_op is lax.max:
return (aval.val == lax._get_max_identity(aval.dtype)
return (val == lax._get_max_identity(aval.dtype)
and _reduce_window_max)
elif monoid_op is lax.min:
return (aval.val == lax._get_min_identity(aval.dtype)
return (val == lax._get_min_identity(aval.dtype)
and _reduce_window_min)
return None

View File

@ -276,7 +276,7 @@ def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
for a, b in in_type)
def valid_size(d) -> bool:
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:

View File

@ -52,7 +52,7 @@ from jax._src import dtypes
from jax._src import xla_bridge
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.core import ShapedArray
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import ( PrecisionLike,_array_copy,
@ -11789,7 +11789,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
except TypeError:
abstract_i = None
# Handle basic int indexes.
if isinstance(abstract_i, (ConcreteArray, ShapedArray)) and _int(abstract_i):
if isinstance(abstract_i, ShapedArray) and _int(abstract_i):
if core.definitely_equal(x_shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
@ -11945,7 +11945,7 @@ def _expand_bool_indices(idx, shape):
i = array(i)
abstract_i = core.get_aval(i)
if not type(abstract_i) is ConcreteArray:
if not core.is_concrete(i):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
elif _ndim(i) == 0:
@ -11975,7 +11975,7 @@ def _is_slice_element_none_or_constant_or_symbolic(elt):
if elt is None: return True
if core.is_symbolic_dim(elt): return True
try:
return type(core.get_aval(elt)) is ConcreteArray
return core.is_concrete(elt)
except TypeError:
return False

View File

@ -2512,7 +2512,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
# lax.pow.
# Case 1: concrete integer scalar powers:
if isinstance(core.get_aval(x2), core.ConcreteArray):
if core.is_concrete(x2):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:

View File

@ -123,12 +123,7 @@ def _maybe_concretize(x: Any):
# This is roughly the same logic as core.concrete_or_error, but we avoid
# calling that because constructing the ConcretizationTypeError can be
# expensive as the size of the tracing context (i.e. the jaxpr) grows.
if isinstance(x, core.Tracer):
if isinstance(x.aval, core.ConcreteArray):
return x.aval.val
else:
return None
return x
return core.to_concrete_value(x)
@tree_util.register_pytree_node_class
@dataclasses.dataclass

View File

@ -24,7 +24,6 @@ from jax._src.core import (
AxisName as AxisName,
CallPrimitive as CallPrimitive,
ClosedJaxpr as ClosedJaxpr,
ConcreteArray as ConcreteArray,
ConcretizationTypeError as ConcretizationTypeError,
DShapedArray as DShapedArray,
DropVar as DropVar,
@ -84,6 +83,7 @@ from jax._src.core import (
get_aval as get_aval,
get_type as get_type,
get_referent as get_referent,
is_concrete as is_concrete,
is_constant_dim as is_constant_dim,
is_constant_shape as is_constant_shape,
jaxpr_as_fun as jaxpr_as_fun,

View File

@ -911,13 +911,14 @@ class ShardMapTracer(core.Tracer):
@property
def aval(self):
aval = core.get_aval(self.val)
if (isinstance(aval, core.ConcreteArray) and
self.rep == set(self._trace.mesh.axis_names)):
return core.mapped_aval(self._trace.mesh.size, 0, aval)
def to_concrete_value(self):
if self.rep == set(self._trace.mesh.axis_names):
with core.eval_context():
return core.get_aval(self.val[0])
return core.to_concrete_value(self.val[0])
else:
aval = core.raise_to_shaped(aval)
return core.mapped_aval(self._trace.mesh.size, 0, aval)
return None
def __str__(self) -> str:
with core.eval_context():
@ -1768,6 +1769,9 @@ class RewriteTracer(core.Tracer):
def aval(self) -> core.AbstractValue:
return core.get_aval(self.val)
def to_concrete_value(self):
return core.to_concrete_value(self.val)
def __str__(self) -> str:
return str(self.val) # TODO(mattjj): could show replication info here
__repr__ = __str__ # for debuggers, like `p x`

View File

@ -89,7 +89,7 @@ def xprod(xs: Iterable[XInt]) -> XInt:
return xmul(*list(xs))
def static_int(x: XInt) -> bool:
return isinstance(core.get_aval(x), core.ConcreteArray)
return core.is_concrete(x)
def static_shape(s: DShape) -> bool:
return all(map(static_int, s))

View File

@ -3691,18 +3691,6 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
def test_join_concrete_arrays_with_omnistaging(self):
# https://github.com/jax-ml/jax/issues/4622
x = jnp.array([1., 2., 3.])
y = jnp.array([1., 2., 4.])
@jit
def f():
core.lattice_join(core.ConcreteArray(x.dtype, x),
core.ConcreteArray(y.dtype, y))
f() # doesn't crash
def test_linearize_aux(self):
def fn(x):
return x * 2 - 3, x > 0

View File

@ -347,13 +347,6 @@ class CoreTest(jtu.JaxTestCase):
'This BatchTracer with object id'):
g_vmap(jnp.ones((1, )))
def test_concrete_array_string_representation(self):
# https://github.com/jax-ml/jax/issues/5364
self.assertEqual(
str(core.ConcreteArray(np.dtype(np.int32),
np.array([1], dtype=np.int32))),
'ConcreteArray([1], dtype=int32)')
def test_dropvar_avals(self):
def f(x):
def body(c, _):