mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
[mutable-arrays] read values should have the same explicit sharding as ref
fixes #26936
This commit is contained in:
parent
5d64b3d2dd
commit
0e30a3ace9
@ -259,3 +259,26 @@ class NDIndexer:
|
||||
|
||||
def transform_dtype(self, dtype):
|
||||
return dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
# If there are no explicit axes, do nothing.
|
||||
if all(p is None for p in sharding.spec):
|
||||
return sharding
|
||||
# If there are explicit axes, we don't support changing the shape, so we
|
||||
# don't support int indexers and instead require all slices.
|
||||
if (self.int_indexer_shape or
|
||||
not all(isinstance(idx, Slice) for idx in self.indices)):
|
||||
raise TypeError("sharded ref (array reference) can only be indexed by "
|
||||
"slices, not integers")
|
||||
# Moreover, only allow trivial slice(None) slices on explicitly sharded
|
||||
# axes. Then the sharding stays the same.
|
||||
_, slice_indexers, _ = unpack_ndindexer(self)
|
||||
for i, (d, sl, s) in enumerate(zip(self.shape, slice_indexers, sharding.spec)):
|
||||
if s is None: continue
|
||||
if not (type(sl.start) is int and sl.start == 0 and
|
||||
type(sl.size) is int and sl.size == d and
|
||||
type(sl.stride) is int and sl.stride == 1):
|
||||
raise ValueError("sharded ref (array reference) can only be sliced "
|
||||
f"along unsharded axes, but ref of shape {self.shape} "
|
||||
f"was sliced on axis {i}, which is sharded like {s}")
|
||||
return sharding
|
||||
|
@ -206,6 +206,13 @@ def _dtype_after_transforming(
|
||||
return dtype
|
||||
|
||||
|
||||
def _sharding_after_transforming(sharding, transforms):
|
||||
for transform in transforms:
|
||||
sharding = transform.transform_sharding(sharding)
|
||||
assert sharding is not None
|
||||
return sharding
|
||||
|
||||
|
||||
def _get_abstract_eval(ref_aval: AbstractRef, *args,
|
||||
tree):
|
||||
transforms = tree_util.tree_unflatten(tree, args)
|
||||
@ -214,10 +221,9 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
|
||||
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
|
||||
# TODO(yashkatariya): Transform the sharding too instead of setting it to
|
||||
# None.
|
||||
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
|
||||
sharding=core.get_cur_mesh_sharding())
|
||||
out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms)
|
||||
out_aval = ref_aval.inner_aval.update(
|
||||
shape=out_shape, dtype=out_dtype, sharding=out_sharding)
|
||||
else:
|
||||
if transforms:
|
||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||
|
@ -119,6 +119,12 @@ class RefBitcaster:
|
||||
del dtype # Unused
|
||||
return self.dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
# If there are no explicit axes, do nothing.
|
||||
if all(p is None for p in sharding.spec):
|
||||
return sharding
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -166,6 +172,12 @@ class RefReshaper:
|
||||
del dtype # Unused
|
||||
return self.dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
# If there are no explicit axes, do nothing.
|
||||
if all(p is None for p in sharding.spec):
|
||||
return sharding
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Transform(Protocol):
|
||||
|
||||
@ -189,6 +201,10 @@ class Transform(Protocol):
|
||||
"""
|
||||
return dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
if all(p is None for p in sharding.spec): return sharding # no explicit axes
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefIndexer:
|
||||
|
@ -254,6 +254,23 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(s, a.sharding)
|
||||
self.assertEqual(s, y.sharding)
|
||||
|
||||
def test_explicit_sharding_after_indexing(self):
|
||||
# https://github.com/jax-ml/jax/issues/26936
|
||||
mesh = jax.make_mesh((1, 1), ('x', 'y'), explicit_axes=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
@jax.jit
|
||||
def f(x_ref):
|
||||
self.assertEqual(core.get_ty(x_ref).sharding.spec,
|
||||
core.get_ty(x_ref[...]).sharding.spec)
|
||||
y = x_ref[...] + 1
|
||||
return y
|
||||
|
||||
with jax.sharding.use_mesh(mesh):
|
||||
x = jnp.zeros((4, 4), jnp.int32, device=sharding)
|
||||
x_ref = core.mutable_array(x)
|
||||
y = f(x_ref)
|
||||
|
||||
|
||||
@jtu.with_config(jax_mutable_array_checks=True)
|
||||
class MutableArrayErrorsTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user