[sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.

This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
This commit is contained in:
Yash Katariya 2025-01-16 10:50:30 -08:00 committed by jax authors
parent 994c3f59e2
commit b23c42372b
6 changed files with 57 additions and 8 deletions

View File

@ -362,6 +362,7 @@ pytype_strict_library(
srcs = ["_src/basearray.py"],
pytype_srcs = ["_src/basearray.pyi"],
deps = [
":partition_spec",
":sharding",
"//jax/_src/lib",
] + py_deps("numpy"),

View File

@ -18,6 +18,7 @@ from typing import Any, Protocol, Union, runtime_checkable
import numpy as np
from jax._src.sharding import Sharding
from jax._src.partition_spec import PartitionSpec
# TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py.
# We redefine these here to prevent circular imports.
@ -278,7 +279,8 @@ class _IndexUpdateHelper:
class _IndexUpdateRef:
def get(self, indices_are_sorted: bool = False, unique_indices: bool = False,
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
mode: str | None = None, fill_value: StaticScalar | None = None,
out_spec: PartitionSpec | None = None) -> Array: ...
def set(self, values: Any,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...

View File

@ -6424,7 +6424,6 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
out = mlir.iota(ctx, aval_out, dimension=dimension)
if config.sharding_in_types.value:
assert aval_out.sharding == sharding
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(iota_p, _iota_lower)

View File

@ -32,6 +32,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src import mesh as mesh_lib
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -1875,6 +1876,18 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes):
else next(indices_shape_gen) for i in range(output_shape_rank))
return ans
class GatherShardingError(Exception):
pass
def _gather_sharding_rule(operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
# TODO(yashkatariya): Write a proper gather sharding rule.
if mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
return None
raise GatherShardingError(
"Use `.at[...].get(out_specs=)` to provide output PartitionSpec for the"
" gather indexing.")
def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
unique_indices, indices_are_sorted, fill_value,
@ -2056,7 +2069,7 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *,
gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',
weak_type_rule=_argnum_weak_type(0))
weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule)
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule

View File

@ -41,6 +41,8 @@ from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.numpy import array_api_metadata
from jax._src.numpy import lax_numpy
from jax._src import mesh as mesh_lib
from jax._src.pjit import auto_mode, PartitionSpec
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.ops import scatter
@ -763,7 +765,7 @@ class _IndexUpdateRef:
return f"_IndexUpdateRef({self.array!r}, {self.index!r})"
def get(self, *, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
mode=None, fill_value=None, out_spec=None):
"""Equivalent to ``x[idx]``.
Returns the value of ``x`` that would result from the NumPy-style
@ -773,10 +775,15 @@ class _IndexUpdateRef:
See :mod:`jax.ops` for details.
"""
return lax_numpy._rewriting_take(self.array, self.index,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
take = partial(lax_numpy._rewriting_take,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
if out_spec is not None:
assert isinstance(out_spec, PartitionSpec)
take = auto_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
out_specs=out_spec)
return take(self.array, self.index)
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):

View File

@ -6092,6 +6092,33 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_auto_gather_out_spec(self, mesh):
embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16),
jax.NamedSharding(mesh, P(None, 'x')))
tok = jax.device_put(jnp.arange(8 * 4).reshape(8, 4),
jax.NamedSharding(mesh, P('x', None)))
@jax.jit
def f(embed_vd, token_bt):
out = embed_vd.at[token_bt].get(out_spec=P('x', None, None))
self.assertEqual(out.shape, (8, 4, 16))
self.assertEqual(out.sharding.spec, P('x', None, None))
return out
out = f(embed, tok)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, None)))
lowered_text = f.lower(embed, tok).as_text()
self.check_wsc_in_lowered(lowered_text)
def g(x, y):
out = f(x, y)
return jnp.sum(out)
out = jax.jit(jax.grad(g))(embed, tok)
self.assertEqual(out.sharding, embed.sharding)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):