Merge pull request #5835 from google:awn-abstract-eval

PiperOrigin-RevId: 361923732
This commit is contained in:
jax authors 2021-03-09 16:27:37 -08:00
commit 23099f6007
7 changed files with 173 additions and 52 deletions

View File

@ -271,7 +271,9 @@ def while_loop(cond_fun: Callable[[T], bool],
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_):
pred_aval = cond_jaxpr.out_avals[0]
if (not isinstance(pred_aval, ShapedArray)
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree

View File

@ -1973,15 +1973,19 @@ def _argnum_weak_type(*argnums):
return lambda *args, **_: all(args[i].weak_type for i in argnums)
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
weak_type_rule=None):
weak_type_rule=None, named_shape_rule=None):
weak_type_rule = weak_type_rule or _standard_weak_type_rule
named_shape_rule = named_shape_rule or standard_named_shape_rule
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule, weak_type_rule))
prim.def_abstract_eval(
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
weak_type_rule, named_shape_rule))
xla.translations[prim] = translation_rule or partial(standard_translate, name)
return prim
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
named_shape_rule, *avals, **kwargs):
assert all(isinstance(aval, UnshapedArray) for aval in avals), avals
assert not prim.multiple_results
weak_type = weak_type_rule(*avals, **kwargs)
@ -1992,14 +1996,16 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, *avals,
weak_type=weak_type)
elif least_specialized is ShapedArray:
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type)
weak_type=weak_type,
named_shape=named_shape_rule(*avals, **kwargs))
elif least_specialized is UnshapedArray:
return UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type)
else:
raise TypeError(avals, least_specialized)
def standard_multi_result_abstract_eval(
prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
prim, shape_rule, dtype_rule, weak_type_rule,
named_shape_rule, *avals, **kwargs):
assert prim.multiple_results
assert all(isinstance(aval, UnshapedArray) for aval in avals), avals
least_specialized = _max(map(type, avals),
@ -2012,8 +2018,10 @@ def standard_multi_result_abstract_eval(
elif least_specialized is ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
return [ShapedArray(s, d, weak_type=weak_type)
for s, d, weak_type in safe_zip(out_shapes, out_dtypes, weak_types)]
out_named_shapes = named_shape_rule(*avals, **kwargs)
return [ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape)
for s, d, weak_type, named_shape
in safe_zip(out_shapes, out_dtypes, weak_types, out_named_shapes)]
elif least_specialized is UnshapedArray:
out_dtypes = dtype_rule(*avals, **kwargs)
return [UnshapedArray(dtype, weak_type=weak_type)
@ -2021,11 +2029,13 @@ def standard_multi_result_abstract_eval(
else:
raise TypeError(avals, least_specialized)
def standard_translate(name, c, *args, **kwargs):
xla_opname = ''.join(term.capitalize() for term in name.split('_'))
return getattr(xops, xla_opname)(*args, **kwargs)
def standard_named_shape_rule(*avals, **kwargs):
return core.join_named_shapes(*(a.named_shape for a in avals))
def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes):
@ -3150,7 +3160,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
input_bitwidth = np.dtype(input_dtype).itemsize
preferred_bitwidth = np.dtype(preferred_element_type).itemsize
if preferred_bitwidth < input_bitwidth:
raise TypeError("`preferred_element_type` must not be narrower than the original type.")
raise TypeError("`preferred_element_type` must not be narrower than the original type.")
return preferred_element_type
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
@ -4958,15 +4968,19 @@ def _reduce_translation_rule(c, *values, computation, jaxpr,
def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
consts, dimensions):
# TODO(mattjj,frostig): use batch_jaxpr, delete computation (assumes poly??)
num_operands = len(batched_args) // 2
operands, init_values = split_list(batched_args, [num_operands])
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
if all(init_value_bdim is None for init_value_bdim in init_value_bdims):
if all(init_value_bdim is batching.not_mapped
for init_value_bdim in init_value_bdims):
# Assume all batch dims are the same for each of the operands
assert all(operand_bdim is not None for operand_bdim in operand_bdims)
assert all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims)
# TODO(sharadmv): handle the case when batch dims are different across
# operands or when some are unbatched
if not all(operand_bdim is not batching.not_mapped for operand_bdim in operand_bdims):
raise NotImplementedError
if not all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims):
raise NotImplementedError
operand_bdim = operand_bdims[0]
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim)))
@ -5007,12 +5021,26 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
bind = prim_bind if input_shape is None else partial(prim_bind, input_shape=padded_shape)
return bind(masked_val, axes=axes)
def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions):
# TODO(mattjj,frostig): see the TODOs noting limitations/assumptions in
# _reduce_batching_rule. We're making the same assumptions here for now.
num_operands = len(avals) // 2
operand_avals, init_avals = split_list(avals, [num_operands])
if any(a.named_shape for a in init_avals):
raise NotImplementedError
named_shapes = [a.named_shape for a in operand_avals]
if not all(named_shapes[0] == named_shape for named_shape in named_shapes):
raise NotImplementedError
return named_shapes
reduce_p = core.Primitive('reduce')
reduce_p.multiple_results = True
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
reduce_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
_reduce_dtype_rule, _reduce_weak_type_rule))
_reduce_dtype_rule, _reduce_weak_type_rule,
_reduce_named_shape_rule))
xla.translations[reduce_p] = _reduce_translation_rule
batching.primitive_batchers[reduce_p] = _reduce_batch_rule

