diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 0223efdb2..edeef78c4 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -28,6 +28,7 @@ from jax._src import source_info_util from jax._src import linear_util as lu from jax._src.partition_spec import PartitionSpec as P from jax._src.sharding_impls import NamedSharding +from jax._src import mesh as mesh_lib from jax._src.ad_util import (Zero, instantiate, SymbolicZero, replace_rule_output_symbolic_zeros, add_jaxvals, add_jaxvals_p) @@ -1103,7 +1104,11 @@ def broadcast(x, sz, axis, mesh_axis=None): x_aval = core.get_aval(x) new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis)) sharding = x_aval.sharding.with_spec(new_spec) - return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) + # TODO(dougalm, yashkatariya): Delete this context manager once we figure + # out how to ensure jaxpr arguments always have the context mesh. + with mesh_lib.set_abstract_mesh(sharding.mesh): + return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, + out_sharding=sharding) def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 4d7c5ac03..a28de971d 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2074,6 +2074,24 @@ class ShardMapTest(jtu.JaxTestCase): params = jnp.copy(arr_sharded) update_fn(params, arr_sharded) # doesn't crash + def test_shmap_close_over_unused_params_vmap(self): + mesh = jtu.create_mesh((2,), ("data",)) + + def loss_fn(params, batch): + return jnp.sum(params) + jnp.sum(batch) + + @jax.jit + def update_fn(params, batch): + def grad_fn(batch): + return jax.value_and_grad(loss_fn)(params, batch) + return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"), + out_specs=P("data"), check_rep=False)(batch) + + arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), + NamedSharding(mesh, P())) + params = jnp.copy(arr_sharded) + update_fn(params, arr_sharded) # doesn't crash + def test_sharded_prng_with_abstract_mesh(self): shape = (8, 2, 2) mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))