mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #5835 from google:awn-abstract-eval
PiperOrigin-RevId: 361923732
This commit is contained in:
commit
23099f6007
@ -271,7 +271,9 @@ def while_loop(cond_fun: Callable[[T], bool],
|
||||
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
||||
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_):
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
if (not isinstance(pred_aval, ShapedArray)
|
||||
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
|
||||
|
@ -1973,15 +1973,19 @@ def _argnum_weak_type(*argnums):
|
||||
return lambda *args, **_: all(args[i].weak_type for i in argnums)
|
||||
|
||||
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
|
||||
weak_type_rule=None):
|
||||
weak_type_rule=None, named_shape_rule=None):
|
||||
weak_type_rule = weak_type_rule or _standard_weak_type_rule
|
||||
named_shape_rule = named_shape_rule or standard_named_shape_rule
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule, weak_type_rule))
|
||||
prim.def_abstract_eval(
|
||||
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
|
||||
weak_type_rule, named_shape_rule))
|
||||
xla.translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
named_shape_rule, *avals, **kwargs):
|
||||
assert all(isinstance(aval, UnshapedArray) for aval in avals), avals
|
||||
assert not prim.multiple_results
|
||||
weak_type = weak_type_rule(*avals, **kwargs)
|
||||
@ -1992,14 +1996,16 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, *avals,
|
||||
weak_type=weak_type)
|
||||
elif least_specialized is ShapedArray:
|
||||
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
|
||||
weak_type=weak_type)
|
||||
weak_type=weak_type,
|
||||
named_shape=named_shape_rule(*avals, **kwargs))
|
||||
elif least_specialized is UnshapedArray:
|
||||
return UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type)
|
||||
else:
|
||||
raise TypeError(avals, least_specialized)
|
||||
|
||||
def standard_multi_result_abstract_eval(
|
||||
prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
|
||||
prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
named_shape_rule, *avals, **kwargs):
|
||||
assert prim.multiple_results
|
||||
assert all(isinstance(aval, UnshapedArray) for aval in avals), avals
|
||||
least_specialized = _max(map(type, avals),
|
||||
@ -2012,8 +2018,10 @@ def standard_multi_result_abstract_eval(
|
||||
elif least_specialized is ShapedArray:
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
return [ShapedArray(s, d, weak_type=weak_type)
|
||||
for s, d, weak_type in safe_zip(out_shapes, out_dtypes, weak_types)]
|
||||
out_named_shapes = named_shape_rule(*avals, **kwargs)
|
||||
return [ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape)
|
||||
for s, d, weak_type, named_shape
|
||||
in safe_zip(out_shapes, out_dtypes, weak_types, out_named_shapes)]
|
||||
elif least_specialized is UnshapedArray:
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
return [UnshapedArray(dtype, weak_type=weak_type)
|
||||
@ -2021,11 +2029,13 @@ def standard_multi_result_abstract_eval(
|
||||
else:
|
||||
raise TypeError(avals, least_specialized)
|
||||
|
||||
|
||||
def standard_translate(name, c, *args, **kwargs):
|
||||
xla_opname = ''.join(term.capitalize() for term in name.split('_'))
|
||||
return getattr(xops, xla_opname)(*args, **kwargs)
|
||||
|
||||
def standard_named_shape_rule(*avals, **kwargs):
|
||||
return core.join_named_shapes(*(a.named_shape for a in avals))
|
||||
|
||||
|
||||
def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
|
||||
if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes):
|
||||
@ -3150,7 +3160,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
input_bitwidth = np.dtype(input_dtype).itemsize
|
||||
preferred_bitwidth = np.dtype(preferred_element_type).itemsize
|
||||
if preferred_bitwidth < input_bitwidth:
|
||||
raise TypeError("`preferred_element_type` must not be narrower than the original type.")
|
||||
raise TypeError("`preferred_element_type` must not be narrower than the original type.")
|
||||
return preferred_element_type
|
||||
|
||||
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
||||
@ -4958,15 +4968,19 @@ def _reduce_translation_rule(c, *values, computation, jaxpr,
|
||||
|
||||
def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
|
||||
consts, dimensions):
|
||||
# TODO(mattjj,frostig): use batch_jaxpr, delete computation (assumes poly??)
|
||||
num_operands = len(batched_args) // 2
|
||||
operands, init_values = split_list(batched_args, [num_operands])
|
||||
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
|
||||
if all(init_value_bdim is None for init_value_bdim in init_value_bdims):
|
||||
if all(init_value_bdim is batching.not_mapped
|
||||
for init_value_bdim in init_value_bdims):
|
||||
# Assume all batch dims are the same for each of the operands
|
||||
assert all(operand_bdim is not None for operand_bdim in operand_bdims)
|
||||
assert all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims)
|
||||
# TODO(sharadmv): handle the case when batch dims are different across
|
||||
# operands or when some are unbatched
|
||||
if not all(operand_bdim is not batching.not_mapped for operand_bdim in operand_bdims):
|
||||
raise NotImplementedError
|
||||
if not all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims):
|
||||
raise NotImplementedError
|
||||
operand_bdim = operand_bdims[0]
|
||||
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
|
||||
new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim)))
|
||||
@ -5007,12 +5021,26 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
|
||||
bind = prim_bind if input_shape is None else partial(prim_bind, input_shape=padded_shape)
|
||||
return bind(masked_val, axes=axes)
|
||||
|
||||
def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions):
|
||||
# TODO(mattjj,frostig): see the TODOs noting limitations/assumptions in
|
||||
# _reduce_batching_rule. We're making the same assumptions here for now.
|
||||
num_operands = len(avals) // 2
|
||||
operand_avals, init_avals = split_list(avals, [num_operands])
|
||||
if any(a.named_shape for a in init_avals):
|
||||
raise NotImplementedError
|
||||
named_shapes = [a.named_shape for a in operand_avals]
|
||||
if not all(named_shapes[0] == named_shape for named_shape in named_shapes):
|
||||
raise NotImplementedError
|
||||
return named_shapes
|
||||
|
||||
|
||||
reduce_p = core.Primitive('reduce')
|
||||
reduce_p.multiple_results = True
|
||||
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
|
||||
reduce_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
|
||||
_reduce_dtype_rule, _reduce_weak_type_rule))
|
||||
_reduce_dtype_rule, _reduce_weak_type_rule,
|
||||
_reduce_named_shape_rule))
|
||||
xla.translations[reduce_p] = _reduce_translation_rule
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
|
||||
|
@ -625,10 +625,18 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
|
||||
return [pos_reducer(arg, axes) for arg in args]
|
||||
|
||||
def _allreduce_abstract_eval(*args, axes, axis_index_groups):
|
||||
# TODO(frostig,mattjj,jekbradbury): maybe check aval names here
|
||||
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
|
||||
named_shapes = [arg.named_shape for arg in args]
|
||||
if axis_index_groups is None:
|
||||
named_axes = set(axis for axis in axes if not isinstance(axis, int))
|
||||
named_shapes = [{name: size for name, size in arg.named_shape.items()
|
||||
if name not in named_axes} for arg in args]
|
||||
else:
|
||||
assert len(pos_axes) == 0
|
||||
return [ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
|
||||
arg.dtype)
|
||||
for arg in args]
|
||||
arg.dtype, named_shape=named_shape)
|
||||
for arg, named_shape in zip(args, named_shapes)]
|
||||
|
||||
def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_groups,
|
||||
axis_env, platform):
|
||||
@ -1082,18 +1090,16 @@ def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_
|
||||
x_aval = raise_to_shaped(x)
|
||||
new_shape = list(x_aval.shape)
|
||||
new_shape.insert(all_gather_dimension, axis_size)
|
||||
return x_aval.update(shape=new_shape)
|
||||
new_named_shape = {name: size for name, size in x_aval.named_shape.items()
|
||||
if name != axis_name}
|
||||
return x_aval.update(shape=new_shape, named_shape=new_named_shape)
|
||||
|
||||
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
|
||||
# TODO(cjfj): Add reduce-scatter op to XLA?
|
||||
concat_axis = 0
|
||||
return (lax_numpy.sum(
|
||||
all_to_all(
|
||||
cts,
|
||||
axis_name=axis_name,
|
||||
split_axis=all_gather_dimension,
|
||||
concat_axis=concat_axis,
|
||||
axis_index_groups=axis_index_groups),
|
||||
return (lax_numpy.sum(all_to_all(
|
||||
cts, axis_name=axis_name, split_axis=all_gather_dimension,
|
||||
concat_axis=concat_axis, axis_index_groups=axis_index_groups),
|
||||
axis=concat_axis),)
|
||||
|
||||
def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
|
||||
@ -1137,10 +1143,13 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
def _axis_index_abstract_eval(*, axis_name):
|
||||
frame = core.axis_frame(axis_name)
|
||||
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
|
||||
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
|
||||
axis_index_p.def_abstract_eval(
|
||||
lambda *args, **params: ShapedArray((), np.int32))
|
||||
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
|
||||
pxla.multi_host_supported_collectives.add(axis_index_p)
|
||||
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
@ -1149,26 +1158,30 @@ core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param,
|
||||
# wants to bind an axis name has to additionally implement `process_axis_index`
|
||||
# and put its main trace on the axis env stack.
|
||||
def _axis_index_bind(*, axis_name):
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
inner_size = 1
|
||||
index = 0
|
||||
for name in reversed(axis_name):
|
||||
def name_idx(name):
|
||||
frame = core.axis_frame(name)
|
||||
if frame.main_trace is not None:
|
||||
trace = frame.main_trace.with_cur_sublevel()
|
||||
name_idx = trace.process_axis_index(frame)
|
||||
dynamic = core.thread_local_state.trace_state.trace_stack.dynamic
|
||||
if (frame.main_trace is None or dynamic.level > frame.main_trace.level):
|
||||
return core.Primitive.bind(axis_index_p, axis_name=name)
|
||||
else:
|
||||
name_idx = core.Primitive.bind(axis_index_p, axis_name=name)
|
||||
index += name_idx * inner_size
|
||||
inner_size *= psum(1, name)
|
||||
return index
|
||||
trace = frame.main_trace.with_cur_sublevel()
|
||||
return trace.process_axis_index(frame)
|
||||
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
return name_idx(axis_name)
|
||||
else:
|
||||
inner_size = 1
|
||||
index = 0
|
||||
for name in reversed(axis_name):
|
||||
index += name_idx(name) * inner_size
|
||||
inner_size *= psum(1, name)
|
||||
return index
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
|
||||
def _process_axis_index(self, frame):
|
||||
def _vmap_process_axis_index(self, frame):
|
||||
assert frame.size is not None
|
||||
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
|
||||
batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore
|
||||
return batching.BatchTracer(self, lax.iota(np.int32, frame.size), 0)
|
||||
batching.BatchTrace.process_axis_index = _vmap_process_axis_index # type: ignore
|
||||
|
||||
|
||||
pdot_p = core.Primitive('pdot')
|
||||
@ -1181,11 +1194,15 @@ def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
|
||||
|
||||
@pdot_p.def_abstract_eval
|
||||
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch):
|
||||
# TODO: avals with names, check inputs are mapped along axis_name, eliminate
|
||||
# TODO(frostig,mattjj,jekbradbury): check inputs have given axis names?
|
||||
if not len(set(axis_name)) == len(axis_name): raise ValueError
|
||||
return lax.dot_general_p.abstract_eval(
|
||||
pos_aval = lax.dot_general_p.abstract_eval(
|
||||
x, y, dimension_numbers=[pos_contract, pos_batch],
|
||||
precision=None, preferred_element_type=None)
|
||||
named_shape = {name: size
|
||||
for aval in (x, y) for name, size in aval.named_shape.items()
|
||||
if name not in axis_name}
|
||||
return pos_aval.update(named_shape=named_shape)
|
||||
|
||||
def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name,
|
||||
pos_contract, pos_batch):
|
||||
@ -1337,7 +1354,8 @@ def omnistaging_disabler() -> None:
|
||||
nreps = dynamic_axis_env.nreps
|
||||
trace = frame.pmap_trace
|
||||
|
||||
out_aval = ShapedArray((), np.int32)
|
||||
out_aval = _axis_index_abstract_eval(
|
||||
nreps=nreps, sizes=sizes, axis_name=axis_name)
|
||||
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
||||
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
|
||||
dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
|
||||
@ -1352,7 +1370,9 @@ def omnistaging_disabler() -> None:
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
def _axis_index_abstract_eval(*, nreps, sizes, axis_name):
|
||||
return ShapedArray((), np.int32, named_shape={axis_name: sizes[-1]})
|
||||
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
axis_index_p.def_abstract_eval(
|
||||
lambda *args, **params: ShapedArray((), np.int32))
|
||||
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
|
||||
xla.translations[axis_index_p] = _axis_index_translation_rule
|
||||
|
15
jax/api.py
15
jax/api.py
@ -2244,10 +2244,11 @@ def _valid_jaxtype(arg):
|
||||
|
||||
|
||||
class ShapeDtypeStruct:
|
||||
__slots__ = ["shape", "dtype"]
|
||||
def __init__(self, shape, dtype):
|
||||
__slots__ = ["shape", "dtype", "named_shape"]
|
||||
def __init__(self, shape, dtype, named_shape={}):
|
||||
self.shape = shape
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.named_shape = named_shape
|
||||
|
||||
size = property(lambda self: prod(self.shape))
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
@ -2259,7 +2260,8 @@ class ShapeDtypeStruct:
|
||||
raise TypeError("len() of unsized object") from e # same as numpy error
|
||||
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype.name})"
|
||||
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
|
||||
return f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype.name}{ns})"
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@ -2267,10 +2269,11 @@ class ShapeDtypeStruct:
|
||||
if not isinstance(other, ShapeDtypeStruct):
|
||||
return False
|
||||
else:
|
||||
return (other.shape, other.dtype) == (self.shape, self.dtype)
|
||||
return (other.shape, other.dtype, other.named_shape) == (
|
||||
self.shape, self.dtype, self.named_shape)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape, self.dtype))
|
||||
return hash((self.shape, self.dtype, self.named_shape))
|
||||
|
||||
def eval_shape(fun: Callable, *args, **kwargs):
|
||||
"""Compute the shape/dtype of ``fun`` without any FLOPs.
|
||||
@ -2336,7 +2339,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
|
||||
*map(shaped_abstractify, args_flat))
|
||||
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
|
||||
out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
|
||||
|
@ -1259,6 +1259,24 @@ class APITest(jtu.JaxTestCase):
|
||||
out_shape = api.eval_shape(lambda x: x, x) # doesn't crash
|
||||
self.assertEqual(out_shape.shape, (3,))
|
||||
|
||||
def test_eval_shape_names(self):
|
||||
def fun(x, y):
|
||||
return lax.psum(x, 'i') + y
|
||||
|
||||
class MyArgArray(object):
|
||||
def __init__(self, shape, dtype, named_shape):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.named_shape = named_shape
|
||||
|
||||
x = MyArgArray((3, 2), jnp.float32, {'i': 10})
|
||||
y = MyArgArray((3, 2), jnp.float32, {'j': 5})
|
||||
with core.extend_axis_env('i', 10, None):
|
||||
with core.extend_axis_env('j', 5, None):
|
||||
out_shape = api.eval_shape(fun, x, y)
|
||||
|
||||
self.assertEqual(out_shape.named_shape, {'j': 5})
|
||||
|
||||
def test_issue_871(self):
|
||||
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
|
||||
x = jnp.array([1, 2, 3])
|
||||
@ -2851,6 +2869,19 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
|
||||
self.assertIn('psum', str(jaxpr))
|
||||
|
||||
def test_make_jaxpr_named(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
|
||||
x = types.SimpleNamespace(
|
||||
shape=(2, 3), dtype=jnp.float32, named_shape={'i': 10})
|
||||
jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x)
|
||||
named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars]
|
||||
self.assertEqual(named_shapes, [{'i': 10}, {}])
|
||||
|
||||
|
||||
class LazyTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -335,6 +335,20 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = np.array([4, 3, 4, 3])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test only works with omnistaging")
|
||||
def testWhileLoopAxisIndexBatched(self):
|
||||
def fun(x):
|
||||
return lax.while_loop(lambda x: x < lax.axis_index('i'), lambda x: x + 2, x)
|
||||
|
||||
ans = api.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0]))
|
||||
expected = np.array([0, 2, 2, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
fun = api.jit(fun)
|
||||
ans = api.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0]))
|
||||
expected = np.array([0, 2, 2, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testWhileLoopCondConstsBatched(self):
|
||||
def fun(x, y):
|
||||
return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)
|
||||
|
@ -2453,5 +2453,28 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
out = lax.cumsum(x)
|
||||
self.assertArraysEqual(out, x)
|
||||
|
||||
|
||||
class LaxNamedShapeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_abstract_eval(self):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
out = lax.sin_p.abstract_eval(aval1)
|
||||
self.assertEqual(out, aval1)
|
||||
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
out = lax.add_p.abstract_eval(aval1, aval2)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
def test_abstract_eval_collective(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("test requires omnistaging")
|
||||
with core.extend_axis_env('i', 10, None):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
out, = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user