mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #18752 from mattjj:shmap-remat-rule
PiperOrigin-RevId: 586729063
This commit is contained in:
commit
53e66c1214
@ -27,6 +27,7 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec, Mesh
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import array
|
||||
from jax._src import callback
|
||||
@ -983,6 +984,19 @@ def _pjit_check(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _check_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||
|
||||
|
||||
@register_rewrite(ad_checkpoint.remat_p)
|
||||
def _remat_rewrite(mesh, in_rep, *args, jaxpr, **kwargs):
|
||||
jaxpr_ = pe.close_jaxpr(jaxpr)
|
||||
jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr_, in_rep)
|
||||
jaxpr, () = jaxpr_.jaxpr, jaxpr_.consts
|
||||
out_vals = ad_checkpoint.remat_p.bind(*args, jaxpr=jaxpr, **kwargs)
|
||||
return out_vals, out_rep
|
||||
|
||||
@register_check(ad_checkpoint.remat_p)
|
||||
def _remat_check(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _check_rep(mesh, jaxpr, in_rep)
|
||||
|
||||
|
||||
@register_check(core.call_p)
|
||||
def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
|
||||
return _check_rep(mesh, call_jaxpr, in_rep)
|
||||
|
@ -430,6 +430,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(x, jnp.arange(4), check_dtypes=False)
|
||||
|
||||
def test_remat_basic(self):
|
||||
# this tests remat-of-shmap
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
||||
|
||||
# check param updating is handled
|
||||
@ -451,6 +452,19 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
|
||||
self.assertAllClose(g2, jnp.cos(x), check_dtypes=False)
|
||||
|
||||
def test_shmap_of_remat_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
||||
|
||||
x = jnp.arange(4.)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
@partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable)
|
||||
def f2(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
|
||||
self.assertAllClose(g2, jnp.cos(x), check_dtypes=False)
|
||||
|
||||
def test_remat_scalar_residuals(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user