Make sure that xmap raises a clear error when using an undefined resource name

PiperOrigin-RevId: 397762568
This commit is contained in:
Adam Paszke 2021-09-20 08:58:09 -07:00 committed by jax authors
parent 5caaed9860
commit 45c1a1b060
2 changed files with 12 additions and 3 deletions

View File

@ -565,7 +565,7 @@ def xmap(fun: Callable,
resource_env = thread_resources.env
available_resources = set(resource_env.shape.keys())
if necessary_resources > available_resources:
if necessary_resources - available_resources:
raise ValueError(f"In-scope resources are insufficient to execute the "
f"xmapped function. The missing resources are: "
f"{necessary_resources - available_resources}")
@ -585,14 +585,14 @@ def xmap(fun: Callable,
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
closure=(out_axes_entries, out_axes_treedef))
axis_resource_count = _get_axis_resource_count(normalized_axis_resources, resource_env)
axis_resource_count = _get_axis_resource_count(frozen_axis_resources, resource_env)
for axis, size in axis_sizes.items():
resources = axis_resource_count[axis]
if size % resources.nglobal != 0:
global_size = "Global size" if resources.distributed else "Size"
raise ValueError(f"{global_size} of axis {axis} ({size}) is not divisible "
f"by the total number of resources assigned to this axis "
f"({normalized_axis_resources[axis]}, {resources.nglobal} in total)")
f"({frozen_axis_resources[axis]}, {resources.nglobal} in total)")
frozen_global_axis_sizes = _get_axis_sizes(args_flat, in_axes_flat,
axis_sizes, axis_resource_count)

View File

@ -1093,6 +1093,15 @@ class XMapErrorTest(jtu.JaxTestCase):
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': ('x', 'x')})
@jtu.with_mesh([('y', 2)])
def testUndefinedAxisResource(self):
error = re.escape(
r"In-scope resources are insufficient to execute the xmapped function. "
r"The missing resources are: {'x'}")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(jnp.zeros((4,)))
@jtu.with_mesh([('x', 2)])
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})