Improve some xmap error messages

We used to raise really bad errors when:
- we failed to infer an axis size
- axis sizes were not divisible by the number of resources assigned to them
- axis sizes were inconsistent between arguments
- in/out_axes referring to out-of-bounds positional axes
- in/out_axes using negative indices for positional axes (this is not implemented yet)
This commit is contained in:
Adam Paszke 2021-02-02 17:36:46 +00:00
parent 5616916cb2
commit be9c58ae21
2 changed files with 115 additions and 18 deletions

View File

@ -166,20 +166,33 @@ def fresh_resource_name(tag=None):
finally:
_next_resource_id += 1
# This is really a Dict[AxisName, int], but we don't define a
# pytree instance for it, so that it is treated as a leaf.
class AxisNamePos(FrozenDict):
pass
user_repr: Any
def __init__(self, *args, user_repr, **kwargs):
super().__init__(*args, **kwargs)
self.user_repr = user_repr
# str(...) == 'Ellipsis' which is really annoying
class DotDotDotRepr:
def __repr__(self): return '...'
def _parse_entry(arg_name, entry):
# Dictionaries mapping axis names to positional axes
if isinstance(entry, dict) and all(isinstance(v, int) for v in entry.keys()):
result = AxisNamePos((name, axis) for axis, name in entry.items())
result = AxisNamePos(((name, axis) for axis, name in entry.items()),
user_repr=str(entry))
num_mapped_dims = len(entry)
# Non-empty lists or tuples that terminate with an ellipsis
elif isinstance(entry, (tuple, list)) and entry and entry[-1] == ...:
result = AxisNamePos((name, axis) for axis, name in enumerate(entry[:-1])
if name is not None)
result = AxisNamePos(((name, axis) for axis, name in enumerate(entry[:-1])
if name is not None),
user_repr=str(entry[:-1] + [DotDotDotRepr()]))
num_mapped_dims = sum(name is not None for name in entry[:-1])
else:
raise TypeError(f"""\
@ -190,6 +203,9 @@ but got: {entry}""")
if len(result) != num_mapped_dims:
raise ValueError(f"Named axes should be unique within each {arg_name} argument "
f"specification, but one them is: {entry}")
for axis in result.values():
if axis < 0:
raise ValueError(f"xmap doesn't support negative axes in {arg_name}")
return result
def _is_axes_leaf(entry):
@ -205,6 +221,7 @@ def _prepare_axes(axes, arg_name):
entries = map(partial(_parse_entry, arg_name), entries)
return tree_unflatten(treedef, entries), entries
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
def xmap(fun: Callable,
@ -374,16 +391,19 @@ def xmap(fun: Callable,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
if isinstance(in_axes, list):
if isinstance(in_axes, list) and not _is_axes_leaf(in_axes):
in_axes = tuple(in_axes)
if isinstance(out_axes, list):
if isinstance(out_axes, list) and not _is_axes_leaf(out_axes):
out_axes = tuple(out_axes)
if in_axes == (): # Allow empty argument lists
in_axes, in_axes_entries = (), []
else:
in_axes, in_axes_entries = _prepare_axes(in_axes, "in_axes")
out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes")
if out_axes == ():
raise ValueError("xmapped functions cannot have no return values")
else:
out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes")
axis_sizes_names = set(axis_sizes.keys())
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
@ -433,7 +453,12 @@ def xmap(fun: Callable,
lambda: tuple(flatten_axes("xmap out_axes", out_tree(), out_axes)),
closure=out_axes)
frozen_axis_sizes = FrozenDict(_get_axis_sizes(args_flat, in_axes_flat, axis_sizes))
assert set(frozen_axis_sizes.keys()) == set(frozen_axis_resources.keys())
missing_sizes = defined_names - set(frozen_axis_sizes.keys())
if missing_sizes:
raise ValueError(f"Failed to infer size of axes: {', '.join(unsafe_map(str, missing_sizes))}. "
f"You've probably passed in empty containers in place of arguments that had "
f"those axes in their in_axes. Provide the sizes of missing axes explicitly "
f"via axis_sizes to fix this error.")
out_flat = xmap_p.bind(
fun_flat, *args_flat,
name=getattr(fun, '__name__', '<unnamed function>'),
@ -516,7 +541,10 @@ class EvaluationPlan(NamedTuple):
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
paxes_size = int(np.prod([resource_shape[paxis] for paxis in paxes], dtype=np.int64))
assert self.axis_sizes[naxis] % paxes_size == 0
if self.axis_sizes[naxis] % paxes_size != 0:
raise ValueError(f"Size of axis {naxis} ({self.axis_sizes[naxis]}) is not divisible "
f"by the total number of resources assigned to this axis ({paxes}, "
f"{paxes_size} in total)")
tile_size = self.axis_sizes[naxis] // paxes_size
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=tile_size, axis_name=vaxis)
return f
@ -605,7 +633,8 @@ def _batch_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
else:
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
def fmap_dims(axes, f):
return AxisNamePos((name, f(axis)) for name, axis in axes.items())
return AxisNamePos(((name, f(axis)) for name, axis in axes.items()),
user_repr=axes.user_repr)
new_in_axes = tuple(
fmap_dims(in_axes, lambda a: a + (d is not not_mapped and d <= a))
for d, in_axes in zip(dims, params['in_axes']))
@ -650,7 +679,6 @@ def _xmap_translation_rule_replica(c, axis_env,
local_mesh_shape = local_mesh.shape
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
assert type(call_jaxpr) is core.Jaxpr
local_avals = [pxla.tile_aval_nd(
local_mesh_shape, aval_mesh_in_axes,
_insert_aval_axes(v.aval, aval_in_axes, axis_sizes))
@ -775,17 +803,26 @@ def _insert_aval_axes(aval, axes: AxisNamePos, axis_sizes):
return aval.update(shape=tuple(shape))
# TODO: pmap has some very fancy error messages for this function!
def _get_axis_sizes(args_flat: Iterable[Any],
in_axes_flat: Iterable[AxisNamePos],
axis_sizes: Dict[AxisName, int]):
axis_sizes = dict(axis_sizes)
for arg, in_axes in zip(args_flat, in_axes_flat):
for name, dim in in_axes.items():
if name in axis_sizes:
assert axis_sizes[name] == arg.shape[dim]
if name in axis_sizes and axis_sizes[name] != arg.shape[dim]:
raise ValueError(f"The size of axis {name} was previously inferred to be "
f"{axis_sizes[name]}, but found an argument of shape {arg.shape} "
f"with in_axes specification {in_axes.user_repr}. Shape mismatch "
f"occurs in dimension {dim}: {arg.shape[dim]} != {axis_sizes[name]}")
else:
axis_sizes[name] = arg.shape[dim]
try:
axis_sizes[name] = arg.shape[dim]
except IndexError:
# TODO(apaszke): Handle negative indices. Check for overlap too!
raise ValueError(f"One of xmap arguments has an in_axes specification of "
f"{in_axes.user_repr}, which implies that it has at least "
f"{max(in_axes.values()) + 1} dimensions, but the argument "
f"has rank {arg.ndim}")
return axis_sizes
@ -807,9 +844,17 @@ def hide_mapped_axes(flat_in_axes, flat_out_axes, *flat_args):
return arg
def _unsqueeze_mapped_axes(out, axes: AxisNamePos):
for dim in sorted(axes.values()):
out = jnp.expand_dims(out, dim)
return out
try:
return jnp.expand_dims(out, tuple(axes.values()))
except ValueError as e:
# Improve the axis out of bounds errors
# TODO(apaszke): Handle negative indices. Check for overlap too!
if e.args[0].startswith('axis') and 'out of bounds' in e.args[0]:
raise ValueError(f"One of xmap outputs has an out_axes specification of "
f"{axes.user_repr}, which requires the result of the xmapped "
f"function to have at least {max(axes.values()) - len(axes) + 1} "
f"positional dimensions, but it only has {out.ndim}")
raise
squeezed_args = map(_squeeze_mapped_axes, flat_args, flat_in_axes)
flat_outputs = yield squeezed_args, {}

View File

@ -855,6 +855,11 @@ class PDotTests(jtu.JaxTestCase):
class XMapErrorTest(jtu.JaxTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testRepeatedAxisResource(self):
@ -880,6 +885,53 @@ class XMapErrorTest(jtu.JaxTestCase):
"Changing the resource environment.*"):
f(x)
@ignore_xmap_warning()
def testEmptyArgumentTrees(self):
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testAxesNotDivisibleByResources(self):
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
r"\(\('x', 'y'\), 4 in total\)"):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
axis_sizes={'i': 5}, axis_resources={'i': ('x', 'y')})({})
@ignore_xmap_warning()
def testInconsistentAxisSizes(self):
x5 = jnp.arange(5)
x6 = jnp.arange(6)
error = (r"The size of axis i was previously inferred to be 5, but found an "
r"argument of shape \(6,\) with in_axes specification \['i', ...\]. "
r"Shape mismatch occurs in dimension 0: 6 != 5")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x, y: x, in_axes=(['i', ...], ['i', ...]), out_axes=['i', ...])(x5, x6)
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...], axis_sizes={'i': 5})(x6)
@ignore_xmap_warning()
def testInAxesRankError(self):
error = (r"One of xmap arguments has an in_axes specification of \['i', 'j', ...\], "
r"which implies that it has at least 2 dimensions, but the argument has rank 1")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['i', 'j', ...], out_axes=['j', 'i', ...])(jnp.ones((5,)))
@ignore_xmap_warning()
def testOutAxesRankError(self):
error = (r"One of xmap outputs has an out_axes specification of {1: 'i'}, "
r"which requires the result of the xmapped function to have at least "
r"1 positional dimensions, but it only has 0")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['i', ...], out_axes={1: 'i'})(jnp.ones((5,)))
@ignore_xmap_warning()
def testNegativeAxes(self):
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in in_axes"):
xmap(lambda x: x, in_axes={-1: 'i'}, out_axes={0: 'i'})(jnp.ones((5,)))
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in out_axes"):
xmap(lambda x: x, in_axes={0: 'i'}, out_axes={-1: 'i'})(jnp.ones((5,)))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())