mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19566 from mattjj:attrs-aqt
PiperOrigin-RevId: 602864008
This commit is contained in:
commit
80d23d64cd
@ -1878,9 +1878,10 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
|
||||
|
||||
@lu.cache
|
||||
def _pjit_transpose_trace(fun, in_avals):
|
||||
transpose_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
transpose_jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
fun, in_avals)
|
||||
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
|
||||
return transpose_jaxpr
|
||||
return transpose_jaxpr, attrs_tracked
|
||||
|
||||
|
||||
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
@ -1901,13 +1902,20 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
|
||||
for ct in primals_and_nz_cts_in)
|
||||
|
||||
transpose_jaxpr = _pjit_transpose_trace(body, global_cts_in_avals)
|
||||
transpose_jaxpr, attrs_tracked = _pjit_transpose_trace(
|
||||
body, global_cts_in_avals)
|
||||
cts_out_treedef = cts_out_treedef_thunk()
|
||||
transpose_out_shardings = prune_type(
|
||||
ad.Zero,
|
||||
in_shardings,
|
||||
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
|
||||
|
||||
if attrs_tracked:
|
||||
init_states = _get_states(attrs_tracked)
|
||||
primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in]
|
||||
transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings
|
||||
transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings
|
||||
|
||||
nz_cts_out = pjit_p.bind(
|
||||
*primals_and_nz_cts_in,
|
||||
jaxpr=transpose_jaxpr,
|
||||
@ -1918,6 +1926,9 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
if attrs_tracked:
|
||||
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
|
||||
_set_states(attrs_tracked, final_states)
|
||||
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
|
@ -20,6 +20,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
@ -78,6 +79,51 @@ class AttrsTest(jtu.JaxTestCase):
|
||||
double_it()
|
||||
self.assertEqual(thing.x, 16.0)
|
||||
|
||||
def test_jit_transpose_basic(self):
|
||||
thing = Thing(jnp.array(2.0))
|
||||
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
def foo_fwd(x):
|
||||
return x, None
|
||||
|
||||
def foo_bwd(x, g):
|
||||
jax_setattr(thing, 'x', g)
|
||||
return g,
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
foo(3.14)
|
||||
self.assertEqual(thing.x, 2.0)
|
||||
|
||||
jax.grad(foo)(3.14)
|
||||
self.assertEqual(thing.x, 1.0)
|
||||
|
||||
thing.x = jnp.array(3.14)
|
||||
self.assertEqual(thing.x, 3.14)
|
||||
|
||||
jax.jit(jax.grad(foo))(3.14)
|
||||
self.assertEqual(thing.x, 1.0)
|
||||
|
||||
thing.x = jnp.array(2.718)
|
||||
self.assertEqual(thing.x, 2.718)
|
||||
|
||||
jax.grad(jax.jit(lambda x: jnp.sin(foo(x))))(3.0)
|
||||
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
|
||||
|
||||
thing.x = jnp.array(3.14)
|
||||
self.assertEqual(thing.x, 3.14)
|
||||
|
||||
def bar(x):
|
||||
out = jnp.sin(foo(x))
|
||||
jax_setattr(thing, 'x', 5.0)
|
||||
return out
|
||||
|
||||
jax.grad(jax.jit(bar))(3.0)
|
||||
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user