mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve some xmap error messages
We used to raise really bad errors when: - we failed to infer an axis size - axis sizes were not divisible by the number of resources assigned to them - axis sizes were inconsistent between arguments - in/out_axes referring to out-of-bounds positional axes - in/out_axes using negative indices for positional axes (this is not implemented yet)
This commit is contained in:
parent
5616916cb2
commit
be9c58ae21
@ -166,20 +166,33 @@ def fresh_resource_name(tag=None):
|
||||
finally:
|
||||
_next_resource_id += 1
|
||||
|
||||
|
||||
# This is really a Dict[AxisName, int], but we don't define a
|
||||
# pytree instance for it, so that it is treated as a leaf.
|
||||
class AxisNamePos(FrozenDict):
|
||||
pass
|
||||
user_repr: Any
|
||||
|
||||
def __init__(self, *args, user_repr, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.user_repr = user_repr
|
||||
|
||||
|
||||
# str(...) == 'Ellipsis' which is really annoying
|
||||
class DotDotDotRepr:
|
||||
def __repr__(self): return '...'
|
||||
|
||||
|
||||
def _parse_entry(arg_name, entry):
|
||||
# Dictionaries mapping axis names to positional axes
|
||||
if isinstance(entry, dict) and all(isinstance(v, int) for v in entry.keys()):
|
||||
result = AxisNamePos((name, axis) for axis, name in entry.items())
|
||||
result = AxisNamePos(((name, axis) for axis, name in entry.items()),
|
||||
user_repr=str(entry))
|
||||
num_mapped_dims = len(entry)
|
||||
# Non-empty lists or tuples that terminate with an ellipsis
|
||||
elif isinstance(entry, (tuple, list)) and entry and entry[-1] == ...:
|
||||
result = AxisNamePos((name, axis) for axis, name in enumerate(entry[:-1])
|
||||
if name is not None)
|
||||
result = AxisNamePos(((name, axis) for axis, name in enumerate(entry[:-1])
|
||||
if name is not None),
|
||||
user_repr=str(entry[:-1] + [DotDotDotRepr()]))
|
||||
num_mapped_dims = sum(name is not None for name in entry[:-1])
|
||||
else:
|
||||
raise TypeError(f"""\
|
||||
@ -190,6 +203,9 @@ but got: {entry}""")
|
||||
if len(result) != num_mapped_dims:
|
||||
raise ValueError(f"Named axes should be unique within each {arg_name} argument "
|
||||
f"specification, but one them is: {entry}")
|
||||
for axis in result.values():
|
||||
if axis < 0:
|
||||
raise ValueError(f"xmap doesn't support negative axes in {arg_name}")
|
||||
return result
|
||||
|
||||
def _is_axes_leaf(entry):
|
||||
@ -205,6 +221,7 @@ def _prepare_axes(axes, arg_name):
|
||||
entries = map(partial(_parse_entry, arg_name), entries)
|
||||
return tree_unflatten(treedef, entries), entries
|
||||
|
||||
|
||||
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
|
||||
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
|
||||
def xmap(fun: Callable,
|
||||
@ -374,16 +391,19 @@ def xmap(fun: Callable,
|
||||
# in cases like these users expect tuples and lists to be treated
|
||||
# essentially interchangeably, so we canonicalize lists to tuples here
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
if isinstance(in_axes, list):
|
||||
if isinstance(in_axes, list) and not _is_axes_leaf(in_axes):
|
||||
in_axes = tuple(in_axes)
|
||||
if isinstance(out_axes, list):
|
||||
if isinstance(out_axes, list) and not _is_axes_leaf(out_axes):
|
||||
out_axes = tuple(out_axes)
|
||||
|
||||
if in_axes == (): # Allow empty argument lists
|
||||
in_axes, in_axes_entries = (), []
|
||||
else:
|
||||
in_axes, in_axes_entries = _prepare_axes(in_axes, "in_axes")
|
||||
out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes")
|
||||
if out_axes == ():
|
||||
raise ValueError("xmapped functions cannot have no return values")
|
||||
else:
|
||||
out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes")
|
||||
|
||||
axis_sizes_names = set(axis_sizes.keys())
|
||||
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
|
||||
@ -433,7 +453,12 @@ def xmap(fun: Callable,
|
||||
lambda: tuple(flatten_axes("xmap out_axes", out_tree(), out_axes)),
|
||||
closure=out_axes)
|
||||
frozen_axis_sizes = FrozenDict(_get_axis_sizes(args_flat, in_axes_flat, axis_sizes))
|
||||
assert set(frozen_axis_sizes.keys()) == set(frozen_axis_resources.keys())
|
||||
missing_sizes = defined_names - set(frozen_axis_sizes.keys())
|
||||
if missing_sizes:
|
||||
raise ValueError(f"Failed to infer size of axes: {', '.join(unsafe_map(str, missing_sizes))}. "
|
||||
f"You've probably passed in empty containers in place of arguments that had "
|
||||
f"those axes in their in_axes. Provide the sizes of missing axes explicitly "
|
||||
f"via axis_sizes to fix this error.")
|
||||
out_flat = xmap_p.bind(
|
||||
fun_flat, *args_flat,
|
||||
name=getattr(fun, '__name__', '<unnamed function>'),
|
||||
@ -516,7 +541,10 @@ class EvaluationPlan(NamedTuple):
|
||||
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
|
||||
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
|
||||
paxes_size = int(np.prod([resource_shape[paxis] for paxis in paxes], dtype=np.int64))
|
||||
assert self.axis_sizes[naxis] % paxes_size == 0
|
||||
if self.axis_sizes[naxis] % paxes_size != 0:
|
||||
raise ValueError(f"Size of axis {naxis} ({self.axis_sizes[naxis]}) is not divisible "
|
||||
f"by the total number of resources assigned to this axis ({paxes}, "
|
||||
f"{paxes_size} in total)")
|
||||
tile_size = self.axis_sizes[naxis] // paxes_size
|
||||
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=tile_size, axis_name=vaxis)
|
||||
return f
|
||||
@ -605,7 +633,8 @@ def _batch_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
|
||||
else:
|
||||
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
|
||||
def fmap_dims(axes, f):
|
||||
return AxisNamePos((name, f(axis)) for name, axis in axes.items())
|
||||
return AxisNamePos(((name, f(axis)) for name, axis in axes.items()),
|
||||
user_repr=axes.user_repr)
|
||||
new_in_axes = tuple(
|
||||
fmap_dims(in_axes, lambda a: a + (d is not not_mapped and d <= a))
|
||||
for d, in_axes in zip(dims, params['in_axes']))
|
||||
@ -650,7 +679,6 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
local_mesh_shape = local_mesh.shape
|
||||
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
||||
|
||||
assert type(call_jaxpr) is core.Jaxpr
|
||||
local_avals = [pxla.tile_aval_nd(
|
||||
local_mesh_shape, aval_mesh_in_axes,
|
||||
_insert_aval_axes(v.aval, aval_in_axes, axis_sizes))
|
||||
@ -775,17 +803,26 @@ def _insert_aval_axes(aval, axes: AxisNamePos, axis_sizes):
|
||||
return aval.update(shape=tuple(shape))
|
||||
|
||||
|
||||
# TODO: pmap has some very fancy error messages for this function!
|
||||
def _get_axis_sizes(args_flat: Iterable[Any],
|
||||
in_axes_flat: Iterable[AxisNamePos],
|
||||
axis_sizes: Dict[AxisName, int]):
|
||||
axis_sizes = dict(axis_sizes)
|
||||
for arg, in_axes in zip(args_flat, in_axes_flat):
|
||||
for name, dim in in_axes.items():
|
||||
if name in axis_sizes:
|
||||
assert axis_sizes[name] == arg.shape[dim]
|
||||
if name in axis_sizes and axis_sizes[name] != arg.shape[dim]:
|
||||
raise ValueError(f"The size of axis {name} was previously inferred to be "
|
||||
f"{axis_sizes[name]}, but found an argument of shape {arg.shape} "
|
||||
f"with in_axes specification {in_axes.user_repr}. Shape mismatch "
|
||||
f"occurs in dimension {dim}: {arg.shape[dim]} != {axis_sizes[name]}")
|
||||
else:
|
||||
axis_sizes[name] = arg.shape[dim]
|
||||
try:
|
||||
axis_sizes[name] = arg.shape[dim]
|
||||
except IndexError:
|
||||
# TODO(apaszke): Handle negative indices. Check for overlap too!
|
||||
raise ValueError(f"One of xmap arguments has an in_axes specification of "
|
||||
f"{in_axes.user_repr}, which implies that it has at least "
|
||||
f"{max(in_axes.values()) + 1} dimensions, but the argument "
|
||||
f"has rank {arg.ndim}")
|
||||
return axis_sizes
|
||||
|
||||
|
||||
@ -807,9 +844,17 @@ def hide_mapped_axes(flat_in_axes, flat_out_axes, *flat_args):
|
||||
return arg
|
||||
|
||||
def _unsqueeze_mapped_axes(out, axes: AxisNamePos):
|
||||
for dim in sorted(axes.values()):
|
||||
out = jnp.expand_dims(out, dim)
|
||||
return out
|
||||
try:
|
||||
return jnp.expand_dims(out, tuple(axes.values()))
|
||||
except ValueError as e:
|
||||
# Improve the axis out of bounds errors
|
||||
# TODO(apaszke): Handle negative indices. Check for overlap too!
|
||||
if e.args[0].startswith('axis') and 'out of bounds' in e.args[0]:
|
||||
raise ValueError(f"One of xmap outputs has an out_axes specification of "
|
||||
f"{axes.user_repr}, which requires the result of the xmapped "
|
||||
f"function to have at least {max(axes.values()) - len(axes) + 1} "
|
||||
f"positional dimensions, but it only has {out.ndim}")
|
||||
raise
|
||||
|
||||
squeezed_args = map(_squeeze_mapped_axes, flat_args, flat_in_axes)
|
||||
flat_outputs = yield squeezed_args, {}
|
||||
|
@ -855,6 +855,11 @@ class PDotTests(jtu.JaxTestCase):
|
||||
|
||||
class XMapErrorTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
super().setUp()
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testRepeatedAxisResource(self):
|
||||
@ -880,6 +885,53 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
"Changing the resource environment.*"):
|
||||
f(x)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testEmptyArgumentTrees(self):
|
||||
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
|
||||
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
def testAxesNotDivisibleByResources(self):
|
||||
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
|
||||
r"\(\('x', 'y'\), 4 in total\)"):
|
||||
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
|
||||
axis_sizes={'i': 5}, axis_resources={'i': ('x', 'y')})({})
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testInconsistentAxisSizes(self):
|
||||
x5 = jnp.arange(5)
|
||||
x6 = jnp.arange(6)
|
||||
error = (r"The size of axis i was previously inferred to be 5, but found an "
|
||||
r"argument of shape \(6,\) with in_axes specification \['i', ...\]. "
|
||||
r"Shape mismatch occurs in dimension 0: 6 != 5")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
xmap(lambda x, y: x, in_axes=(['i', ...], ['i', ...]), out_axes=['i', ...])(x5, x6)
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...], axis_sizes={'i': 5})(x6)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testInAxesRankError(self):
|
||||
error = (r"One of xmap arguments has an in_axes specification of \['i', 'j', ...\], "
|
||||
r"which implies that it has at least 2 dimensions, but the argument has rank 1")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
xmap(lambda x: x, in_axes=['i', 'j', ...], out_axes=['j', 'i', ...])(jnp.ones((5,)))
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testOutAxesRankError(self):
|
||||
error = (r"One of xmap outputs has an out_axes specification of {1: 'i'}, "
|
||||
r"which requires the result of the xmapped function to have at least "
|
||||
r"1 positional dimensions, but it only has 0")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
xmap(lambda x: x, in_axes=['i', ...], out_axes={1: 'i'})(jnp.ones((5,)))
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testNegativeAxes(self):
|
||||
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in in_axes"):
|
||||
xmap(lambda x: x, in_axes={-1: 'i'}, out_axes={0: 'i'})(jnp.ones((5,)))
|
||||
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in out_axes"):
|
||||
xmap(lambda x: x, in_axes={0: 'i'}, out_axes={-1: 'i'})(jnp.ones((5,)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user