mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15337 from mattjj:axis-name-shadowing-2
PiperOrigin-RevId: 520838748
This commit is contained in:
commit
76b922aade
@ -2465,10 +2465,12 @@ class _TempAxisName:
|
||||
return type(other) is _TempAxisName and self.id < other.id
|
||||
|
||||
|
||||
def axis_frame(axis_name):
|
||||
def axis_frame(axis_name: AxisName, main_trace: Optional[MainTrace] = None
|
||||
) -> AxisEnvFrame:
|
||||
frames = thread_local_state.trace_state.axis_env
|
||||
for frame in reversed(frames):
|
||||
if frame.name == axis_name:
|
||||
if (frame.name == axis_name and
|
||||
(main_trace is None or frame.main_trace is main_trace)):
|
||||
return frame
|
||||
named_axes = [frame.name for frame in reversed(frames)
|
||||
if not isinstance(frame.name, _TempAxisName)]
|
||||
|
@ -341,7 +341,7 @@ class BatchTrace(Trace):
|
||||
if self.axis_name is core.no_axis_name:
|
||||
assert axis_size is not None # must be inferrable from data
|
||||
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
|
||||
frame = core.axis_frame(self.axis_name)
|
||||
frame = core.axis_frame(self.axis_name, self.main)
|
||||
assert axis_size is None or axis_size == frame.size, (axis_size, frame.size)
|
||||
assert frame.main_trace is self.main
|
||||
return frame
|
||||
|
@ -2139,16 +2139,21 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
|
||||
def test_axis_name_shadowing_with_vmap(self):
|
||||
# TODO(mattjj): we don't want assertion errors here, but it's a start! the
|
||||
# main point of including this test for now is to document the bug
|
||||
# vmap-of-pmap with mismatched axis sizes
|
||||
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
|
||||
axis_name='i')(jax.numpy.ones((2, 1))) # don't crash
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
|
||||
axis_name='i')(jax.numpy.ones((2, 4)))
|
||||
# vmap-of-pmap with matched axis sizes
|
||||
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
|
||||
axis_name='i')(jax.numpy.ones((1, 1))) # don't crash
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
|
||||
axis_name='i')(jax.numpy.ones((4, 4)))
|
||||
# vmap-of-vmap with mismatched axis sizes
|
||||
jax.vmap(jax.vmap(lambda x: 2 * x, axis_name='i'),
|
||||
axis_name='i')(jax.numpy.ones((2, 1))) # don't crash
|
||||
|
||||
# vmap-of-vmap with matched axis sizes
|
||||
jax.vmap(jax.vmap(lambda x: 2 * x, axis_name='i'),
|
||||
axis_name='i')(jax.numpy.ones((1, 1))) # don't crash
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
|
@ -319,7 +319,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
|
||||
_ = jax.jit(f)(jnp.array(2.0)) # doesnt crash
|
||||
|
||||
@unittest.skip('Does not work with pjit with pjit batcher error')
|
||||
def test_vmap_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
@ -327,7 +326,27 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def g(x):
|
||||
return shard_map(lambda x: 2. * x, mesh,
|
||||
in_specs=P('y'), out_specs=P('y'))(x)
|
||||
y = jax.vmap(g, axis_name='x')(x)
|
||||
y = jax.vmap(g)(x)
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
def test_vmap_basic_axis_name(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def g(x):
|
||||
return shard_map(lambda x: 2. * x, mesh,
|
||||
in_specs=P('y'), out_specs=P('y'))(x)
|
||||
y = jax.vmap(g, axis_name='i')(x)
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
def test_vmap_basic_axis_name_reuse_mesh_name(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def g(x):
|
||||
return shard_map(lambda x: 2. * x, mesh,
|
||||
in_specs=P('y'), out_specs=P('y'))(x)
|
||||
y = jax.vmap(g, axis_name='x')(x) # NOTE reuse same 'x' as on mesh
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
def test_tree_prefix_error(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user