fix bug from #15335 by checking main_trace tag

This commit is contained in:
Matthew Johnson 2023-03-30 22:11:21 -07:00
parent 211bc29842
commit 6a2b081506
4 changed files with 39 additions and 13 deletions

View File

@ -2465,10 +2465,12 @@ class _TempAxisName:
return type(other) is _TempAxisName and self.id < other.id
def axis_frame(axis_name):
def axis_frame(axis_name: AxisName, main_trace: Optional[MainTrace] = None
) -> AxisEnvFrame:
frames = thread_local_state.trace_state.axis_env
for frame in reversed(frames):
if frame.name == axis_name:
if (frame.name == axis_name and
(main_trace is None or frame.main_trace is main_trace)):
return frame
named_axes = [frame.name for frame in reversed(frames)
if not isinstance(frame.name, _TempAxisName)]

View File

@ -341,7 +341,7 @@ class BatchTrace(Trace):
if self.axis_name is core.no_axis_name:
assert axis_size is not None # must be inferrable from data
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
frame = core.axis_frame(self.axis_name)
frame = core.axis_frame(self.axis_name, self.main)
assert axis_size is None or axis_size == frame.size, (axis_size, frame.size)
assert frame.main_trace is self.main
return frame

View File

@ -2139,16 +2139,21 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
def test_axis_name_shadowing_with_vmap(self):
# TODO(mattjj): we don't want assertion errors here, but it's a start! the
# main point of including this test for now is to document the bug
# vmap-of-pmap with mismatched axis sizes
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
axis_name='i')(jax.numpy.ones((2, 1))) # don't crash
with self.assertRaises(AssertionError):
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
axis_name='i')(jax.numpy.ones((2, 4)))
# vmap-of-pmap with matched axis sizes
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
axis_name='i')(jax.numpy.ones((1, 1))) # don't crash
with self.assertRaises(AssertionError):
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),
axis_name='i')(jax.numpy.ones((4, 4)))
# vmap-of-vmap with mismatched axis sizes
jax.vmap(jax.vmap(lambda x: 2 * x, axis_name='i'),
axis_name='i')(jax.numpy.ones((2, 1))) # don't crash
# vmap-of-vmap with matched axis sizes
jax.vmap(jax.vmap(lambda x: 2 * x, axis_name='i'),
axis_name='i')(jax.numpy.ones((1, 1))) # don't crash
@jtu.pytest_mark_if_available('multiaccelerator')

View File

@ -319,7 +319,6 @@ class ShardMapTest(jtu.JaxTestCase):
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
_ = jax.jit(f)(jnp.array(2.0)) # doesnt crash
@unittest.skip('Does not work with pjit with pjit batcher error')
def test_vmap_basic(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
x = jnp.arange(8 * 8.).reshape(8, 8)
@ -327,7 +326,27 @@ class ShardMapTest(jtu.JaxTestCase):
def g(x):
return shard_map(lambda x: 2. * x, mesh,
in_specs=P('y'), out_specs=P('y'))(x)
y = jax.vmap(g, axis_name='x')(x)
y = jax.vmap(g)(x)
self.assertAllClose(y, 2 * x, check_dtypes=False)
def test_vmap_basic_axis_name(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
x = jnp.arange(8 * 8.).reshape(8, 8)
def g(x):
return shard_map(lambda x: 2. * x, mesh,
in_specs=P('y'), out_specs=P('y'))(x)
y = jax.vmap(g, axis_name='i')(x)
self.assertAllClose(y, 2 * x, check_dtypes=False)
def test_vmap_basic_axis_name_reuse_mesh_name(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
x = jnp.arange(8 * 8.).reshape(8, 8)
def g(x):
return shard_map(lambda x: 2. * x, mesh,
in_specs=P('y'), out_specs=P('y'))(x)
y = jax.vmap(g, axis_name='x')(x) # NOTE reuse same 'x' as on mesh
self.assertAllClose(y, 2 * x, check_dtypes=False)
def test_tree_prefix_error(self):