Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)

PiperOrigin-RevId: 692557993
This commit is contained in:
Dougal Maclaurin 2024-11-02 17:02:02 -07:00 committed by jax authors
parent d679c0abaa
commit ec39b592f7
24 changed files with 96 additions and 211 deletions

View File

@ -43,7 +43,8 @@ raw_jaxval_adders = {} # type: ignore
@add_jaxvals_p.def_abstract_eval
def add_abstract(x, y):
return core.lattice_join(x, y)
assert core.typematch(x, y)
return x
def zeros_like_aval(aval: core.AbstractValue) -> Array:
return aval_zeros_likers[type(aval)](aval)

View File

@ -368,7 +368,7 @@ class Var:
def __init__(self, suffix: str, aval: AbstractValue):
self.count = next(_var_counter)
self.suffix = suffix
self.aval = raise_to_shaped(aval)
self.aval = aval
# TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not
# care about variable ordering, but the downstream package kfac_jax does.
@ -662,7 +662,7 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
def _error_repr(self):
if self.aval is None:
return f"traced array with aval {self.aval}"
return f"traced array with shape {raise_to_shaped(self.aval).str_short()}"
return f"traced array with shape {self.aval.str_short()}"
def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
@ -1302,11 +1302,11 @@ class AbstractValue:
except AttributeError:
return self.__class__.__name__
def strip_weak_type(self) -> AbstractValue:
def update_weak_type(self, weak_type):
return self
def join(self, other):
raise NotImplementedError("must override")
def strip_weak_type(self) -> AbstractValue:
return self.update_weak_type(False)
def update(self, **kwargs):
raise NotImplementedError("must override")
@ -1314,7 +1314,6 @@ class AbstractValue:
def str_short(self, short_dtypes=False):
return str(self)
# For type signatures involving dynamic shapes, we use lists of abstract values
# which may contain (reverse) de Bruijn indices in their shapes.
class DBIdx(NamedTuple):
@ -1348,26 +1347,10 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
for v in jaxpr.invars]
return tuple(out)
class Bot(AbstractValue): pass
bot = Bot()
def lattice_join(x: AbstractValue | None,
y: AbstractValue | None) -> AbstractValue:
if x is None:
assert y is not None
return y
elif y is None:
return x
elif isinstance(x, type(y)):
return y.join(x)
elif isinstance(y, type(x)):
return x.join(y)
elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray):
# TODO(mattjj): remove this special case after dynamic shapes are integrated
return x.join(y)
else:
raise TypeError(x, y)
# TODO(dougalm): Deprecate. This is here for backwards compat.
def lattice_join(x, y):
assert typematch(x, y)
return x
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
@ -1530,9 +1513,8 @@ class UnshapedArray(AbstractValue):
def str_short(self, short_dtypes=False) -> str:
return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
def strip_weak_type(self):
"""Returns a copy of the aval with weak_type=False."""
return self.update(weak_type=False)
def update_weak_type(self, weak_type):
return self.update(weak_type=weak_type)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
@ -1656,13 +1638,6 @@ class ShapedArray(UnshapedArray):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
def join(self, other):
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
@ -1762,14 +1737,6 @@ class DShapedArray(UnshapedArray):
def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type))
def join(self, other):
if (definitely_equal_shape(self.shape, other.shape) and
self.dtype == other.dtype):
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
else:
raise TypeError(self, other)
def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
@ -1881,16 +1848,11 @@ def mutable_array_abstract_eval(init_aval):
@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
aval = raise_to_shaped(get_aval(init_val))
aval = get_aval(init_val)
return MutableArray(AbstractRef(aval), init_val)
class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
return self
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self, short_dtypes=False): return 'Tok'
def to_tangent_aval(self): return self
abstract_token: AbstractToken = AbstractToken()
@ -1910,30 +1872,9 @@ class Token:
pytype_aval_mappings[Token] = lambda _: abstract_token
def raise_to_shaped(aval: AbstractValue, weak_type=None):
aval_type = type(aval)
if aval_type is ShapedArray and weak_type is None:
return aval
if aval_type is DShapedArray and weak_type is None:
return aval
if weak_type is None:
weak_type = getattr(aval, 'weak_type', False)
for typ in aval_type.__mro__:
handler = raise_to_shaped_mappings.get(typ)
if handler: return handler(aval, weak_type)
raise TypeError(type(aval))
def _shaped_array_mapping(aval, weak_type):
if config.sharding_in_types.value:
return ShapedArray(aval.shape, aval.dtype, weak_type, sharding=aval.sharding)
return ShapedArray(aval.shape, aval.dtype, weak_type)
raise_to_shaped_mappings: dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
ShapedArray: _shaped_array_mapping,
DShapedArray: lambda aval, _: aval
}
# TODO(dougalm): Deprecate. This is just here for backwards compat.
def raise_to_shaped(aval):
return aval
### Operations on shapes and dimension sizes.
@ -2341,18 +2282,23 @@ def typecheck(aval: AbstractValue, x) -> bool:
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
"""Determine whether `aval` conforms to `aval_ref`. Ignores weak_type."""
try:
return typematch(aval_ref, lattice_join(aval_ref, aval))
return typematch(aval_ref, aval)
except TypeError:
return False
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
"""Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type."""
if aval1 == aval2: return True
# unequal avals may still represent the same type, because type is represented
# by avals at the shaped level, and because weak type tags aren't considered
# part of the type
return (raise_to_shaped(aval1, weak_type=False) ==
raise_to_shaped(aval2, weak_type=False))
def typematch(t1: AbstractValue, t2: AbstractValue) -> bool:
"""Determine whether `t1` and `t2` are equivalent. Ignores weak_type."""
t1 = t1.strip_weak_type()
t2 = t2.strip_weak_type()
if t1 == t2:
return True
elif (isinstance(t1, (ShapedArray, DShapedArray)) and
isinstance(t2, (ShapedArray, DShapedArray))):
# This case handles DShapedArray and shape polynomials. Alternatively we
# could try normalizing first and then doing simple equality.
return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape)
else:
return False
class JaxprTypeError(TypeError): pass

