Merge pull request #19441 from jakevdp:shard-alike-fix

PiperOrigin-RevId: 599929883
This commit is contained in:
jax authors 2024-01-19 13:58:12 -08:00
commit f0329bf033
3 changed files with 22 additions and 1 deletions

View File

@ -62,6 +62,7 @@ from jax._src.lax.utils import (
standard_primitive)
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@ -1375,7 +1376,14 @@ def full_like(x: ArrayLike | DuckTypedArray,
# TODO(yashkatariya): Use shard_like in tracing mode too i.e. remove the
# ArrayImpl check.
if shape is None and isinstance(x, array.ArrayImpl):
return shard_alike.shard_alike(x, val)[1]
if xla_extension_version < 227:
sharding = x.sharding # type: ignore[union-attr]
if (not dispatch.is_single_device_sharding(sharding) and
not isinstance(sharding, PmapSharding)):
return array.make_array_from_callback(
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
else:
return shard_alike.shard_alike(x, val)[1]
return val

View File

@ -23,6 +23,7 @@ from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.interpreters import batching
from jax._src.util import safe_zip
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.api_util import shaped_abstractify
from jax._src.lib.mlir import ir
@ -30,6 +31,8 @@ _next_shard_group_id = itertools.count()
def shard_alike(x, y):
"""Shards x and y alike."""
if xla_extension_version < 227:
raise ValueError("shard_alike requires jaxlib v0.4.24 or newer.")
x_flat, x_tree = tree_flatten(x)
y_flat, y_tree = tree_flatten(y)

View File

@ -50,6 +50,16 @@ def tearDownModule():
xla_bridge.get_backend.cache_clear()
class ShardAlikeDownstreamTest(jtu.JaxTestCase):
def test_full_like(self):
x = jnp.arange(16, dtype='float32').reshape(8, 2)
mesh = jtu.create_global_mesh((8,), ("i",))
x = jax.device_put(x, NamedSharding(mesh, P('i', None)))
y = jnp.full_like(x, 1)
self.assertEqual(x.sharding, y.sharding)
class ShardAlikeTest(jtu.JaxTestCase):
def setUp(self):