Fix float0 behavior inside shard_map transpose under scan.

PiperOrigin-RevId: 689512880
This commit is contained in:
Parker Schuh 2024-10-24 14:15:04 -07:00 committed by jax authors
parent 0d68a2bf3b
commit 9500bd451a
2 changed files with 50 additions and 1 deletions

View File

@ -1652,7 +1652,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite
else x if rewrite or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else

View File

@ -1495,6 +1495,55 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(str(e1.primitive), 'psum2')
self.assertEqual(str(e2.primitive), 'pbroadcast')
def test_transpose_float0(self):
mesh = jtu.create_mesh((4,), ('x',))
s = jax.sharding.NamedSharding(mesh, P(None, 'x'))
# vjp that triggers float0
@jax.custom_vjp
def f(x, _):
return x
def f_fwd(x, y):
return x, jnp.zeros(shape=y.shape, dtype=np.int32)
def f_rev(tmp, g):
return (g, tmp)
f.defvjp(f_fwd, f_rev)
# trivial vjp that consumes float0
@jax.custom_vjp
def g(x, y):
return x, y
def g_fwd(x, y):
return jax.vjp(lambda x, y: (x, y), x, y)
def g_bwd(vjp_fn, result):
return vjp_fn(result)
g.defvjp(g_fwd, g_bwd)
@partial(shard_map, mesh=mesh, in_specs=(P('x'), P()), out_specs=P())
def f_shmapped(x, y):
return jax.lax.psum(f(x, y).sum(), axis_name=('x'))
@partial(shard_map, mesh=mesh, check_rep=False,
in_specs=P('x'), out_specs=(P('x'), P()))
def f_shmapped2(x, y):
return g(x, y)
def f_wrapper(x, y):
x, y = jax.lax.map(lambda xs: f_shmapped2(xs[0], xs[1]), (x, y))
return jax.lax.map(lambda xs: f_shmapped(xs[0], xs[1]), (x, y)).sum()
@partial(jax.jit, in_shardings=s,
out_shardings=jax.sharding.NamedSharding(mesh, P()))
def example(x, y):
return jax.grad(f_wrapper, allow_int=True, argnums=(0, 1))(x, y)
x = np.zeros(shape=(8,16), dtype=np.float32)
y = np.zeros(shape=(8,16), dtype=np.int32)
# Doesn't crash.
dx, dy = example(x, y)
self.assertEqual(dy.dtype, jax.dtypes.float0)
def test_rewrite_binops(self):
mesh = jtu.create_mesh((4,), ('x',))