View File

@ -31,7 +31,6 @@ from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
from jax._src.core import raise_to_shaped
from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -81,7 +80,7 @@ def _flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
ans_flat, ans_tree = tree_flatten(ans)
ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat]
ans_avals = [core.get_aval(x) for x in ans_flat]
yield ans_flat, (ans_tree, ans_avals)
@ -287,7 +286,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
py_primals_out, py_tangents_out = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
tangents_out, out_tree2 = tree_flatten(py_tangents_out)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
primal_avals = [core.get_aval(x) for x in primals_out]
if out_tree != out_tree2:
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce primal and tangent outputs with equal container (pytree) "
@ -327,11 +326,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
primal_avals_out = [core.get_aval(x).strip_weak_type() for x in primals_out]
expected_tangent_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval()
core.get_aval(x).strip_weak_type().to_tangent_aval()
for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
tangent_avals_out = [core.get_aval(t).strip_weak_type()
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
if expected_tangent_avals_out != tangent_avals_out:
@ -606,7 +605,7 @@ class custom_vjp(Generic[ReturnValue]):
f_, dyn_args = lu.wrap_init(self.fun), args
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
@ -674,7 +673,7 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
py_primals_out, res = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
res, res_tree = tree_flatten(res)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
primal_avals = [core.get_aval(x) for x in primals_out]
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
@ -772,7 +771,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
msg = ("Custom VJP bwd rule must produce an output with the same "
"shape/dtypes as the args tuple of the primal function, but at "
f"output{keystr(kp)} the bwd rule produced an output of "
f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding "
f"shape/dtype {a_.str_short()} corresponding "
f"to an input of shape/dtype {a.str_short()}.")
raise ValueError(msg)
results.append(ct)
@ -831,7 +830,7 @@ def _custom_vjp_call_jaxpr_jvp(
_, res_tree = out_trees()
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
@ -1110,7 +1109,7 @@ def partition_list(choice, lst):
return out, merge
def abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
return core.get_aval(x)
### Custom transposition
@ -1211,7 +1210,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
lin_avals = map(abstractify, operands_lin)
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
f_jaxpr = _close_jaxpr(f_jaxpr)
out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)
out_avals = f_jaxpr.out_avals
t_in_tree = treedef_tuple((res_tree, out_tree()))
t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree)
@ -1265,7 +1264,7 @@ def _linear_call_transpose_rule(cts, *args, callee, transpose,
return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
def _linear_call_abstract_eval(*args, **kwargs):
return map(core.raise_to_shaped, kwargs['callee'].out_avals)
return kwargs['callee'].out_avals
linear_call_p = core.Primitive('linear_call')
linear_call_p.multiple_results = True
@ -1398,7 +1397,7 @@ def optimize_remat_of_custom_vjp_fwd(
in_tree, out_type)
flat_fwd = _fix_fwd_args(flat_fwd)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
prim_tree, res_tree = out_trees()

View File

@ -33,8 +33,7 @@ from jax._src.ad_util import (
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval)
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
from jax._src.dtypes import dtype, float0
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
as_hashable_function, weakref_lru_cache,
@ -362,7 +361,7 @@ class JVPTrace(Trace):
_, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
with core.set_current_trace(self.parent_trace):
tangents_in = map(instantiate_zeros, tangents_in)
@ -434,8 +433,8 @@ class JVPTracer(Tracer):
def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
primal_aval = get_aval(primal).strip_weak_type()
tangent_aval = get_aval(tangent).strip_weak_type()
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)

