mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix float0 behavior inside shard_map transpose under scan.
PiperOrigin-RevId: 689512880
This commit is contained in:
parent
0d68a2bf3b
commit
9500bd451a
@ -1652,7 +1652,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite
|
||||
else x if rewrite or dtypes.dtype(x) == dtypes.float0
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = [x if type(x) is not ad.UndefinedPrimal else
|
||||
|
@ -1495,6 +1495,55 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(str(e1.primitive), 'psum2')
|
||||
self.assertEqual(str(e2.primitive), 'pbroadcast')
|
||||
|
||||
def test_transpose_float0(self):
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
s = jax.sharding.NamedSharding(mesh, P(None, 'x'))
|
||||
|
||||
# vjp that triggers float0
|
||||
@jax.custom_vjp
|
||||
def f(x, _):
|
||||
return x
|
||||
def f_fwd(x, y):
|
||||
return x, jnp.zeros(shape=y.shape, dtype=np.int32)
|
||||
def f_rev(tmp, g):
|
||||
return (g, tmp)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
# trivial vjp that consumes float0
|
||||
@jax.custom_vjp
|
||||
def g(x, y):
|
||||
return x, y
|
||||
def g_fwd(x, y):
|
||||
return jax.vjp(lambda x, y: (x, y), x, y)
|
||||
def g_bwd(vjp_fn, result):
|
||||
return vjp_fn(result)
|
||||
g.defvjp(g_fwd, g_bwd)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('x'), P()), out_specs=P())
|
||||
def f_shmapped(x, y):
|
||||
return jax.lax.psum(f(x, y).sum(), axis_name=('x'))
|
||||
|
||||
@partial(shard_map, mesh=mesh, check_rep=False,
|
||||
in_specs=P('x'), out_specs=(P('x'), P()))
|
||||
def f_shmapped2(x, y):
|
||||
return g(x, y)
|
||||
|
||||
def f_wrapper(x, y):
|
||||
x, y = jax.lax.map(lambda xs: f_shmapped2(xs[0], xs[1]), (x, y))
|
||||
return jax.lax.map(lambda xs: f_shmapped(xs[0], xs[1]), (x, y)).sum()
|
||||
|
||||
@partial(jax.jit, in_shardings=s,
|
||||
out_shardings=jax.sharding.NamedSharding(mesh, P()))
|
||||
def example(x, y):
|
||||
return jax.grad(f_wrapper, allow_int=True, argnums=(0, 1))(x, y)
|
||||
|
||||
x = np.zeros(shape=(8,16), dtype=np.float32)
|
||||
y = np.zeros(shape=(8,16), dtype=np.int32)
|
||||
# Doesn't crash.
|
||||
dx, dy = example(x, y)
|
||||
self.assertEqual(dy.dtype, jax.dtypes.float0)
|
||||
|
||||
def test_rewrite_binops(self):
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user