mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19441 from jakevdp:shard-alike-fix
PiperOrigin-RevId: 599929883
This commit is contained in:
commit
f0329bf033
@ -62,6 +62,7 @@ from jax._src.lax.utils import (
|
||||
standard_primitive)
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -1375,7 +1376,14 @@ def full_like(x: ArrayLike | DuckTypedArray,
|
||||
# TODO(yashkatariya): Use shard_like in tracing mode too i.e. remove the
|
||||
# ArrayImpl check.
|
||||
if shape is None and isinstance(x, array.ArrayImpl):
|
||||
return shard_alike.shard_alike(x, val)[1]
|
||||
if xla_extension_version < 227:
|
||||
sharding = x.sharding # type: ignore[union-attr]
|
||||
if (not dispatch.is_single_device_sharding(sharding) and
|
||||
not isinstance(sharding, PmapSharding)):
|
||||
return array.make_array_from_callback(
|
||||
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
|
||||
else:
|
||||
return shard_alike.shard_alike(x, val)[1]
|
||||
return val
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ from jax._src.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.lib.mlir import ir
|
||||
|
||||
@ -30,6 +31,8 @@ _next_shard_group_id = itertools.count()
|
||||
|
||||
def shard_alike(x, y):
|
||||
"""Shards x and y alike."""
|
||||
if xla_extension_version < 227:
|
||||
raise ValueError("shard_alike requires jaxlib v0.4.24 or newer.")
|
||||
x_flat, x_tree = tree_flatten(x)
|
||||
y_flat, y_tree = tree_flatten(y)
|
||||
|
||||
|
@ -50,6 +50,16 @@ def tearDownModule():
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
|
||||
class ShardAlikeDownstreamTest(jtu.JaxTestCase):
|
||||
|
||||
def test_full_like(self):
|
||||
x = jnp.arange(16, dtype='float32').reshape(8, 2)
|
||||
mesh = jtu.create_global_mesh((8,), ("i",))
|
||||
x = jax.device_put(x, NamedSharding(mesh, P('i', None)))
|
||||
y = jnp.full_like(x, 1)
|
||||
self.assertEqual(x.sharding, y.sharding)
|
||||
|
||||
|
||||
class ShardAlikeTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user