View File

@ -29,7 +29,7 @@ from jax._src import linear_util as lu
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
replace_rule_output_symbolic_zeros,
add_jaxvals, add_jaxvals_p)
from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName
from jax._src.core import Trace, Tracer, TraceTag, AxisName
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
@ -217,7 +217,7 @@ def _update_annotation(
for d in a.shape))
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
new_avals = [core.get_aval(s) for s in segment_lens]
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
for a, d in zip(avals, explicit_in_dims):
if isinstance(d, RaggedAxis):
@ -387,7 +387,7 @@ class BatchTracer(Tracer):
if config.enable_checks.value:
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
if type(batch_dim) is int:
aval = raise_to_shaped(core.get_aval(val))
aval = core.get_aval(val)
assert 0 <= batch_dim < len(aval.shape)
self._trace = trace
self.val = val
@ -396,7 +396,7 @@ class BatchTracer(Tracer):
@property
def aval(self):
aval = raise_to_shaped(core.get_aval(self.val))
aval = core.get_aval(self.val)
if self.batch_dim is not_mapped:
return aval
elif type(self.batch_dim) is int:

View File

@ -40,7 +40,7 @@ from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
fun_sourceinfo)
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
Var, DropVar, raise_to_shaped, Atom,
Var, DropVar, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
@ -162,8 +162,7 @@ class JaxprTrace(Trace['JaxprTracer']):
def new_instantiated_literal(self, val) -> JaxprTracer:
aval = get_aval(val)
return JaxprTracer(self, PartialVal.unknown(aval),
Literal(val, raise_to_shaped(aval)))
return JaxprTracer(self, PartialVal.unknown(aval), Literal(val, aval))
def new_instantiated_const(self, val) -> JaxprTracer:
aval = get_aval(val)
@ -201,7 +200,7 @@ class JaxprTrace(Trace['JaxprTracer']):
if const is None:
return tracer
else:
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
aval = get_aval(const).update_weak_type(np.isscalar(const))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
def process_primitive(self, primitive, tracers, params):
@ -715,7 +714,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
len(params["in_axes"]) == len(params["call_jaxpr"].invars))
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers]
out_avals = [t.aval for t in out_tracers]
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
@ -936,7 +935,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
f, in_pvals, instantiate=instantiate)
jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_)
out_unknowns = [not pval.is_known() for pval in out_pvals]
res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals]
res_avals = [core.get_aval(r) for r in residuals]
cell.append((out_unknowns, jaxpr_unknown, res_avals))
known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()]
return [*known_vals_out, *residuals]
@ -1567,7 +1566,7 @@ class DynamicJaxprTracer(core.Tracer):
return self if val is None else get_referent(val)
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return core.raise_to_shaped(x.aval)
return x.aval
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
@ -1827,7 +1826,9 @@ class DynamicJaxprTrace(core.Trace):
# TODO(mattjj): for ints, or hashable consts, don't rely on id
tracer = self.frame.constid_to_tracer.get(id(c))
if tracer is None:
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
aval = get_aval(c)
if hasattr(aval, "weak_type"):
aval = aval.update_weak_type(dtypes.is_weakly_typed(c))
aval = self._lift_tracers_in_aval(aval)
tracer = self._new_const(aval, c)
return tracer
@ -1892,8 +1893,7 @@ class DynamicJaxprTrace(core.Trace):
def process_call(self, call_primitive, f, explicit_tracers, params):
if f.in_type is None:
f = lu.annotate(f, tuple((raise_to_shaped(get_aval(t)), True)
for t in explicit_tracers))
f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers))
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
@ -2291,7 +2291,7 @@ def _collect_implicit(
for i, name in spec.items():
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
idxs[name] = DBIdx(next(counter))
implicit_types.append(raise_to_shaped(get_aval(x.shape[i])))
implicit_types.append(get_aval(x.shape[i]))
if isinstance(x, Tracer):
explicit_tracers.setdefault(id(x), explicit_idx) # use the first
@ -2310,7 +2310,7 @@ def _arg_type(
) -> AbstractValue:
# Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames.
aval = get_aval(x) # aval.shape could contain Tracers
if not spec: return core.raise_to_shaped(aval)
if not spec: return aval
shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d
for i, d in enumerate(aval.shape)]
assert not any(isinstance(d, Tracer) for d in shape)

