mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make sure that xmap raises a clear error when using an undefined resource name
PiperOrigin-RevId: 397762568
This commit is contained in:
parent
5caaed9860
commit
45c1a1b060
@ -565,7 +565,7 @@ def xmap(fun: Callable,
|
||||
resource_env = thread_resources.env
|
||||
available_resources = set(resource_env.shape.keys())
|
||||
|
||||
if necessary_resources > available_resources:
|
||||
if necessary_resources - available_resources:
|
||||
raise ValueError(f"In-scope resources are insufficient to execute the "
|
||||
f"xmapped function. The missing resources are: "
|
||||
f"{necessary_resources - available_resources}")
|
||||
@ -585,14 +585,14 @@ def xmap(fun: Callable,
|
||||
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
|
||||
closure=(out_axes_entries, out_axes_treedef))
|
||||
|
||||
axis_resource_count = _get_axis_resource_count(normalized_axis_resources, resource_env)
|
||||
axis_resource_count = _get_axis_resource_count(frozen_axis_resources, resource_env)
|
||||
for axis, size in axis_sizes.items():
|
||||
resources = axis_resource_count[axis]
|
||||
if size % resources.nglobal != 0:
|
||||
global_size = "Global size" if resources.distributed else "Size"
|
||||
raise ValueError(f"{global_size} of axis {axis} ({size}) is not divisible "
|
||||
f"by the total number of resources assigned to this axis "
|
||||
f"({normalized_axis_resources[axis]}, {resources.nglobal} in total)")
|
||||
f"({frozen_axis_resources[axis]}, {resources.nglobal} in total)")
|
||||
frozen_global_axis_sizes = _get_axis_sizes(args_flat, in_axes_flat,
|
||||
axis_sizes, axis_resource_count)
|
||||
|
||||
|
@ -1093,6 +1093,15 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': ('x', 'x')})
|
||||
|
||||
@jtu.with_mesh([('y', 2)])
|
||||
def testUndefinedAxisResource(self):
|
||||
error = re.escape(
|
||||
r"In-scope resources are insufficient to execute the xmapped function. "
|
||||
r"The missing resources are: {'x'}")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
xmap(lambda x: x, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(jnp.zeros((4,)))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||
|
Loading…
x
Reference in New Issue
Block a user