View File

@ -625,10 +625,18 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
return [pos_reducer(arg, axes) for arg in args]
def _allreduce_abstract_eval(*args, axes, axis_index_groups):
# TODO(frostig,mattjj,jekbradbury): maybe check aval names here
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
named_shapes = [arg.named_shape for arg in args]
if axis_index_groups is None:
named_axes = set(axis for axis in axes if not isinstance(axis, int))
named_shapes = [{name: size for name, size in arg.named_shape.items()
if name not in named_axes} for arg in args]
else:
assert len(pos_axes) == 0
return [ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
arg.dtype)
for arg in args]
arg.dtype, named_shape=named_shape)
for arg, named_shape in zip(args, named_shapes)]
def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_groups,
axis_env, platform):
@ -1082,18 +1090,16 @@ def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_
x_aval = raise_to_shaped(x)
new_shape = list(x_aval.shape)
new_shape.insert(all_gather_dimension, axis_size)
return x_aval.update(shape=new_shape)
new_named_shape = {name: size for name, size in x_aval.named_shape.items()
if name != axis_name}
return x_aval.update(shape=new_shape, named_shape=new_named_shape)
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
# TODO(cjfj): Add reduce-scatter op to XLA?
concat_axis = 0
return (lax_numpy.sum(
all_to_all(
cts,
axis_name=axis_name,
split_axis=all_gather_dimension,
concat_axis=concat_axis,
axis_index_groups=axis_index_groups),
return (lax_numpy.sum(all_to_all(
cts, axis_name=axis_name, split_axis=all_gather_dimension,
concat_axis=concat_axis, axis_index_groups=axis_index_groups),
axis=concat_axis),)
def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
@ -1137,10 +1143,13 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
def _axis_index_abstract_eval(*, axis_name):
frame = core.axis_frame(axis_name)
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
axis_index_p = core.Primitive('axis_index')
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), np.int32))
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
pxla.multi_host_supported_collectives.add(axis_index_p)
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
@ -1149,26 +1158,30 @@ core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param,
# wants to bind an axis name has to additionally implement `process_axis_index`
# and put its main trace on the axis env stack.
def _axis_index_bind(*, axis_name):
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
inner_size = 1
index = 0
for name in reversed(axis_name):
def name_idx(name):
frame = core.axis_frame(name)
if frame.main_trace is not None:
trace = frame.main_trace.with_cur_sublevel()
name_idx = trace.process_axis_index(frame)
dynamic = core.thread_local_state.trace_state.trace_stack.dynamic
if (frame.main_trace is None or dynamic.level > frame.main_trace.level):
return core.Primitive.bind(axis_index_p, axis_name=name)
else:
name_idx = core.Primitive.bind(axis_index_p, axis_name=name)
index += name_idx * inner_size
inner_size *= psum(1, name)
return index
trace = frame.main_trace.with_cur_sublevel()
return trace.process_axis_index(frame)
if not isinstance(axis_name, (tuple, list)):
return name_idx(axis_name)
else:
inner_size = 1
index = 0
for name in reversed(axis_name):
index += name_idx(name) * inner_size
inner_size *= psum(1, name)
return index
axis_index_p.def_custom_bind(_axis_index_bind)
def _process_axis_index(self, frame):
def _vmap_process_axis_index(self, frame):
assert frame.size is not None
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore
return batching.BatchTracer(self, lax.iota(np.int32, frame.size), 0)
batching.BatchTrace.process_axis_index = _vmap_process_axis_index # type: ignore
pdot_p = core.Primitive('pdot')
@ -1181,11 +1194,15 @@ def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
@pdot_p.def_abstract_eval
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch):
# TODO: avals with names, check inputs are mapped along axis_name, eliminate
# TODO(frostig,mattjj,jekbradbury): check inputs have given axis names?
if not len(set(axis_name)) == len(axis_name): raise ValueError
return lax.dot_general_p.abstract_eval(
pos_aval = lax.dot_general_p.abstract_eval(
x, y, dimension_numbers=[pos_contract, pos_batch],
precision=None, preferred_element_type=None)
named_shape = {name: size
for aval in (x, y) for name, size in aval.named_shape.items()
if name not in axis_name}
return pos_aval.update(named_shape=named_shape)
def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name,
pos_contract, pos_batch):
@ -1337,7 +1354,8 @@ def omnistaging_disabler() -> None:
nreps = dynamic_axis_env.nreps
trace = frame.pmap_trace
out_aval = ShapedArray((), np.int32)
out_aval = _axis_index_abstract_eval(
nreps=nreps, sizes=sizes, axis_name=axis_name)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
@ -1352,7 +1370,9 @@ def omnistaging_disabler() -> None:
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
def _axis_index_abstract_eval(*, nreps, sizes, axis_name):
return ShapedArray((), np.int32, named_shape={axis_name: sizes[-1]})
axis_index_p.def_custom_bind(_axis_index_bind)
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), np.int32))
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
xla.translations[axis_index_p] = _axis_index_translation_rule

