Merge pull request #19566 from mattjj:attrs-aqt

PiperOrigin-RevId: 602864008
This commit is contained in:
jax authors 2024-01-30 15:51:00 -08:00
commit 80d23d64cd
2 changed files with 60 additions and 3 deletions

View File

@ -1878,9 +1878,10 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
@lu.cache
def _pjit_transpose_trace(fun, in_avals):
transpose_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
transpose_jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
fun, in_avals)
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
return transpose_jaxpr
return transpose_jaxpr, attrs_tracked
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
@ -1901,13 +1902,20 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
for ct in primals_and_nz_cts_in)
transpose_jaxpr = _pjit_transpose_trace(body, global_cts_in_avals)
transpose_jaxpr, attrs_tracked = _pjit_transpose_trace(
body, global_cts_in_avals)
cts_out_treedef = cts_out_treedef_thunk()
transpose_out_shardings = prune_type(
ad.Zero,
in_shardings,
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
if attrs_tracked:
init_states = _get_states(attrs_tracked)
primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in]
transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings
transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings
nz_cts_out = pjit_p.bind(
*primals_and_nz_cts_in,
jaxpr=transpose_jaxpr,
@ -1918,6 +1926,9 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
name=name,
keep_unused=keep_unused,
inline=inline)
if attrs_tracked:
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
_set_states(attrs_tracked, final_states)
return tree_unflatten(cts_out_treedef, nz_cts_out)
ad.reducing_transposes[pjit_p] = _pjit_transpose

View File

@ -20,6 +20,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
@ -78,6 +79,51 @@ class AttrsTest(jtu.JaxTestCase):
double_it()
self.assertEqual(thing.x, 16.0)
def test_jit_transpose_basic(self):
thing = Thing(jnp.array(2.0))
@jax.custom_vjp
def foo(x):
return x
def foo_fwd(x):
return x, None
def foo_bwd(x, g):
jax_setattr(thing, 'x', g)
return g,
foo.defvjp(foo_fwd, foo_bwd)
foo(3.14)
self.assertEqual(thing.x, 2.0)
jax.grad(foo)(3.14)
self.assertEqual(thing.x, 1.0)
thing.x = jnp.array(3.14)
self.assertEqual(thing.x, 3.14)
jax.jit(jax.grad(foo))(3.14)
self.assertEqual(thing.x, 1.0)
thing.x = jnp.array(2.718)
self.assertEqual(thing.x, 2.718)
jax.grad(jax.jit(lambda x: jnp.sin(foo(x))))(3.0)
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
thing.x = jnp.array(3.14)
self.assertEqual(thing.x, 3.14)
def bar(x):
out = jnp.sin(foo(x))
jax_setattr(thing, 'x', 5.0)
return out
jax.grad(jax.jit(bar))(3.0)
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())