mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to bcd4048dd5
PiperOrigin-RevId: 729532213
This commit is contained in:
parent
c664a0cd44
commit
66037d10e7
@ -28,6 +28,7 @@ from jax._src import source_info_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
|
||||
replace_rule_output_symbolic_zeros,
|
||||
add_jaxvals, add_jaxvals_p)
|
||||
@ -1103,7 +1104,11 @@ def broadcast(x, sz, axis, mesh_axis=None):
|
||||
x_aval = core.get_aval(x)
|
||||
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
|
||||
sharding = x_aval.sharding.with_spec(new_spec)
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
|
||||
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
|
||||
# out how to ensure jaxpr arguments always have the context mesh.
|
||||
with mesh_lib.set_abstract_mesh(sharding.mesh):
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims,
|
||||
out_sharding=sharding)
|
||||
|
||||
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
|
||||
if dst == jumble_axis:
|
||||
|
@ -2074,6 +2074,24 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
params = jnp.copy(arr_sharded)
|
||||
update_fn(params, arr_sharded) # doesn't crash
|
||||
|
||||
def test_shmap_close_over_unused_params_vmap(self):
|
||||
mesh = jtu.create_mesh((2,), ("data",))
|
||||
|
||||
def loss_fn(params, batch):
|
||||
return jnp.sum(params) + jnp.sum(batch)
|
||||
|
||||
@jax.jit
|
||||
def update_fn(params, batch):
|
||||
def grad_fn(batch):
|
||||
return jax.value_and_grad(loss_fn)(params, batch)
|
||||
return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"),
|
||||
out_specs=P("data"), check_rep=False)(batch)
|
||||
|
||||
arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8),
|
||||
NamedSharding(mesh, P()))
|
||||
params = jnp.copy(arr_sharded)
|
||||
update_fn(params, arr_sharded) # doesn't crash
|
||||
|
||||
def test_sharded_prng_with_abstract_mesh(self):
|
||||
shape = (8, 2, 2)
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user