mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[vmappable] fix trace context bugs
to_elt must run in the parent context, while from_elt must run in the batching context. We previously had it precisely backward! Tests didn't catch it because our tests are extremely minimal, and in particular didn't check a to_elt that binds primitives.
This commit is contained in:
parent
2e62693f72
commit
6bae8c75c8
@ -616,17 +616,15 @@ def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals):
|
||||
trace = BatchTrace(parent_trace, tag, axis_data)
|
||||
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
|
||||
source_info_util.current()))
|
||||
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
|
||||
with core.set_current_trace(parent_trace):
|
||||
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
|
||||
with (core.set_current_trace(trace),
|
||||
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
|
||||
core.add_spmd_axis_names(axis_data.spmd_name)):
|
||||
outs = f(*in_tracers)
|
||||
|
||||
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
||||
out_vals = map(partial(from_elt, trace, axis_data.size,
|
||||
axis_data.explicit_mesh_axis),
|
||||
range(len(outs)), outs, out_dim_dests)
|
||||
|
||||
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
||||
out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis),
|
||||
range(len(outs)), outs, out_dim_dests)
|
||||
return out_vals, trace
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
|
@ -1328,33 +1328,70 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
|
||||
|
||||
@jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe
|
||||
class VmappableTest(jtu.JaxTestCase):
|
||||
def test_basic(self):
|
||||
@parameterized.parameters([False, True])
|
||||
def test_basic(self, jit):
|
||||
with temporarily_register_named_array_vmappable():
|
||||
def f(x):
|
||||
return named_mul(x, x)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
||||
g = jax.vmap(f,
|
||||
in_axes=NamedMapSpec('i', 0),
|
||||
out_axes=NamedMapSpec('i', 1),
|
||||
axis_size=3)
|
||||
in_axes=NamedMapSpec('i', 0),
|
||||
out_axes=NamedMapSpec('i', 1),
|
||||
axis_size=3)
|
||||
ans = g(x)
|
||||
expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2)
|
||||
|
||||
self.assertEqual(ans.names, expected.names)
|
||||
self.assertAllClose(ans.data, expected.data)
|
||||
|
||||
def test_basic_jit(self):
|
||||
with temporarily_register_named_array_vmappable():
|
||||
def f(x):
|
||||
return named_mul(x, x)
|
||||
def test_to_elt_that_binds_primitives(self):
|
||||
class A:
|
||||
data: Array
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
def to_elt(cont, _, val, spec):
|
||||
return cont(val.data + 1, spec)
|
||||
def from_elt(cont, size, elt, spec):
|
||||
assert False
|
||||
|
||||
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
||||
ans = jax.jit(f)(x)
|
||||
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
|
||||
@jax.jit
|
||||
def f():
|
||||
a = A(jnp.arange(3.))
|
||||
return jax.vmap(lambda x: x - 1, axis_size=3)(a)
|
||||
|
||||
self.assertEqual(ans.names, expected.names)
|
||||
self.assertAllClose(ans.data, expected.data)
|
||||
try:
|
||||
batching.register_vmappable(A, int, int, to_elt, from_elt, None)
|
||||
ans = f()
|
||||
finally:
|
||||
batching.unregister_vmappable(A)
|
||||
|
||||
self.assertAllClose(ans, jnp.arange(3.))
|
||||
|
||||
def test_from_elt_that_binds_primitives(self):
|
||||
class A:
|
||||
data: Array
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
def to_elt(cont, _, val, spec):
|
||||
return A(cont(val.data, spec))
|
||||
def from_elt(cont, size, elt, spec):
|
||||
return A(cont(size, elt.data + 1, spec))
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
a = A(jnp.arange(3.))
|
||||
return jax.vmap(lambda x: x, axis_size=3)(a).data
|
||||
|
||||
try:
|
||||
batching.register_vmappable(A, int, int, to_elt, from_elt, None)
|
||||
ans = f()
|
||||
finally:
|
||||
batching.unregister_vmappable(A)
|
||||
|
||||
self.assertAllClose(ans, jnp.arange(3.) + 1)
|
||||
|
||||
def test_types_with_same_spec(self):
|
||||
# We register NamedArray.
|
||||
|
Loading…
x
Reference in New Issue
Block a user