View File

@ -35,7 +35,7 @@ from jax._src import source_info_util
from jax._src import util
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.core import raise_to_shaped, replace_jaxpr_effects
from jax._src.core import replace_jaxpr_effects
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -328,7 +328,7 @@ def _cond_abstract_eval(*avals, branches, **_):
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
return map(raise_to_shaped, branches[0].out_avals), joined_effects
return branches[0].out_avals, joined_effects
def _bcast_select(pred, on_true, on_false):
if np.ndim(pred) != np.ndim(on_true):
@ -676,7 +676,6 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
def _transpose_cond_jaxpr(jaxpr, num_res):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = map(raise_to_shaped, primal_avals)
@lu.wrap_init
def transposed(*args):
@ -693,7 +692,7 @@ def _cond_transpose(cts, *args, branches):
index, *ops = args
assert type(index) is not ad.UndefinedPrimal
linear = [type(x) is ad.UndefinedPrimal for x in ops]
in_avals = map(raise_to_shaped, branches[0].in_avals)
in_avals = branches[0].in_avals
num_res = len(ops) - sum(linear)
if any(isinstance(eff, RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
@ -701,8 +700,7 @@ def _cond_transpose(cts, *args, branches):
branches_trans = tuple(
_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
lin_in_avals = [raise_to_shaped(a, weak_type=False)
for a, l in zip(in_avals, linear) if l]
lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l]
assert all(core.typematch(out_aval, lin_in_aval)
for jaxpr in branches_trans
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))

View File

@ -35,7 +35,7 @@ from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.api_util import shaped_abstractify
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -262,7 +262,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
xs_avals = [core.get_aval(x) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
def _create_jaxpr(init):
@ -1370,7 +1370,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts,
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
return body_jaxpr.out_avals, joined_effects
def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr,

View File

@ -23,7 +23,6 @@ from jax._src import api
from jax._src import core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -300,7 +299,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
if num_aux > 0:
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
return _map(raise_to_shaped, args_to_raise)
return args_to_raise
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):

View File

@ -48,7 +48,7 @@ from jax._src import state
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.core import (Primitive, UnshapedArray, ShapedArray,
raise_to_shaped, abstract_token, canonicalize_shape)
abstract_token, canonicalize_shape)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -3044,7 +3044,7 @@ def _to_edtype_abstract_eval(x, *, edtype):
f" has a representation shape {rep_aval.shape} while the given "
f"representation array has shape {x.shape}, so the shape suffix "
f"does not match: given {shape_suffix} but required {rep_aval.shape}.")
return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype)
return x.update(shape=shape_prefix, dtype=edtype)
to_edtype_p = Primitive('to_edtype')
to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p))
@ -5246,7 +5246,7 @@ _INT_DTYPES = {
def _sort_abstract_eval(*args, **kwargs):
args = tuple(raise_to_shaped(arg) for arg in args)
args = tuple(args)
if any(arg.shape != args[0].shape for arg in args[1:]):
shapes = " ".join(str(a.shape) for a in args)
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
@ -6196,7 +6196,7 @@ def _eq_meet(a, b):
def _abstractify(x):
return raise_to_shaped(core.get_aval(x))
return core.get_aval(x)
def empty(dtype):

View File

@ -33,7 +33,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import util
from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape)
Primitive, ShapedArray, is_constant_dim, is_constant_shape)
from jax._src.extend import ffi
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -1289,7 +1289,6 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size):
def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size):
pivots = raise_to_shaped(pivots)
if isinstance(pivots, ShapedArray):
if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32):
raise ValueError(
@ -1421,7 +1420,6 @@ def _lu_impl(operand):
return lu, pivot, perm
def _lu_abstract_eval(operand):
operand = raise_to_shaped(operand)
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to LU decomposition must have ndims >= 2")

View File

@ -27,7 +27,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import sharding_impls
from jax._src.core import AxisName, ShapedArray, raise_to_shaped
from jax._src.core import AxisName, ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -636,7 +636,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes),
arg.dtype) for arg in args]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
@ -817,7 +817,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
def _raise_to_shaped_abstract_eval(x, *, axis_name, **params):
_check_axis_names(axis_name)
return raise_to_shaped(x)
return x
ppermute_p = core.Primitive('ppermute')
ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval)
@ -1019,13 +1019,12 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in,
def _all_to_all_effectful_abstract_eval(
x, axis_name, split_axis, concat_axis, axis_index_groups, tiled
input_aval, axis_name, split_axis, concat_axis, axis_index_groups, tiled
):
del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
_check_axis_names(axis_name)
input_aval = raise_to_shaped(x)
shape = list(input_aval.shape)
axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0])
assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size)
@ -1169,12 +1168,11 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
def _all_gather_effectful_abstract_eval(
x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
_check_axis_names(axis_name)
x_aval = raise_to_shaped(x)
new_shape = list(x_aval.shape)
if tiled:
new_shape[all_gather_dimension] *= axis_size
@ -1298,12 +1296,11 @@ def _reduce_scatter_lowering(
def _reduce_scatter_effectful_abstract_eval(
x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled
x_aval, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled
):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
_check_axis_names(axis_name)
x_aval = core.raise_to_shaped(x)
new_shape = list(x_aval.shape)
scatter_dim_input_size = x_aval.shape[scatter_dimension]
if tiled:

View File

@ -140,13 +140,6 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
self.memory_space,
))
def at_least_vspace(self):
"""Vector space method needed for AD."""
raise NotImplementedError
def join(self, other):
raise NotImplementedError
def str_short(self, short_dtypes=False):
dt_str = \
dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
@ -226,11 +219,6 @@ class AbstractMemoryRef(state.AbstractRef):
def __repr__(self) -> str:
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
def join(self, other):
assert isinstance(other, AbstractMemoryRef)
return AbstractMemoryRef(self.inner_aval.join(other.inner_aval),
self.memory_space)
def update(self, inner_aval=None, memory_space=None):
inner_aval = self.inner_aval if inner_aval is None else inner_aval
memory_space = self.memory_space if memory_space is None else memory_space
@ -262,13 +250,6 @@ class MemorySpace(enum.Enum):
return self.value
def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
return AbstractMemoryRef(
jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type),
ref_aval.memory_space)
jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped
@dataclasses.dataclass(frozen=True)
class PallasGridContext:
grid: GridMappingGrid

