Improve undefined axis checks

Previously we checked for out axes being a superset of the defined axes,
but that's just not the right relation. In particular, out_axes of {'a'}
are not a superset of defined axes {'b'}, but axis 'a' is undefined. The
correct check is to verify emptiness of their difference.
This commit is contained in:
Adam Paszke 2022-10-26 15:30:11 +00:00
parent 40c85bdab8
commit 0fce5be556
2 changed files with 26 additions and 10 deletions

View File

@ -490,6 +490,16 @@ def xmap(fun: Callable,
return name
return r
axes_with_resources = set(axis_resources.keys())
if axes_with_resources - defined_names:
raise ValueError(f"All axes that were assigned resources have to appear in "
f"in_axes or axis_sizes, but the following are missing: "
f"{axes_with_resources - defined_names}")
if out_axes_names - defined_names:
raise ValueError(f"All axis names appearing in out_axes must also appear in "
f"in_axes or axis_sizes, but the following are missing: "
f"{out_axes_names - defined_names}")
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = {}
for axis in defined_names:
resources = axis_resources.get(axis, ())
@ -499,16 +509,6 @@ def xmap(fun: Callable,
frozen_axis_resources = FrozenDict(normalized_axis_resources)
necessary_resources = set(it.chain(*frozen_axis_resources.values()))
axes_with_resources = set(frozen_axis_resources.keys())
if axes_with_resources > defined_names:
raise ValueError(f"All axes that were assigned resources have to appear in "
f"in_axes or axis_sizes, but the following are missing: "
f"{axes_with_resources - defined_names}")
if out_axes_names > defined_names:
raise ValueError(f"All axis names appearing in out_axes must also appear in "
f"in_axes or axis_sizes, but the following are missing: "
f"{out_axes_names - defined_names}")
for axis, resources in frozen_axis_resources.items():
if len(set(resources)) != len(resources): # type: ignore
raise ValueError(f"Resource assignment of a single axis must be a tuple of "

View File

@ -1779,6 +1779,22 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, error):
fm(x, y)
def testUndefinedOutAxis(self):
error = (r"All axis names appearing in out_axes must also appear in "
r"in_axes or axis_sizes, but the following are missing: {'c'}")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x, y: x + y,
in_axes=(['a', ...], ['b', ...]), out_axes=['c', ...])
@jtu.with_mesh([('x', 2)])
def testUndefinedAxisInAxisResources(self):
error = (r"All axes that were assigned resources have to appear in in_axes "
r"or axis_sizes, but the following are missing: {'b'}")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x, y: x + y,
in_axes=(['a', ...], ['a', ...]), out_axes=['a', ...],
axis_resources={'b': 'x'})
@jtu.with_mesh([('x', 2)])
def testResourceConflictArgs(self):
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),