[vmappable] fix trace context bugs

to_elt must run in the parent context, while from_elt must run in the batching
context. We previously had it precisely backward!

Tests didn't catch it because our tests are extremely minimal, and in
particular didn't check a to_elt that binds primitives.
This commit is contained in:
Matthew Johnson 2025-04-06 00:29:23 +00:00
parent 2e62693f72
commit 6bae8c75c8
2 changed files with 55 additions and 20 deletions

View File

@ -616,17 +616,15 @@ def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals):
trace = BatchTrace(parent_trace, tag, axis_data)
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
source_info_util.current()))
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
with core.set_current_trace(parent_trace):
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = f(*in_tracers)
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_data.size,
axis_data.explicit_mesh_axis),
range(len(outs)), outs, out_dim_dests)
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis),
range(len(outs)), outs, out_dim_dests)
return out_vals, trace
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.

View File

@ -1328,33 +1328,70 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
@jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe
class VmappableTest(jtu.JaxTestCase):
def test_basic(self):
@parameterized.parameters([False, True])
def test_basic(self, jit):
with temporarily_register_named_array_vmappable():
def f(x):
return named_mul(x, x)
if jit:
f = jax.jit(f)
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
g = jax.vmap(f,
in_axes=NamedMapSpec('i', 0),
out_axes=NamedMapSpec('i', 1),
axis_size=3)
in_axes=NamedMapSpec('i', 0),
out_axes=NamedMapSpec('i', 1),
axis_size=3)
ans = g(x)
expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2)
self.assertEqual(ans.names, expected.names)
self.assertAllClose(ans.data, expected.data)
def test_basic_jit(self):
with temporarily_register_named_array_vmappable():
def f(x):
return named_mul(x, x)
def test_to_elt_that_binds_primitives(self):
class A:
data: Array
def __init__(self, data):
self.data = data
def to_elt(cont, _, val, spec):
return cont(val.data + 1, spec)
def from_elt(cont, size, elt, spec):
assert False
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
ans = jax.jit(f)(x)
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
@jax.jit
def f():
a = A(jnp.arange(3.))
return jax.vmap(lambda x: x - 1, axis_size=3)(a)
self.assertEqual(ans.names, expected.names)
self.assertAllClose(ans.data, expected.data)
try:
batching.register_vmappable(A, int, int, to_elt, from_elt, None)
ans = f()
finally:
batching.unregister_vmappable(A)
self.assertAllClose(ans, jnp.arange(3.))
def test_from_elt_that_binds_primitives(self):
class A:
data: Array
def __init__(self, data):
self.data = data
def to_elt(cont, _, val, spec):
return A(cont(val.data, spec))
def from_elt(cont, size, elt, spec):
return A(cont(size, elt.data + 1, spec))
@jax.jit
def f():
a = A(jnp.arange(3.))
return jax.vmap(lambda x: x, axis_size=3)(a).data
try:
batching.register_vmappable(A, int, int, to_elt, from_elt, None)
ans = f()
finally:
batching.unregister_vmappable(A)
self.assertAllClose(ans, jnp.arange(3.) + 1)
def test_types_with_same_spec(self):
# We register NamedArray.