mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
This commit is contained in:
parent
8536eca46e
commit
48f24b6acb
@ -24,9 +24,7 @@ from jax._src import dtypes
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
UnshapedArray = core.UnshapedArray
|
||||
ShapedArray = core.ShapedArray
|
||||
ConcreteArray = core.ConcreteArray
|
||||
AbstractToken = core.AbstractToken
|
||||
abstract_token = core.abstract_token
|
||||
canonicalize_shape = core.canonicalize_shape
|
||||
@ -47,8 +45,11 @@ if dtypes.int2 is not None:
|
||||
array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
|
||||
|
||||
def canonical_concrete_aval(val, weak_type=None):
|
||||
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
|
||||
weak_type=weak_type)
|
||||
weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type
|
||||
dtype = dtypes.canonicalize_dtype(np.result_type(val))
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
sharding = core._get_abstract_sharding(val)
|
||||
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)
|
||||
|
||||
def masked_array_error(*args, **kwargs):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
|
@ -56,7 +56,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import pjit
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
|
||||
from jax._src.core import eval_jaxpr, ShapedArray
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
flatten_axes, donation_vector,
|
||||
@ -2188,9 +2188,9 @@ def _infer_src_sharding(src, x) -> Sharding | None:
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
return x.sharding
|
||||
elif isinstance(x, core.Tracer):
|
||||
aval = core.get_aval(x)
|
||||
if isinstance(aval, ConcreteArray) and isinstance(aval.val, array.ArrayImpl):
|
||||
return aval.val.sharding
|
||||
val = x.to_concrete_value()
|
||||
if val is not None and isinstance(val, array.ArrayImpl):
|
||||
return val.sharding
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1184,7 +1184,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
global_aval, out_sharding, committed=committed, _skip_checks=True
|
||||
)
|
||||
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
|
||||
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
|
||||
|
||||
# Only used for Arrays that come out of pmap.
|
||||
def _array_local_result_handler(aval, sharding, indices):
|
||||
@ -1197,7 +1196,6 @@ def _array_local_result_handler(aval, sharding, indices):
|
||||
aval, sharding, committed=True, _skip_checks=True
|
||||
)
|
||||
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
|
||||
pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
|
||||
|
||||
|
||||
# Token handlers
|
||||
|
108
jax/_src/core.py
108
jax/_src/core.py
@ -19,7 +19,7 @@ from collections.abc import (Callable, Collection, Hashable, Iterable, Iterator,
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
from functools import partial, partialmethod, total_ordering
|
||||
from functools import partial, total_ordering
|
||||
import gc
|
||||
import inspect
|
||||
import itertools as it
|
||||
@ -696,6 +696,10 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
def __len__(self):
|
||||
return self.aval._len(self)
|
||||
|
||||
def to_concrete_value(self):
|
||||
# Should return the concrete value if there is one, or else None.
|
||||
return None
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
|
||||
@ -739,10 +743,12 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
return self # Override for object equivalence checking
|
||||
|
||||
def __bool__(self):
|
||||
if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types
|
||||
check_bool_conversion(self)
|
||||
return self.aval._bool(self)
|
||||
|
||||
def __int__(self):
|
||||
if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types
|
||||
check_scalar_conversion(self)
|
||||
return self.aval._int(self)
|
||||
|
||||
@ -755,14 +761,17 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
return self.aval._complex(self)
|
||||
|
||||
def __hex__(self):
|
||||
if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types
|
||||
check_integer_conversion(self)
|
||||
return self.aval._hex(self)
|
||||
|
||||
def __oct__(self):
|
||||
if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types
|
||||
check_integer_conversion(self)
|
||||
return self.aval._oct(self)
|
||||
|
||||
def __index__(self):
|
||||
if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types
|
||||
check_integer_conversion(self)
|
||||
return self.aval._index(self)
|
||||
|
||||
@ -1393,12 +1402,16 @@ def get_aval(x):
|
||||
else:
|
||||
return concrete_aval(x)
|
||||
|
||||
def get_type(x):
|
||||
aval = get_aval(x)
|
||||
if isinstance(aval, ConcreteArray):
|
||||
return raise_to_shaped(aval)
|
||||
get_type = get_aval
|
||||
|
||||
def is_concrete(x):
|
||||
return to_concrete_value(x) is not None
|
||||
|
||||
def to_concrete_value(x):
|
||||
if isinstance(x, Tracer):
|
||||
return x.to_concrete_value()
|
||||
else:
|
||||
return aval
|
||||
return x
|
||||
|
||||
def concretization_function_error(fun, suggest_astype=False):
|
||||
fname = getattr(fun, "__name__", fun)
|
||||
@ -1423,10 +1436,11 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
if force is None:
|
||||
force = lambda x: x
|
||||
if isinstance(val, Tracer):
|
||||
if isinstance(val.aval, ConcreteArray):
|
||||
return force(val.aval.val)
|
||||
else:
|
||||
maybe_concrete = val.to_concrete_value()
|
||||
if maybe_concrete is None:
|
||||
raise ConcretizationTypeError(val, context)
|
||||
else:
|
||||
return force(maybe_concrete)
|
||||
else:
|
||||
return force(val)
|
||||
|
||||
@ -1578,7 +1592,7 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
msg += f" {context}."
|
||||
if not config.dynamic_shapes.value and any(
|
||||
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
|
||||
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
|
||||
and not is_concrete(x) for x in shape):
|
||||
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
|
||||
"smaller subfunctions.")
|
||||
for x in shape:
|
||||
@ -1677,10 +1691,6 @@ def _get_shape_sharding_str(shape, spec):
|
||||
else:
|
||||
yield f"{s1}@{s2}"
|
||||
|
||||
|
||||
def _forward_to_value(self, fun, ignored_tracer, *args):
|
||||
return fun(self.val, *args)
|
||||
|
||||
def _get_abstract_sharding(val):
|
||||
from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error
|
||||
|
||||
@ -1690,59 +1700,6 @@ def _get_abstract_sharding(val):
|
||||
val.sharding._normalized_spec(val.ndim))
|
||||
return None
|
||||
|
||||
class ConcreteArray(ShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 0
|
||||
|
||||
def __init__(self, dtype, val, weak_type=None):
|
||||
super().__init__(
|
||||
np.shape(val), dtype,
|
||||
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type,
|
||||
sharding=_get_abstract_sharding(val))
|
||||
dtypes.check_valid_dtype(self.dtype)
|
||||
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
||||
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
|
||||
self.val = val
|
||||
|
||||
def update(self, dtype=None, val=None, weak_type=None):
|
||||
dtype = self.dtype if dtype is None else dtype
|
||||
val = self.val if val is None else val
|
||||
weak_type = self.weak_type if weak_type is None else weak_type
|
||||
return ConcreteArray(dtype, val, weak_type)
|
||||
|
||||
def __eq__(self, other):
|
||||
if (type(self) is type(other) and self.dtype == other.dtype
|
||||
and self.shape == other.shape and self.weak_type == other.weak_type):
|
||||
with eval_context(): # in case self.val is an Array
|
||||
return (self.val == other.val).all()
|
||||
else:
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return id(self.val)
|
||||
|
||||
def join(self, other) -> AbstractValue:
|
||||
if self == other:
|
||||
return self
|
||||
elif self.shape == other.shape and self.dtype == other.dtype:
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
return f'{self.val}, dtype={dt_str}'
|
||||
|
||||
_bool = partialmethod(_forward_to_value, bool)
|
||||
_int = partialmethod(_forward_to_value, int)
|
||||
_hex = partialmethod(_forward_to_value, hex)
|
||||
_oct = partialmethod(_forward_to_value, oct)
|
||||
_index = partialmethod(_forward_to_value, operator.index)
|
||||
|
||||
_float = concretization_function_error(float, True)
|
||||
_complex = concretization_function_error(complex, True)
|
||||
|
||||
def primal_dtype_to_tangent_dtype(primal_dtype):
|
||||
if isinstance(primal_dtype, dtypes.ExtendedDType):
|
||||
return primal_dtype._rules.tangent_dtype(primal_dtype)
|
||||
@ -1817,14 +1774,6 @@ class DShapedArray(UnshapedArray):
|
||||
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
class DConcreteArray(DShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 1
|
||||
def __init__(self, shape, dtype, weak_type, val):
|
||||
super().__init__(shape, dtype, weak_type)
|
||||
self.val = val
|
||||
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
|
||||
@ -1881,8 +1830,7 @@ class DArray:
|
||||
|
||||
|
||||
pytype_aval_mappings[DArray] = \
|
||||
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
|
||||
x._data)
|
||||
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class bint(dtypes.ExtendedDType):
|
||||
@ -1984,10 +1932,7 @@ raise_to_shaped_mappings: dict[type, Callable] = {
|
||||
AbstractToken: lambda aval, _: aval,
|
||||
Bot: lambda aval, _: aval,
|
||||
ShapedArray: _shaped_array_mapping,
|
||||
DShapedArray: lambda aval, _: aval,
|
||||
DConcreteArray: lambda aval, weak_type: DShapedArray(
|
||||
aval.shape, aval.dtype, weak_type
|
||||
),
|
||||
DShapedArray: lambda aval, _: aval
|
||||
}
|
||||
|
||||
### Operations on shapes and dimension sizes.
|
||||
@ -2323,7 +2268,6 @@ AvalMapHandlerPair = tuple[Callable, Callable]
|
||||
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
||||
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
|
||||
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
|
||||
}
|
||||
|
||||
|
@ -429,6 +429,9 @@ class JVPTracer(Tracer):
|
||||
else:
|
||||
return self
|
||||
|
||||
def to_concrete_value(self):
|
||||
return core.to_concrete_value(self.primal)
|
||||
|
||||
def _primal_tangent_shapes_match(primal, tangent):
|
||||
if type(tangent) is not Zero:
|
||||
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
|
||||
|
@ -230,7 +230,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes:
|
||||
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
|
||||
|
||||
ir_type_handlers[core.ShapedArray] = _array_ir_types
|
||||
ir_type_handlers[core.ConcreteArray] = _array_ir_types
|
||||
ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get()
|
||||
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
|
||||
|
||||
|
@ -40,7 +40,7 @@ from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
|
||||
fun_sourceinfo)
|
||||
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
|
||||
ConcreteArray, Var, DropVar, raise_to_shaped, Atom,
|
||||
Var, DropVar, raise_to_shaped, Atom,
|
||||
JaxprEqn, Primitive, ShapedArray, DShapedArray,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
@ -299,7 +299,6 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
# With dynamic shapes, we may need to substitute Tracers into avals.
|
||||
out_tracers = []
|
||||
for aval, _ in out_type:
|
||||
assert not isinstance(aval, ConcreteArray)
|
||||
if type(aval) is DShapedArray:
|
||||
shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val]
|
||||
if type(d) is InDBIdx else d for d in aval.shape]
|
||||
|
@ -25,7 +25,7 @@ import numpy as np
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
|
||||
from jax._src.typing import Shape
|
||||
@ -101,7 +101,6 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
|
||||
_xla_shape_handlers: dict[type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[xc.Shape]]] = {
|
||||
ShapedArray: _make_array_shape,
|
||||
ConcreteArray: _make_array_shape,
|
||||
}
|
||||
_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||
|
||||
|
@ -35,7 +35,7 @@ from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.core import raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -130,8 +130,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
hi = np.array(len(branches) - 1, np.int32)
|
||||
index = lax.clamp(lo, index, hi)
|
||||
|
||||
if (config.disable_jit.value and
|
||||
isinstance(core.get_aval(index), ConcreteArray)):
|
||||
if (config.disable_jit.value and core.is_concrete(index)):
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
@ -220,7 +219,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred_dtype))
|
||||
|
||||
if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray):
|
||||
if config.disable_jit.value and core.is_concrete(pred):
|
||||
if pred:
|
||||
return true_fun(*operands)
|
||||
else:
|
||||
|
@ -35,7 +35,7 @@ from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import util
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax._src.core import ShapedArray, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -2015,12 +2015,11 @@ def fori_loop(lower, upper, body_fun, init_val,
|
||||
|
||||
# If we can specialize on the trip count, call scan instead of a while_loop
|
||||
# to enable efficient reverse-mode differentiation.
|
||||
if (isinstance(core.get_aval(lower), ConcreteArray) and
|
||||
isinstance(core.get_aval(upper), ConcreteArray)):
|
||||
if core.is_concrete(lower) and core.is_concrete(upper):
|
||||
try:
|
||||
lower_ = int(lower)
|
||||
upper_ = int(upper)
|
||||
except TypeError:
|
||||
except (TypeError, core.InconclusiveDimensionOperation):
|
||||
use_scan = False
|
||||
else:
|
||||
use_scan = True
|
||||
|
@ -47,7 +47,7 @@ from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import util
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||
from jax._src.core import (Primitive, UnshapedArray, ShapedArray,
|
||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -582,8 +582,7 @@ def _convert_element_type(
|
||||
|
||||
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
|
||||
isinstance(operand, Array) and
|
||||
not (isinstance(operand, core.Tracer) and
|
||||
isinstance(core.get_aval(operand), core.ConcreteArray)) and
|
||||
not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and
|
||||
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
|
||||
return operand
|
||||
else:
|
||||
@ -1438,23 +1437,24 @@ def _get_monoid_reducer(monoid_op: Callable,
|
||||
x, = xs
|
||||
aval = core.get_aval(x)
|
||||
dtype = _dtype(x)
|
||||
if (type(aval) is ConcreteArray) and aval.shape == ():
|
||||
if core.is_concrete(x) and aval.shape == ():
|
||||
val = core.to_concrete_value(x)
|
||||
# allow bitwise reductions for boolean and integer types
|
||||
_is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer)
|
||||
if monoid_op is add:
|
||||
return _reduce_sum if np.equal(aval.val, 0) else None
|
||||
return _reduce_sum if np.equal(val, 0) else None
|
||||
elif monoid_op is mul:
|
||||
return _reduce_prod if np.equal(aval.val, 1) else None
|
||||
return _reduce_prod if np.equal(val, 1) else None
|
||||
elif monoid_op is bitwise_or and _is_intlike:
|
||||
return _reduce_or if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None
|
||||
return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
||||
elif monoid_op is bitwise_and and _is_intlike:
|
||||
return _reduce_and if np.equal(aval.val, _get_bitwise_and_identity(dtype)) else None
|
||||
return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None
|
||||
elif monoid_op is bitwise_xor and _is_intlike:
|
||||
return _reduce_xor if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None
|
||||
return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
||||
elif monoid_op is max:
|
||||
return _reduce_max if np.equal(aval.val, _get_max_identity(dtype)) else None
|
||||
return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None
|
||||
elif monoid_op is min:
|
||||
return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None
|
||||
return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None
|
||||
return None
|
||||
|
||||
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
|
@ -52,10 +52,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
assert not prim.multiple_results
|
||||
weak_type = weak_type_rule(*avals, **kwargs)
|
||||
least_specialized = type(max(avals, key=_get_array_abstraction_level))
|
||||
if least_specialized is core.ConcreteArray:
|
||||
out = prim.impl(*[x.val for x in avals], **kwargs)
|
||||
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
|
||||
elif least_specialized is core.ShapedArray:
|
||||
if least_specialized is core.ShapedArray:
|
||||
return core.ShapedArray(
|
||||
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
|
||||
weak_type=weak_type,
|
||||
@ -77,11 +74,7 @@ def standard_multi_result_abstract_eval(
|
||||
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
|
||||
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
|
||||
weak_types = weak_type_rule(*avals, **kwargs)
|
||||
if least_specialized is core.ConcreteArray:
|
||||
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
|
||||
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
|
||||
for val, weak_type in zip(out_vals, weak_types)]
|
||||
elif least_specialized is core.ShapedArray:
|
||||
if least_specialized is core.ShapedArray:
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
return [core.ShapedArray(s, d, weak_type=weak_type)
|
||||
|
@ -23,7 +23,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -142,14 +142,15 @@ def _get_monoid_window_reducer(
|
||||
return None
|
||||
x, = xs
|
||||
aval = core.get_aval(x)
|
||||
if (type(aval) is ConcreteArray) and aval.shape == ():
|
||||
if core.is_concrete(x) and aval.shape == ():
|
||||
val = core.to_concrete_value(x)
|
||||
if monoid_op is lax.add:
|
||||
return aval.val == 0 and _reduce_window_sum
|
||||
return val == 0 and _reduce_window_sum
|
||||
elif monoid_op is lax.max:
|
||||
return (aval.val == lax._get_max_identity(aval.dtype)
|
||||
return (val == lax._get_max_identity(aval.dtype)
|
||||
and _reduce_window_max)
|
||||
elif monoid_op is lax.min:
|
||||
return (aval.val == lax._get_min_identity(aval.dtype)
|
||||
return (val == lax._get_min_identity(aval.dtype)
|
||||
and _reduce_window_min)
|
||||
return None
|
||||
|
||||
|
@ -276,7 +276,7 @@ def _check_input_type(in_type: core.InputType) -> None:
|
||||
# Check that in_type is syntactically well-formed
|
||||
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
|
||||
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
|
||||
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
|
||||
for a, b in in_type)
|
||||
|
||||
def valid_size(d) -> bool:
|
||||
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
|
||||
|
@ -52,7 +52,7 @@ from jax._src import dtypes
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.custom_derivatives import custom_jvp
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.lax import ( PrecisionLike,_array_copy,
|
||||
@ -11789,7 +11789,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
except TypeError:
|
||||
abstract_i = None
|
||||
# Handle basic int indexes.
|
||||
if isinstance(abstract_i, (ConcreteArray, ShapedArray)) and _int(abstract_i):
|
||||
if isinstance(abstract_i, ShapedArray) and _int(abstract_i):
|
||||
if core.definitely_equal(x_shape[x_axis], 0):
|
||||
# XLA gives error when indexing into an axis of size 0
|
||||
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
|
||||
@ -11945,7 +11945,7 @@ def _expand_bool_indices(idx, shape):
|
||||
i = array(i)
|
||||
abstract_i = core.get_aval(i)
|
||||
|
||||
if not type(abstract_i) is ConcreteArray:
|
||||
if not core.is_concrete(i):
|
||||
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
|
||||
raise errors.NonConcreteBooleanIndexError(abstract_i)
|
||||
elif _ndim(i) == 0:
|
||||
@ -11975,7 +11975,7 @@ def _is_slice_element_none_or_constant_or_symbolic(elt):
|
||||
if elt is None: return True
|
||||
if core.is_symbolic_dim(elt): return True
|
||||
try:
|
||||
return type(core.get_aval(elt)) is ConcreteArray
|
||||
return core.is_concrete(elt)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
@ -2512,7 +2512,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
# lax.pow.
|
||||
|
||||
# Case 1: concrete integer scalar powers:
|
||||
if isinstance(core.get_aval(x2), core.ConcreteArray):
|
||||
if core.is_concrete(x2):
|
||||
try:
|
||||
x2 = operator.index(x2) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
|
@ -123,12 +123,7 @@ def _maybe_concretize(x: Any):
|
||||
# This is roughly the same logic as core.concrete_or_error, but we avoid
|
||||
# calling that because constructing the ConcretizationTypeError can be
|
||||
# expensive as the size of the tracing context (i.e. the jaxpr) grows.
|
||||
if isinstance(x, core.Tracer):
|
||||
if isinstance(x.aval, core.ConcreteArray):
|
||||
return x.aval.val
|
||||
else:
|
||||
return None
|
||||
return x
|
||||
return core.to_concrete_value(x)
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
|
@ -24,7 +24,6 @@ from jax._src.core import (
|
||||
AxisName as AxisName,
|
||||
CallPrimitive as CallPrimitive,
|
||||
ClosedJaxpr as ClosedJaxpr,
|
||||
ConcreteArray as ConcreteArray,
|
||||
ConcretizationTypeError as ConcretizationTypeError,
|
||||
DShapedArray as DShapedArray,
|
||||
DropVar as DropVar,
|
||||
@ -84,6 +83,7 @@ from jax._src.core import (
|
||||
get_aval as get_aval,
|
||||
get_type as get_type,
|
||||
get_referent as get_referent,
|
||||
is_concrete as is_concrete,
|
||||
is_constant_dim as is_constant_dim,
|
||||
is_constant_shape as is_constant_shape,
|
||||
jaxpr_as_fun as jaxpr_as_fun,
|
||||
|
@ -911,13 +911,14 @@ class ShardMapTracer(core.Tracer):
|
||||
@property
|
||||
def aval(self):
|
||||
aval = core.get_aval(self.val)
|
||||
if (isinstance(aval, core.ConcreteArray) and
|
||||
self.rep == set(self._trace.mesh.axis_names)):
|
||||
return core.mapped_aval(self._trace.mesh.size, 0, aval)
|
||||
|
||||
def to_concrete_value(self):
|
||||
if self.rep == set(self._trace.mesh.axis_names):
|
||||
with core.eval_context():
|
||||
return core.get_aval(self.val[0])
|
||||
return core.to_concrete_value(self.val[0])
|
||||
else:
|
||||
aval = core.raise_to_shaped(aval)
|
||||
return core.mapped_aval(self._trace.mesh.size, 0, aval)
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
with core.eval_context():
|
||||
@ -1768,6 +1769,9 @@ class RewriteTracer(core.Tracer):
|
||||
def aval(self) -> core.AbstractValue:
|
||||
return core.get_aval(self.val)
|
||||
|
||||
def to_concrete_value(self):
|
||||
return core.to_concrete_value(self.val)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.val) # TODO(mattjj): could show replication info here
|
||||
__repr__ = __str__ # for debuggers, like `p x`
|
||||
|
@ -89,7 +89,7 @@ def xprod(xs: Iterable[XInt]) -> XInt:
|
||||
return xmul(*list(xs))
|
||||
|
||||
def static_int(x: XInt) -> bool:
|
||||
return isinstance(core.get_aval(x), core.ConcreteArray)
|
||||
return core.is_concrete(x)
|
||||
|
||||
def static_shape(s: DShape) -> bool:
|
||||
return all(map(static_int, s))
|
||||
|
@ -3691,18 +3691,6 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
g(1)
|
||||
|
||||
def test_join_concrete_arrays_with_omnistaging(self):
|
||||
# https://github.com/jax-ml/jax/issues/4622
|
||||
x = jnp.array([1., 2., 3.])
|
||||
y = jnp.array([1., 2., 4.])
|
||||
|
||||
@jit
|
||||
def f():
|
||||
core.lattice_join(core.ConcreteArray(x.dtype, x),
|
||||
core.ConcreteArray(y.dtype, y))
|
||||
|
||||
f() # doesn't crash
|
||||
|
||||
def test_linearize_aux(self):
|
||||
def fn(x):
|
||||
return x * 2 - 3, x > 0
|
||||
|
@ -347,13 +347,6 @@ class CoreTest(jtu.JaxTestCase):
|
||||
'This BatchTracer with object id'):
|
||||
g_vmap(jnp.ones((1, )))
|
||||
|
||||
def test_concrete_array_string_representation(self):
|
||||
# https://github.com/jax-ml/jax/issues/5364
|
||||
self.assertEqual(
|
||||
str(core.ConcreteArray(np.dtype(np.int32),
|
||||
np.array([1], dtype=np.int32))),
|
||||
'ConcreteArray([1], dtype=int32)')
|
||||
|
||||
def test_dropvar_avals(self):
|
||||
def f(x):
|
||||
def body(c, _):
|
||||
|
Loading…
x
Reference in New Issue
Block a user