View File

@ -2244,10 +2244,11 @@ def _valid_jaxtype(arg):
class ShapeDtypeStruct:
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
__slots__ = ["shape", "dtype", "named_shape"]
def __init__(self, shape, dtype, named_shape={}):
self.shape = shape
self.dtype = np.dtype(dtype)
self.named_shape = named_shape
size = property(lambda self: prod(self.shape))
ndim = property(lambda self: len(self.shape))
@ -2259,7 +2260,8 @@ class ShapeDtypeStruct:
raise TypeError("len() of unsized object") from e # same as numpy error
def __repr__(self):
return f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype.name})"
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
return f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype.name}{ns})"
__str__ = __repr__
@ -2267,10 +2269,11 @@ class ShapeDtypeStruct:
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return (other.shape, other.dtype) == (self.shape, self.dtype)
return (other.shape, other.dtype, other.named_shape) == (
self.shape, self.dtype, self.named_shape)
def __hash__(self):
return hash((self.shape, self.dtype))
return hash((self.shape, self.dtype, self.named_shape))
def eval_shape(fun: Callable, *args, **kwargs):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
@ -2336,7 +2339,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
*map(shaped_abstractify, args_flat))
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
return tree_unflatten(out_tree(), out)

View File

@ -1259,6 +1259,24 @@ class APITest(jtu.JaxTestCase):
out_shape = api.eval_shape(lambda x: x, x) # doesn't crash
self.assertEqual(out_shape.shape, (3,))
def test_eval_shape_names(self):
def fun(x, y):
return lax.psum(x, 'i') + y
class MyArgArray(object):
def __init__(self, shape, dtype, named_shape):
self.shape = shape
self.dtype = dtype
self.named_shape = named_shape
x = MyArgArray((3, 2), jnp.float32, {'i': 10})
y = MyArgArray((3, 2), jnp.float32, {'j': 5})
with core.extend_axis_env('i', 10, None):
with core.extend_axis_env('j', 5, None):
out_shape = api.eval_shape(fun, x, y)
self.assertEqual(out_shape.named_shape, {'j': 5})
def test_issue_871(self):
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
x = jnp.array([1, 2, 3])
@ -2851,6 +2869,19 @@ class JaxprTest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
self.assertIn('psum', str(jaxpr))
def test_make_jaxpr_named(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
return x - lax.psum(x, 'i')
x = types.SimpleNamespace(
shape=(2, 3), dtype=jnp.float32, named_shape={'i': 10})
jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x)
named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars]
self.assertEqual(named_shapes, [{'i': 10}, {}])
class LazyTest(jtu.JaxTestCase):

View File

@ -335,6 +335,20 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = np.array([4, 3, 4, 3])
self.assertAllClose(ans, expected, check_dtypes=False)
@skipIf(not config.omnistaging_enabled, "test only works with omnistaging")
def testWhileLoopAxisIndexBatched(self):
def fun(x):
return lax.while_loop(lambda x: x < lax.axis_index('i'), lambda x: x + 2, x)
ans = api.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0]))
expected = np.array([0, 2, 2, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
fun = api.jit(fun)
ans = api.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0]))
expected = np.array([0, 2, 2, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
def testWhileLoopCondConstsBatched(self):
def fun(x, y):
return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)

View File

@ -2453,5 +2453,28 @@ class LazyConstantTest(jtu.JaxTestCase):
out = lax.cumsum(x)
self.assertArraysEqual(out, x)
class LaxNamedShapeTest(jtu.JaxTestCase):
def test_abstract_eval(self):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
out = lax.sin_p.abstract_eval(aval1)
self.assertEqual(out, aval1)
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
out = lax.add_p.abstract_eval(aval1, aval2)
self.assertEqual(out, expected)
def test_abstract_eval_collective(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
with core.extend_axis_env('i', 10, None):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
out, = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
self.assertEqual(out, expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())