View File

@ -174,15 +174,6 @@ class SemaphoreType(enum.Enum):
class AbstractSemaphore(jax_core.AbstractValue):
sem_type: SemaphoreType
def join(self, other):
if not isinstance(other, AbstractSemaphore):
raise ValueError
if other.sem_type != self.sem_type:
raise ValueError
return self
jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval
@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):

View File

@ -529,7 +529,8 @@ assume_p.def_impl(lambda x, y: x)
@assume_p.def_abstract_eval
def _assume_abstract_eval(x, y):
return x.join(y)
assert jax_core.typematch(x, y)
return x
def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y):
return y if ctx.lowering_context.for_verification else x

View File

@ -458,15 +458,9 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef):
def __repr__(self) -> str:
return f'Accumulator{{{self.inner_aval.str_short()}}}'
def join(self, other):
return _as_accum(super().join(other))
def update(self, inner_aval=None, memory_space=None):
return _as_accum(super().update(inner_aval=None, memory_space=None))
def at_least_vspace(self):
return _as_accum(super().at_least_vspace())
def _getitem(self, tracer, idx):
from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error
arr = wgmma_accumulator_deref(tracer)
@ -483,10 +477,6 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef:
memory_space=ref.memory_space, # pytype: disable=attribute-error
)
def _ref_raise_to_shaped(ref_aval, weak_type):
return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type))
jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped
_WARPGROUP_AXIS_NAME = object()

