mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve undefined axis checks
Previously we checked for out axes being a superset of the defined axes, but that's just not the right relation. In particular, out_axes of {'a'} are not a superset of defined axes {'b'}, but axis 'a' is undefined. The correct check is to verify emptiness of their difference.
This commit is contained in:
parent
40c85bdab8
commit
0fce5be556
@ -490,6 +490,16 @@ def xmap(fun: Callable,
|
||||
return name
|
||||
return r
|
||||
|
||||
axes_with_resources = set(axis_resources.keys())
|
||||
if axes_with_resources - defined_names:
|
||||
raise ValueError(f"All axes that were assigned resources have to appear in "
|
||||
f"in_axes or axis_sizes, but the following are missing: "
|
||||
f"{axes_with_resources - defined_names}")
|
||||
if out_axes_names - defined_names:
|
||||
raise ValueError(f"All axis names appearing in out_axes must also appear in "
|
||||
f"in_axes or axis_sizes, but the following are missing: "
|
||||
f"{out_axes_names - defined_names}")
|
||||
|
||||
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = {}
|
||||
for axis in defined_names:
|
||||
resources = axis_resources.get(axis, ())
|
||||
@ -499,16 +509,6 @@ def xmap(fun: Callable,
|
||||
frozen_axis_resources = FrozenDict(normalized_axis_resources)
|
||||
necessary_resources = set(it.chain(*frozen_axis_resources.values()))
|
||||
|
||||
axes_with_resources = set(frozen_axis_resources.keys())
|
||||
if axes_with_resources > defined_names:
|
||||
raise ValueError(f"All axes that were assigned resources have to appear in "
|
||||
f"in_axes or axis_sizes, but the following are missing: "
|
||||
f"{axes_with_resources - defined_names}")
|
||||
if out_axes_names > defined_names:
|
||||
raise ValueError(f"All axis names appearing in out_axes must also appear in "
|
||||
f"in_axes or axis_sizes, but the following are missing: "
|
||||
f"{out_axes_names - defined_names}")
|
||||
|
||||
for axis, resources in frozen_axis_resources.items():
|
||||
if len(set(resources)) != len(resources): # type: ignore
|
||||
raise ValueError(f"Resource assignment of a single axis must be a tuple of "
|
||||
|
@ -1779,6 +1779,22 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
def testUndefinedOutAxis(self):
|
||||
error = (r"All axis names appearing in out_axes must also appear in "
|
||||
r"in_axes or axis_sizes, but the following are missing: {'c'}")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
xmap(lambda x, y: x + y,
|
||||
in_axes=(['a', ...], ['b', ...]), out_axes=['c', ...])
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testUndefinedAxisInAxisResources(self):
|
||||
error = (r"All axes that were assigned resources have to appear in in_axes "
|
||||
r"or axis_sizes, but the following are missing: {'b'}")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
xmap(lambda x, y: x + y,
|
||||
in_axes=(['a', ...], ['a', ...]), out_axes=['a', ...],
|
||||
axis_resources={'b': 'x'})
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testResourceConflictArgs(self):
|
||||
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
|
||||
|
Loading…
x
Reference in New Issue
Block a user