Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to bcd4048dd5

PiperOrigin-RevId: 729532213
This commit is contained in:
Yash Katariya 2025-02-21 08:05:02 -08:00 committed by jax authors
parent c664a0cd44
commit 66037d10e7
2 changed files with 24 additions and 1 deletions

View File

@ -28,6 +28,7 @@ from jax._src import source_info_util
from jax._src import linear_util as lu
from jax._src.partition_spec import PartitionSpec as P
from jax._src.sharding_impls import NamedSharding
from jax._src import mesh as mesh_lib
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
replace_rule_output_symbolic_zeros,
add_jaxvals, add_jaxvals_p)
@ -1103,7 +1104,11 @@ def broadcast(x, sz, axis, mesh_axis=None):
x_aval = core.get_aval(x)
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
sharding = x_aval.sharding.with_spec(new_spec)
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
# out how to ensure jaxpr arguments always have the context mesh.
with mesh_lib.set_abstract_mesh(sharding.mesh):
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims,
out_sharding=sharding)
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
if dst == jumble_axis:

View File

@ -2074,6 +2074,24 @@ class ShardMapTest(jtu.JaxTestCase):
params = jnp.copy(arr_sharded)
update_fn(params, arr_sharded) # doesn't crash
def test_shmap_close_over_unused_params_vmap(self):
mesh = jtu.create_mesh((2,), ("data",))
def loss_fn(params, batch):
return jnp.sum(params) + jnp.sum(batch)
@jax.jit
def update_fn(params, batch):
def grad_fn(batch):
return jax.value_and_grad(loss_fn)(params, batch)
return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"),
out_specs=P("data"), check_rep=False)(batch)
arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8),
NamedSharding(mesh, P()))
params = jnp.copy(arr_sharded)
update_fn(params, arr_sharded) # doesn't crash
def test_sharded_prng_with_abstract_mesh(self):
shape = (8, 2, 2)
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))