mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
This commit is contained in:
parent
d679c0abaa
commit
ec39b592f7
@ -43,7 +43,8 @@ raw_jaxval_adders = {} # type: ignore
|
||||
|
||||
@add_jaxvals_p.def_abstract_eval
|
||||
def add_abstract(x, y):
|
||||
return core.lattice_join(x, y)
|
||||
assert core.typematch(x, y)
|
||||
return x
|
||||
|
||||
def zeros_like_aval(aval: core.AbstractValue) -> Array:
|
||||
return aval_zeros_likers[type(aval)](aval)
|
||||
|
112
jax/_src/core.py
112
jax/_src/core.py
@ -368,7 +368,7 @@ class Var:
|
||||
def __init__(self, suffix: str, aval: AbstractValue):
|
||||
self.count = next(_var_counter)
|
||||
self.suffix = suffix
|
||||
self.aval = raise_to_shaped(aval)
|
||||
self.aval = aval
|
||||
|
||||
# TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not
|
||||
# care about variable ordering, but the downstream package kfac_jax does.
|
||||
@ -662,7 +662,7 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
def _error_repr(self):
|
||||
if self.aval is None:
|
||||
return f"traced array with aval {self.aval}"
|
||||
return f"traced array with shape {raise_to_shaped(self.aval).str_short()}"
|
||||
return f"traced array with shape {self.aval.str_short()}"
|
||||
|
||||
def __array__(self, *args, **kw):
|
||||
raise TracerArrayConversionError(self)
|
||||
@ -1302,11 +1302,11 @@ class AbstractValue:
|
||||
except AttributeError:
|
||||
return self.__class__.__name__
|
||||
|
||||
def strip_weak_type(self) -> AbstractValue:
|
||||
def update_weak_type(self, weak_type):
|
||||
return self
|
||||
|
||||
def join(self, other):
|
||||
raise NotImplementedError("must override")
|
||||
def strip_weak_type(self) -> AbstractValue:
|
||||
return self.update_weak_type(False)
|
||||
|
||||
def update(self, **kwargs):
|
||||
raise NotImplementedError("must override")
|
||||
@ -1314,7 +1314,6 @@ class AbstractValue:
|
||||
def str_short(self, short_dtypes=False):
|
||||
return str(self)
|
||||
|
||||
|
||||
# For type signatures involving dynamic shapes, we use lists of abstract values
|
||||
# which may contain (reverse) de Bruijn indices in their shapes.
|
||||
class DBIdx(NamedTuple):
|
||||
@ -1348,26 +1347,10 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
|
||||
for v in jaxpr.invars]
|
||||
return tuple(out)
|
||||
|
||||
class Bot(AbstractValue): pass
|
||||
bot = Bot()
|
||||
|
||||
|
||||
def lattice_join(x: AbstractValue | None,
|
||||
y: AbstractValue | None) -> AbstractValue:
|
||||
if x is None:
|
||||
assert y is not None
|
||||
return y
|
||||
elif y is None:
|
||||
return x
|
||||
elif isinstance(x, type(y)):
|
||||
return y.join(x)
|
||||
elif isinstance(y, type(x)):
|
||||
return x.join(y)
|
||||
elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray):
|
||||
# TODO(mattjj): remove this special case after dynamic shapes are integrated
|
||||
return x.join(y)
|
||||
else:
|
||||
raise TypeError(x, y)
|
||||
# TODO(dougalm): Deprecate. This is here for backwards compat.
|
||||
def lattice_join(x, y):
|
||||
assert typematch(x, y)
|
||||
return x
|
||||
|
||||
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
|
||||
Value = Any
|
||||
@ -1530,9 +1513,8 @@ class UnshapedArray(AbstractValue):
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
|
||||
def strip_weak_type(self):
|
||||
"""Returns a copy of the aval with weak_type=False."""
|
||||
return self.update(weak_type=False)
|
||||
def update_weak_type(self, weak_type):
|
||||
return self.update(weak_type=weak_type)
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
# Dimensions are most commonly integral (by far), so we check that first.
|
||||
@ -1656,13 +1638,6 @@ class ShapedArray(UnshapedArray):
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
def join(self, other):
|
||||
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return self.update(weak_type=weak_type)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
self.dtype.name)
|
||||
@ -1762,14 +1737,6 @@ class DShapedArray(UnshapedArray):
|
||||
def __hash__(self):
|
||||
return hash((self.shape, self.dtype, self.weak_type))
|
||||
|
||||
def join(self, other):
|
||||
if (definitely_equal_shape(self.shape, other.shape) and
|
||||
self.dtype == other.dtype):
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return self.update(weak_type=weak_type)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def to_tangent_aval(self):
|
||||
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
@ -1881,16 +1848,11 @@ def mutable_array_abstract_eval(init_aval):
|
||||
@mutable_array_p.def_impl
|
||||
def _mutable_array_impl(init_val):
|
||||
from jax._src.state.types import AbstractRef # pytype: disable=import-error
|
||||
aval = raise_to_shaped(get_aval(init_val))
|
||||
aval = get_aval(init_val)
|
||||
return MutableArray(AbstractRef(aval), init_val)
|
||||
|
||||
|
||||
class AbstractToken(AbstractValue):
|
||||
def join(self, other):
|
||||
if isinstance(other, AbstractToken):
|
||||
return self
|
||||
else:
|
||||
assert False, f"Cannot join {self} with {other}"
|
||||
def str_short(self, short_dtypes=False): return 'Tok'
|
||||
def to_tangent_aval(self): return self
|
||||
abstract_token: AbstractToken = AbstractToken()
|
||||
@ -1910,30 +1872,9 @@ class Token:
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
|
||||
|
||||
def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
||||
aval_type = type(aval)
|
||||
if aval_type is ShapedArray and weak_type is None:
|
||||
return aval
|
||||
if aval_type is DShapedArray and weak_type is None:
|
||||
return aval
|
||||
if weak_type is None:
|
||||
weak_type = getattr(aval, 'weak_type', False)
|
||||
for typ in aval_type.__mro__:
|
||||
handler = raise_to_shaped_mappings.get(typ)
|
||||
if handler: return handler(aval, weak_type)
|
||||
raise TypeError(type(aval))
|
||||
|
||||
def _shaped_array_mapping(aval, weak_type):
|
||||
if config.sharding_in_types.value:
|
||||
return ShapedArray(aval.shape, aval.dtype, weak_type, sharding=aval.sharding)
|
||||
return ShapedArray(aval.shape, aval.dtype, weak_type)
|
||||
|
||||
raise_to_shaped_mappings: dict[type, Callable] = {
|
||||
AbstractToken: lambda aval, _: aval,
|
||||
Bot: lambda aval, _: aval,
|
||||
ShapedArray: _shaped_array_mapping,
|
||||
DShapedArray: lambda aval, _: aval
|
||||
}
|
||||
# TODO(dougalm): Deprecate. This is just here for backwards compat.
|
||||
def raise_to_shaped(aval):
|
||||
return aval
|
||||
|
||||
### Operations on shapes and dimension sizes.
|
||||
|
||||
@ -2341,18 +2282,23 @@ def typecheck(aval: AbstractValue, x) -> bool:
|
||||
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
|
||||
"""Determine whether `aval` conforms to `aval_ref`. Ignores weak_type."""
|
||||
try:
|
||||
return typematch(aval_ref, lattice_join(aval_ref, aval))
|
||||
return typematch(aval_ref, aval)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
|
||||
"""Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type."""
|
||||
if aval1 == aval2: return True
|
||||
# unequal avals may still represent the same type, because type is represented
|
||||
# by avals at the shaped level, and because weak type tags aren't considered
|
||||
# part of the type
|
||||
return (raise_to_shaped(aval1, weak_type=False) ==
|
||||
raise_to_shaped(aval2, weak_type=False))
|
||||
def typematch(t1: AbstractValue, t2: AbstractValue) -> bool:
|
||||
"""Determine whether `t1` and `t2` are equivalent. Ignores weak_type."""
|
||||
t1 = t1.strip_weak_type()
|
||||
t2 = t2.strip_weak_type()
|
||||
if t1 == t2:
|
||||
return True
|
||||
elif (isinstance(t1, (ShapedArray, DShapedArray)) and
|
||||
isinstance(t2, (ShapedArray, DShapedArray))):
|
||||
# This case handles DShapedArray and shape polynomials. Alternatively we
|
||||
# could try normalizing first and then doing simple equality.
|
||||
return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape)
|
||||
else:
|
||||
return False
|
||||
|
||||
class JaxprTypeError(TypeError): pass
|
||||
|
||||
|
@ -31,7 +31,6 @@ from jax._src.ad_util import (
|
||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||
from jax._src.api_util import (
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax._src.errors import UnexpectedTracerError
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -81,7 +80,7 @@ def _flatten_fun_nokwargs(in_tree, *args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
ans = yield py_args, {}
|
||||
ans_flat, ans_tree = tree_flatten(ans)
|
||||
ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat]
|
||||
ans_avals = [core.get_aval(x) for x in ans_flat]
|
||||
yield ans_flat, (ans_tree, ans_avals)
|
||||
|
||||
|
||||
@ -287,7 +286,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
|
||||
py_primals_out, py_tangents_out = pair_out
|
||||
primals_out, out_tree = tree_flatten(py_primals_out)
|
||||
tangents_out, out_tree2 = tree_flatten(py_tangents_out)
|
||||
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
primal_avals = [core.get_aval(x) for x in primals_out]
|
||||
if out_tree != out_tree2:
|
||||
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
|
||||
"produce primal and tangent outputs with equal container (pytree) "
|
||||
@ -327,11 +326,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
|
||||
"shapes/dtypes of:\n"
|
||||
f""" {str(ty_tree_).replace("'", "")}""")
|
||||
raise TypeError(m)
|
||||
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
|
||||
primal_avals_out = [core.get_aval(x).strip_weak_type() for x in primals_out]
|
||||
expected_tangent_avals_out = [
|
||||
raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval()
|
||||
core.get_aval(x).strip_weak_type().to_tangent_aval()
|
||||
for x in primals_out]
|
||||
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
|
||||
tangent_avals_out = [core.get_aval(t).strip_weak_type()
|
||||
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
|
||||
for t in tangents_out]
|
||||
if expected_tangent_avals_out != tangent_avals_out:
|
||||
@ -606,7 +605,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
@ -674,7 +673,7 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
||||
py_primals_out, res = pair_out
|
||||
primals_out, out_tree = tree_flatten(py_primals_out)
|
||||
res, res_tree = tree_flatten(res)
|
||||
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
primal_avals = [core.get_aval(x) for x in primals_out]
|
||||
# If the primal function already ran, check out_tree agreement.
|
||||
try: out_type_ = maybe_out_type()
|
||||
except lu.StoreException: out_type_ = None
|
||||
@ -772,7 +771,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
msg = ("Custom VJP bwd rule must produce an output with the same "
|
||||
"shape/dtypes as the args tuple of the primal function, but at "
|
||||
f"output{keystr(kp)} the bwd rule produced an output of "
|
||||
f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding "
|
||||
f"shape/dtype {a_.str_short()} corresponding "
|
||||
f"to an input of shape/dtype {a.str_short()}.")
|
||||
raise ValueError(msg)
|
||||
results.append(ct)
|
||||
@ -831,7 +830,7 @@ def _custom_vjp_call_jaxpr_jvp(
|
||||
_, res_tree = out_trees()
|
||||
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
|
||||
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
|
||||
args_dot = map(ad.instantiate_zeros, args_dot)
|
||||
tangents_out = ad.custom_lin_p.bind(
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
@ -1110,7 +1109,7 @@ def partition_list(choice, lst):
|
||||
return out, merge
|
||||
|
||||
def abstractify(x):
|
||||
return core.raise_to_shaped(core.get_aval(x))
|
||||
return core.get_aval(x)
|
||||
|
||||
|
||||
### Custom transposition
|
||||
@ -1211,7 +1210,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
|
||||
lin_avals = map(abstractify, operands_lin)
|
||||
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
|
||||
f_jaxpr = _close_jaxpr(f_jaxpr)
|
||||
out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)
|
||||
out_avals = f_jaxpr.out_avals
|
||||
|
||||
t_in_tree = treedef_tuple((res_tree, out_tree()))
|
||||
t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree)
|
||||
@ -1265,7 +1264,7 @@ def _linear_call_transpose_rule(cts, *args, callee, transpose,
|
||||
return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
|
||||
|
||||
def _linear_call_abstract_eval(*args, **kwargs):
|
||||
return map(core.raise_to_shaped, kwargs['callee'].out_avals)
|
||||
return kwargs['callee'].out_avals
|
||||
|
||||
linear_call_p = core.Primitive('linear_call')
|
||||
linear_call_p.multiple_results = True
|
||||
@ -1398,7 +1397,7 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
in_tree, out_type)
|
||||
flat_fwd = _fix_fwd_args(flat_fwd)
|
||||
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
|
||||
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
|
||||
prim_tree, res_tree = out_trees()
|
||||
|
@ -33,8 +33,7 @@ from jax._src.ad_util import (
|
||||
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval)
|
||||
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
|
||||
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
|
||||
as_hashable_function, weakref_lru_cache,
|
||||
@ -362,7 +361,7 @@ class JVPTrace(Trace):
|
||||
|
||||
_, res_tree = out_trees()
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
|
||||
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
|
||||
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
|
||||
with core.set_current_trace(self.parent_trace):
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
@ -434,8 +433,8 @@ class JVPTracer(Tracer):
|
||||
|
||||
def _primal_tangent_shapes_match(primal, tangent):
|
||||
if type(tangent) is not Zero:
|
||||
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
|
||||
tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
|
||||
primal_aval = get_aval(primal).strip_weak_type()
|
||||
tangent_aval = get_aval(tangent).strip_weak_type()
|
||||
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
|
||||
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
|
||||
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
|
||||
|
@ -29,7 +29,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
|
||||
replace_rule_output_symbolic_zeros,
|
||||
add_jaxvals, add_jaxvals_p)
|
||||
from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName
|
||||
from jax._src.core import Trace, Tracer, TraceTag, AxisName
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
@ -217,7 +217,7 @@ def _update_annotation(
|
||||
for d in a.shape))
|
||||
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
|
||||
|
||||
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
|
||||
new_avals = [core.get_aval(s) for s in segment_lens]
|
||||
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
|
||||
for a, d in zip(avals, explicit_in_dims):
|
||||
if isinstance(d, RaggedAxis):
|
||||
@ -387,7 +387,7 @@ class BatchTracer(Tracer):
|
||||
if config.enable_checks.value:
|
||||
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
|
||||
if type(batch_dim) is int:
|
||||
aval = raise_to_shaped(core.get_aval(val))
|
||||
aval = core.get_aval(val)
|
||||
assert 0 <= batch_dim < len(aval.shape)
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
@ -396,7 +396,7 @@ class BatchTracer(Tracer):
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
aval = raise_to_shaped(core.get_aval(self.val))
|
||||
aval = core.get_aval(self.val)
|
||||
if self.batch_dim is not_mapped:
|
||||
return aval
|
||||
elif type(self.batch_dim) is int:
|
||||
|
@ -40,7 +40,7 @@ from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
|
||||
fun_sourceinfo)
|
||||
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
|
||||
Var, DropVar, raise_to_shaped, Atom,
|
||||
Var, DropVar, Atom,
|
||||
JaxprEqn, Primitive, ShapedArray, DShapedArray,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
@ -162,8 +162,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
|
||||
def new_instantiated_literal(self, val) -> JaxprTracer:
|
||||
aval = get_aval(val)
|
||||
return JaxprTracer(self, PartialVal.unknown(aval),
|
||||
Literal(val, raise_to_shaped(aval)))
|
||||
return JaxprTracer(self, PartialVal.unknown(aval), Literal(val, aval))
|
||||
|
||||
def new_instantiated_const(self, val) -> JaxprTracer:
|
||||
aval = get_aval(val)
|
||||
@ -201,7 +200,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
if const is None:
|
||||
return tracer
|
||||
else:
|
||||
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
|
||||
aval = get_aval(const).update_weak_type(np.isscalar(const))
|
||||
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
@ -715,7 +714,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
|
||||
len(params["in_axes"]) == len(params["call_jaxpr"].invars))
|
||||
assert ("donated_invars" in params and
|
||||
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
|
||||
out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers]
|
||||
out_avals = [t.aval for t in out_tracers]
|
||||
ctx = ctx or JaxprEqnContext(
|
||||
compute_on.current_compute_type(),
|
||||
config.threefry_partitionable.value,
|
||||
@ -936,7 +935,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
|
||||
f, in_pvals, instantiate=instantiate)
|
||||
jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_)
|
||||
out_unknowns = [not pval.is_known() for pval in out_pvals]
|
||||
res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals]
|
||||
res_avals = [core.get_aval(r) for r in residuals]
|
||||
cell.append((out_unknowns, jaxpr_unknown, res_avals))
|
||||
known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()]
|
||||
return [*known_vals_out, *residuals]
|
||||
@ -1567,7 +1566,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
return self if val is None else get_referent(val)
|
||||
|
||||
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
|
||||
return core.raise_to_shaped(x.aval)
|
||||
return x.aval
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
|
||||
|
||||
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
@ -1827,7 +1826,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
# TODO(mattjj): for ints, or hashable consts, don't rely on id
|
||||
tracer = self.frame.constid_to_tracer.get(id(c))
|
||||
if tracer is None:
|
||||
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
|
||||
aval = get_aval(c)
|
||||
if hasattr(aval, "weak_type"):
|
||||
aval = aval.update_weak_type(dtypes.is_weakly_typed(c))
|
||||
aval = self._lift_tracers_in_aval(aval)
|
||||
tracer = self._new_const(aval, c)
|
||||
return tracer
|
||||
@ -1892,8 +1893,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
def process_call(self, call_primitive, f, explicit_tracers, params):
|
||||
if f.in_type is None:
|
||||
f = lu.annotate(f, tuple((raise_to_shaped(get_aval(t)), True)
|
||||
for t in explicit_tracers))
|
||||
f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers))
|
||||
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
||||
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
|
||||
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
|
||||
@ -2291,7 +2291,7 @@ def _collect_implicit(
|
||||
for i, name in spec.items():
|
||||
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
|
||||
idxs[name] = DBIdx(next(counter))
|
||||
implicit_types.append(raise_to_shaped(get_aval(x.shape[i])))
|
||||
implicit_types.append(get_aval(x.shape[i]))
|
||||
if isinstance(x, Tracer):
|
||||
explicit_tracers.setdefault(id(x), explicit_idx) # use the first
|
||||
|
||||
@ -2310,7 +2310,7 @@ def _arg_type(
|
||||
) -> AbstractValue:
|
||||
# Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames.
|
||||
aval = get_aval(x) # aval.shape could contain Tracers
|
||||
if not spec: return core.raise_to_shaped(aval)
|
||||
if not spec: return aval
|
||||
shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d
|
||||
for i, d in enumerate(aval.shape)]
|
||||
assert not any(isinstance(d, Tracer) for d in shape)
|
||||
|
@ -35,7 +35,7 @@ from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
from jax._src.core import raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.core import replace_jaxpr_effects
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -328,7 +328,7 @@ def _cond_abstract_eval(*avals, branches, **_):
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
return map(raise_to_shaped, branches[0].out_avals), joined_effects
|
||||
return branches[0].out_avals, joined_effects
|
||||
|
||||
def _bcast_select(pred, on_true, on_false):
|
||||
if np.ndim(pred) != np.ndim(on_true):
|
||||
@ -676,7 +676,6 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
|
||||
|
||||
def _transpose_cond_jaxpr(jaxpr, num_res):
|
||||
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
||||
primal_avals = map(raise_to_shaped, primal_avals)
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(*args):
|
||||
@ -693,7 +692,7 @@ def _cond_transpose(cts, *args, branches):
|
||||
index, *ops = args
|
||||
assert type(index) is not ad.UndefinedPrimal
|
||||
linear = [type(x) is ad.UndefinedPrimal for x in ops]
|
||||
in_avals = map(raise_to_shaped, branches[0].in_avals)
|
||||
in_avals = branches[0].in_avals
|
||||
num_res = len(ops) - sum(linear)
|
||||
if any(isinstance(eff, RefEffect) for branch in branches for eff in
|
||||
branch.jaxpr.effects):
|
||||
@ -701,8 +700,7 @@ def _cond_transpose(cts, *args, branches):
|
||||
|
||||
branches_trans = tuple(
|
||||
_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
|
||||
lin_in_avals = [raise_to_shaped(a, weak_type=False)
|
||||
for a, l in zip(in_avals, linear) if l]
|
||||
lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l]
|
||||
assert all(core.typematch(out_aval, lin_in_aval)
|
||||
for jaxpr in branches_trans
|
||||
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
|
||||
|
@ -35,7 +35,7 @@ from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import util
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.core import ShapedArray, raise_to_shaped
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -262,7 +262,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
stacked_y = tree_map(stack, *maybe_reversed(ys))
|
||||
return carry, stacked_y
|
||||
|
||||
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
|
||||
xs_avals = [core.get_aval(x) for x in xs_flat]
|
||||
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
||||
|
||||
def _create_jaxpr(init):
|
||||
@ -1370,7 +1370,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts,
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
|
||||
return body_jaxpr.out_avals, joined_effects
|
||||
|
||||
|
||||
def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr,
|
||||
|
@ -23,7 +23,6 @@ from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -300,7 +299,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
|
||||
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
|
||||
if num_aux > 0:
|
||||
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
|
||||
return _map(raise_to_shaped, args_to_raise)
|
||||
return args_to_raise
|
||||
|
||||
|
||||
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
||||
|
@ -48,7 +48,7 @@ from jax._src import state
|
||||
from jax._src import util
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.core import (Primitive, UnshapedArray, ShapedArray,
|
||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||
abstract_token, canonicalize_shape)
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -3044,7 +3044,7 @@ def _to_edtype_abstract_eval(x, *, edtype):
|
||||
f" has a representation shape {rep_aval.shape} while the given "
|
||||
f"representation array has shape {x.shape}, so the shape suffix "
|
||||
f"does not match: given {shape_suffix} but required {rep_aval.shape}.")
|
||||
return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype)
|
||||
return x.update(shape=shape_prefix, dtype=edtype)
|
||||
|
||||
to_edtype_p = Primitive('to_edtype')
|
||||
to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p))
|
||||
@ -5246,7 +5246,7 @@ _INT_DTYPES = {
|
||||
|
||||
|
||||
def _sort_abstract_eval(*args, **kwargs):
|
||||
args = tuple(raise_to_shaped(arg) for arg in args)
|
||||
args = tuple(args)
|
||||
if any(arg.shape != args[0].shape for arg in args[1:]):
|
||||
shapes = " ".join(str(a.shape) for a in args)
|
||||
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
|
||||
@ -6196,7 +6196,7 @@ def _eq_meet(a, b):
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
return raise_to_shaped(core.get_aval(x))
|
||||
return core.get_aval(x)
|
||||
|
||||
|
||||
def empty(dtype):
|
||||
|
@ -33,7 +33,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import (
|
||||
Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape)
|
||||
Primitive, ShapedArray, is_constant_dim, is_constant_shape)
|
||||
from jax._src.extend import ffi
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -1289,7 +1289,6 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size):
|
||||
|
||||
|
||||
def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size):
|
||||
pivots = raise_to_shaped(pivots)
|
||||
if isinstance(pivots, ShapedArray):
|
||||
if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32):
|
||||
raise ValueError(
|
||||
@ -1421,7 +1420,6 @@ def _lu_impl(operand):
|
||||
return lu, pivot, perm
|
||||
|
||||
def _lu_abstract_eval(operand):
|
||||
operand = raise_to_shaped(operand)
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to LU decomposition must have ndims >= 2")
|
||||
|
@ -27,7 +27,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.core import AxisName, ShapedArray, raise_to_shaped
|
||||
from jax._src.core import AxisName, ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -636,7 +636,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
|
||||
raise ValueError(f"axis_index_groups can only be used with reductions over "
|
||||
f"named axes, but got: {axes}")
|
||||
out_avals = [
|
||||
ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
|
||||
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes),
|
||||
arg.dtype) for arg in args]
|
||||
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
|
||||
|
||||
@ -817,7 +817,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
|
||||
|
||||
def _raise_to_shaped_abstract_eval(x, *, axis_name, **params):
|
||||
_check_axis_names(axis_name)
|
||||
return raise_to_shaped(x)
|
||||
return x
|
||||
|
||||
ppermute_p = core.Primitive('ppermute')
|
||||
ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval)
|
||||
@ -1019,13 +1019,12 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in,
|
||||
|
||||
|
||||
def _all_to_all_effectful_abstract_eval(
|
||||
x, axis_name, split_axis, concat_axis, axis_index_groups, tiled
|
||||
input_aval, axis_name, split_axis, concat_axis, axis_index_groups, tiled
|
||||
):
|
||||
del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
axis_name = (axis_name,)
|
||||
_check_axis_names(axis_name)
|
||||
input_aval = raise_to_shaped(x)
|
||||
shape = list(input_aval.shape)
|
||||
axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0])
|
||||
assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size)
|
||||
@ -1169,12 +1168,11 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
|
||||
|
||||
def _all_gather_effectful_abstract_eval(
|
||||
x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
|
||||
x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
|
||||
):
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
axis_name = (axis_name,)
|
||||
_check_axis_names(axis_name)
|
||||
x_aval = raise_to_shaped(x)
|
||||
new_shape = list(x_aval.shape)
|
||||
if tiled:
|
||||
new_shape[all_gather_dimension] *= axis_size
|
||||
@ -1298,12 +1296,11 @@ def _reduce_scatter_lowering(
|
||||
|
||||
|
||||
def _reduce_scatter_effectful_abstract_eval(
|
||||
x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled
|
||||
x_aval, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled
|
||||
):
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
axis_name = (axis_name,)
|
||||
_check_axis_names(axis_name)
|
||||
x_aval = core.raise_to_shaped(x)
|
||||
new_shape = list(x_aval.shape)
|
||||
scatter_dim_input_size = x_aval.shape[scatter_dimension]
|
||||
if tiled:
|
||||
|
@ -140,13 +140,6 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
|
||||
self.memory_space,
|
||||
))
|
||||
|
||||
def at_least_vspace(self):
|
||||
"""Vector space method needed for AD."""
|
||||
raise NotImplementedError
|
||||
|
||||
def join(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
dt_str = \
|
||||
dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
@ -226,11 +219,6 @@ class AbstractMemoryRef(state.AbstractRef):
|
||||
def __repr__(self) -> str:
|
||||
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
|
||||
|
||||
def join(self, other):
|
||||
assert isinstance(other, AbstractMemoryRef)
|
||||
return AbstractMemoryRef(self.inner_aval.join(other.inner_aval),
|
||||
self.memory_space)
|
||||
|
||||
def update(self, inner_aval=None, memory_space=None):
|
||||
inner_aval = self.inner_aval if inner_aval is None else inner_aval
|
||||
memory_space = self.memory_space if memory_space is None else memory_space
|
||||
@ -262,13 +250,6 @@ class MemorySpace(enum.Enum):
|
||||
return self.value
|
||||
|
||||
|
||||
def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
|
||||
return AbstractMemoryRef(
|
||||
jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type),
|
||||
ref_aval.memory_space)
|
||||
jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PallasGridContext:
|
||||
grid: GridMappingGrid
|
||||
|
@ -174,15 +174,6 @@ class SemaphoreType(enum.Enum):
|
||||
class AbstractSemaphore(jax_core.AbstractValue):
|
||||
sem_type: SemaphoreType
|
||||
|
||||
def join(self, other):
|
||||
if not isinstance(other, AbstractSemaphore):
|
||||
raise ValueError
|
||||
if other.sem_type != self.sem_type:
|
||||
raise ValueError
|
||||
return self
|
||||
|
||||
jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True)
|
||||
class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
|
@ -529,7 +529,8 @@ assume_p.def_impl(lambda x, y: x)
|
||||
|
||||
@assume_p.def_abstract_eval
|
||||
def _assume_abstract_eval(x, y):
|
||||
return x.join(y)
|
||||
assert jax_core.typematch(x, y)
|
||||
return x
|
||||
|
||||
def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y):
|
||||
return y if ctx.lowering_context.for_verification else x
|
||||
|
@ -458,15 +458,9 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef):
|
||||
def __repr__(self) -> str:
|
||||
return f'Accumulator{{{self.inner_aval.str_short()}}}'
|
||||
|
||||
def join(self, other):
|
||||
return _as_accum(super().join(other))
|
||||
|
||||
def update(self, inner_aval=None, memory_space=None):
|
||||
return _as_accum(super().update(inner_aval=None, memory_space=None))
|
||||
|
||||
def at_least_vspace(self):
|
||||
return _as_accum(super().at_least_vspace())
|
||||
|
||||
def _getitem(self, tracer, idx):
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error
|
||||
arr = wgmma_accumulator_deref(tracer)
|
||||
@ -483,10 +477,6 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef:
|
||||
memory_space=ref.memory_space, # pytype: disable=attribute-error
|
||||
)
|
||||
|
||||
def _ref_raise_to_shaped(ref_aval, weak_type):
|
||||
return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type))
|
||||
jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped
|
||||
|
||||
|
||||
_WARPGROUP_AXIS_NAME = object()
|
||||
|
||||
|
@ -567,7 +567,7 @@ def wgmma_accumulator_deref(acc):
|
||||
@wgmma_accumulator_deref_p.def_effectful_abstract_eval
|
||||
def _wgmma_accumulator_deref_abstract_eval(acc):
|
||||
# Dereferencing implies flushing so we have a wgmma pipeline effect.
|
||||
ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc
|
||||
ret = acc.inner_aval if isinstance(acc, state.AbstractRef) else acc
|
||||
assert isinstance(ret, jax_core.ShapedArray), acc
|
||||
return ret, {gpu_core._wgmma_pipeline_effect}
|
||||
|
||||
|
@ -230,7 +230,6 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms)
|
||||
expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
|
||||
@ -262,7 +261,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef,
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
|
||||
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
|
@ -291,15 +291,14 @@ class AbstractRef(core.AbstractValue):
|
||||
raise AttributeError
|
||||
return self.inner_aval.weak_type
|
||||
|
||||
def update_weak_type(self, weak_type):
|
||||
return AbstractRef(self.inner_aval.update_weak_type(weak_type))
|
||||
|
||||
def update(self, inner_aval=None):
|
||||
if inner_aval is None:
|
||||
return AbstractRef(self.inner_aval)
|
||||
return AbstractRef(inner_aval)
|
||||
|
||||
def join(self, other):
|
||||
assert isinstance(other, AbstractRef)
|
||||
return AbstractRef(self.inner_aval.join(other.inner_aval))
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: math.prod(self.shape))
|
||||
|
||||
@ -365,10 +364,6 @@ class AbstractRef(core.AbstractValue):
|
||||
def __hash__(self):
|
||||
return hash((self.__class__, self.inner_aval))
|
||||
|
||||
def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type):
|
||||
return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type))
|
||||
core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped
|
||||
|
||||
def _map_ref(size, axis, ref_aval):
|
||||
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))
|
||||
|
||||
|
@ -105,7 +105,6 @@ from jax._src.core import (
|
||||
primitive_uses_outfeed as primitive_uses_outfeed,
|
||||
pytype_aval_mappings as pytype_aval_mappings,
|
||||
raise_to_shaped as raise_to_shaped,
|
||||
raise_to_shaped_mappings as raise_to_shaped_mappings,
|
||||
reset_trace_state as reset_trace_state,
|
||||
set_current_trace as set_current_trace,
|
||||
str_eqn_compact as str_eqn_compact,
|
||||
|
@ -533,15 +533,6 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
r"Variable '.+_test' not defined\n\nin equation:",
|
||||
lambda: core.check_jaxpr(jaxpr))
|
||||
|
||||
@parameterized.parameters(
|
||||
{'value': 0, 'weak_type': True},
|
||||
{'value': np.int32(0), 'weak_type': False},
|
||||
{'value': np.array([0]), 'weak_type': False}
|
||||
)
|
||||
def test_raise_to_shaped_weak_type(self, value, weak_type):
|
||||
aval = core.raise_to_shaped(core.get_aval(value))
|
||||
self.assertEqual(aval.weak_type, weak_type)
|
||||
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True)
|
||||
class DynamicShapesTest(jtu.JaxTestCase):
|
||||
|
@ -3821,7 +3821,7 @@ def shard_foo_array_handler(xs, shardings, layouts):
|
||||
results = []
|
||||
for x, sharding in safe_zip(xs, shardings):
|
||||
device, = sharding._addressable_device_assignment
|
||||
aval = core.raise_to_shaped(core.get_aval(x.data))
|
||||
aval = core.get_aval(x.data)
|
||||
results.append(pxla.batched_device_put(
|
||||
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]))
|
||||
return results
|
||||
|
@ -870,8 +870,9 @@ class PallasCallDMATest(PallasBaseTest):
|
||||
pl.run_scoped(scope)
|
||||
return []
|
||||
|
||||
aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
|
||||
in_avals = [aref, aref]
|
||||
aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
|
||||
aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
|
||||
in_avals = [aref1, aref2]
|
||||
stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = state_discharge.discharge_state(
|
||||
|
@ -746,9 +746,10 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
b_ref[...] = jnp.array(1., dtype=jnp.float32)
|
||||
return a_ref[...], b_ref[...]
|
||||
|
||||
scalar_ref = shaped_array_ref((), jnp.float32)
|
||||
scalar_ref_1 = shaped_array_ref((), jnp.float32)
|
||||
scalar_ref_2 = shaped_array_ref((), jnp.float32)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [scalar_ref, scalar_ref])
|
||||
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2])
|
||||
|
||||
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
|
||||
prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns)
|
||||
|
Loading…
x
Reference in New Issue
Block a user