Merge pull request #18752 from mattjj:shmap-remat-rule

PiperOrigin-RevId: 586729063
This commit is contained in:
jax authors 2023-11-30 11:02:20 -08:00
commit 53e66c1214
2 changed files with 28 additions and 0 deletions

View File

@ -27,6 +27,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec, Mesh
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import array
from jax._src import callback
@ -983,6 +984,19 @@ def _pjit_check(mesh, *in_rep, jaxpr, **kwargs):
return _check_rep(mesh, jaxpr.jaxpr, in_rep)
@register_rewrite(ad_checkpoint.remat_p)
def _remat_rewrite(mesh, in_rep, *args, jaxpr, **kwargs):
jaxpr_ = pe.close_jaxpr(jaxpr)
jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr_, in_rep)
jaxpr, () = jaxpr_.jaxpr, jaxpr_.consts
out_vals = ad_checkpoint.remat_p.bind(*args, jaxpr=jaxpr, **kwargs)
return out_vals, out_rep
@register_check(ad_checkpoint.remat_p)
def _remat_check(mesh, *in_rep, jaxpr, **kwargs):
return _check_rep(mesh, jaxpr, in_rep)
@register_check(core.call_p)
def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
return _check_rep(mesh, call_jaxpr, in_rep)

View File

@ -430,6 +430,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(x, jnp.arange(4), check_dtypes=False)
def test_remat_basic(self):
# this tests remat-of-shmap
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
# check param updating is handled
@ -451,6 +452,19 @@ class ShardMapTest(jtu.JaxTestCase):
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
self.assertAllClose(g2, jnp.cos(x), check_dtypes=False)
def test_shmap_of_remat_basic(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
x = jnp.arange(4.)
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
@partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable)
def f2(x):
return jnp.sin(x)
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
self.assertAllClose(g2, jnp.cos(x), check_dtypes=False)
def test_remat_scalar_residuals(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))