View File

@ -567,7 +567,7 @@ def wgmma_accumulator_deref(acc):
@wgmma_accumulator_deref_p.def_effectful_abstract_eval
def _wgmma_accumulator_deref_abstract_eval(acc):
# Dereferencing implies flushing so we have a wgmma pipeline effect.
ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc
ret = acc.inner_aval if isinstance(acc, state.AbstractRef) else acc
assert isinstance(ret, jax_core.ShapedArray), acc
return ret, {gpu_core._wgmma_pipeline_effect}

View File

@ -230,7 +230,6 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms)
expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
@ -262,7 +261,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef,
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
val_aval = core.raise_to_shaped(val_aval)
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
assert isinstance(val_aval, core.ShapedArray)

View File

@ -291,15 +291,14 @@ class AbstractRef(core.AbstractValue):
raise AttributeError
return self.inner_aval.weak_type
def update_weak_type(self, weak_type):
return AbstractRef(self.inner_aval.update_weak_type(weak_type))
def update(self, inner_aval=None):
if inner_aval is None:
return AbstractRef(self.inner_aval)
return AbstractRef(inner_aval)
def join(self, other):
assert isinstance(other, AbstractRef)
return AbstractRef(self.inner_aval.join(other.inner_aval))
ndim = property(lambda self: len(self.shape))
size = property(lambda self: math.prod(self.shape))
@ -365,10 +364,6 @@ class AbstractRef(core.AbstractValue):
def __hash__(self):
return hash((self.__class__, self.inner_aval))
def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type):
return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type))
core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped
def _map_ref(size, axis, ref_aval):
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))

View File

@ -105,7 +105,6 @@ from jax._src.core import (
primitive_uses_outfeed as primitive_uses_outfeed,
pytype_aval_mappings as pytype_aval_mappings,
raise_to_shaped as raise_to_shaped,
raise_to_shaped_mappings as raise_to_shaped_mappings,
reset_trace_state as reset_trace_state,
set_current_trace as set_current_trace,
str_eqn_compact as str_eqn_compact,

View File

@ -533,15 +533,6 @@ class JaxprTypeChecks(jtu.JaxTestCase):
r"Variable '.+_test' not defined\n\nin equation:",
lambda: core.check_jaxpr(jaxpr))
@parameterized.parameters(
{'value': 0, 'weak_type': True},
{'value': np.int32(0), 'weak_type': False},
{'value': np.array([0]), 'weak_type': False}
)
def test_raise_to_shaped_weak_type(self, value, weak_type):
aval = core.raise_to_shaped(core.get_aval(value))
self.assertEqual(aval.weak_type, weak_type)
@jtu.with_config(jax_dynamic_shapes=True)
class DynamicShapesTest(jtu.JaxTestCase):

View File

@ -3821,7 +3821,7 @@ def shard_foo_array_handler(xs, shardings, layouts):
results = []
for x, sharding in safe_zip(xs, shardings):
device, = sharding._addressable_device_assignment
aval = core.raise_to_shaped(core.get_aval(x.data))
aval = core.get_aval(x.data)
results.append(pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]))
return results

View File

@ -870,8 +870,9 @@ class PallasCallDMATest(PallasBaseTest):
pl.run_scoped(scope)
return []
aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
in_avals = [aref, aref]
aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
in_avals = [aref1, aref2]
stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state_discharge.discharge_state(

View File

@ -746,9 +746,10 @@ class StateDischargeTest(jtu.JaxTestCase):
b_ref[...] = jnp.array(1., dtype=jnp.float32)
return a_ref[...], b_ref[...]
scalar_ref = shaped_array_ref((), jnp.float32)
scalar_ref_1 = shaped_array_ref((), jnp.float32)
scalar_ref_2 = shaped_array_ref((), jnp.float32)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [scalar_ref, scalar_ref])
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2])
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns)