1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

[mutable-arrays] read values should have the same explicit sharding as ref

fixes 
This commit is contained in:
Matthew Johnson 2025-03-07 01:03:22 +00:00
parent 5d64b3d2dd
commit 0e30a3ace9
4 changed files with 66 additions and 4 deletions

@ -259,3 +259,26 @@ class NDIndexer:
def transform_dtype(self, dtype):
return dtype
def transform_sharding(self, sharding):
# If there are no explicit axes, do nothing.
if all(p is None for p in sharding.spec):
return sharding
# If there are explicit axes, we don't support changing the shape, so we
# don't support int indexers and instead require all slices.
if (self.int_indexer_shape or
not all(isinstance(idx, Slice) for idx in self.indices)):
raise TypeError("sharded ref (array reference) can only be indexed by "
"slices, not integers")
# Moreover, only allow trivial slice(None) slices on explicitly sharded
# axes. Then the sharding stays the same.
_, slice_indexers, _ = unpack_ndindexer(self)
for i, (d, sl, s) in enumerate(zip(self.shape, slice_indexers, sharding.spec)):
if s is None: continue
if not (type(sl.start) is int and sl.start == 0 and
type(sl.size) is int and sl.size == d and
type(sl.stride) is int and sl.stride == 1):
raise ValueError("sharded ref (array reference) can only be sliced "
f"along unsharded axes, but ref of shape {self.shape} "
f"was sliced on axis {i}, which is sharded like {s}")
return sharding

@ -206,6 +206,13 @@ def _dtype_after_transforming(
return dtype
def _sharding_after_transforming(sharding, transforms):
for transform in transforms:
sharding = transform.transform_sharding(sharding)
assert sharding is not None
return sharding
def _get_abstract_eval(ref_aval: AbstractRef, *args,
tree):
transforms = tree_util.tree_unflatten(tree, args)
@ -214,10 +221,9 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
if isinstance(ref_aval.inner_aval, core.ShapedArray):
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
# TODO(yashkatariya): Transform the sharding too instead of setting it to
# None.
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
sharding=core.get_cur_mesh_sharding())
out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms)
out_aval = ref_aval.inner_aval.update(
shape=out_shape, dtype=out_dtype, sharding=out_sharding)
else:
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")

@ -119,6 +119,12 @@ class RefBitcaster:
del dtype # Unused
return self.dtype
def transform_sharding(self, sharding):
# If there are no explicit axes, do nothing.
if all(p is None for p in sharding.spec):
return sharding
raise NotImplementedError
@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
@ -166,6 +172,12 @@ class RefReshaper:
del dtype # Unused
return self.dtype
def transform_sharding(self, sharding):
# If there are no explicit axes, do nothing.
if all(p is None for p in sharding.spec):
return sharding
raise NotImplementedError
class Transform(Protocol):
@ -189,6 +201,10 @@ class Transform(Protocol):
"""
return dtype
def transform_sharding(self, sharding):
if all(p is None for p in sharding.spec): return sharding # no explicit axes
raise NotImplementedError
@dataclasses.dataclass
class RefIndexer:

@ -254,6 +254,23 @@ class MutableArrayTest(jtu.JaxTestCase):
self.assertEqual(s, a.sharding)
self.assertEqual(s, y.sharding)
def test_explicit_sharding_after_indexing(self):
# https://github.com/jax-ml/jax/issues/26936
mesh = jax.make_mesh((1, 1), ('x', 'y'), explicit_axes=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
@jax.jit
def f(x_ref):
self.assertEqual(core.get_ty(x_ref).sharding.spec,
core.get_ty(x_ref[...]).sharding.spec)
y = x_ref[...] + 1
return y
with jax.sharding.use_mesh(mesh):
x = jnp.zeros((4, 4), jnp.int32, device=sharding)
x_ref = core.mutable_array(x)
y = f(x_ref)
@jtu.with_config(jax_mutable_array_checks=True)
class MutableArrayErrorsTest(jtu.JaxTestCase):