mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[sharding_in_types] If an indexing operation hits into gather_p
, error out saying to use .at[...].get(out_spec=...)
instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 716295953
This commit is contained in:
parent
994c3f59e2
commit
b23c42372b
@ -362,6 +362,7 @@ pytype_strict_library(
|
||||
srcs = ["_src/basearray.py"],
|
||||
pytype_srcs = ["_src/basearray.pyi"],
|
||||
deps = [
|
||||
":partition_spec",
|
||||
":sharding",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
|
@ -18,6 +18,7 @@ from typing import Any, Protocol, Union, runtime_checkable
|
||||
import numpy as np
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
|
||||
# TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py.
|
||||
# We redefine these here to prevent circular imports.
|
||||
@ -278,7 +279,8 @@ class _IndexUpdateHelper:
|
||||
|
||||
class _IndexUpdateRef:
|
||||
def get(self, indices_are_sorted: bool = False, unique_indices: bool = False,
|
||||
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
|
||||
mode: str | None = None, fill_value: StaticScalar | None = None,
|
||||
out_spec: PartitionSpec | None = None) -> Array: ...
|
||||
def set(self, values: Any,
|
||||
indices_are_sorted: bool = False, unique_indices: bool = False,
|
||||
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
|
||||
|
@ -6424,7 +6424,6 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
||||
out = mlir.iota(ctx, aval_out, dimension=dimension)
|
||||
if config.sharding_in_types.value:
|
||||
assert aval_out.sharding == sharding
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
@ -32,6 +32,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -1875,6 +1876,18 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes):
|
||||
else next(indices_shape_gen) for i in range(output_shape_rank))
|
||||
return ans
|
||||
|
||||
class GatherShardingError(Exception):
|
||||
pass
|
||||
|
||||
def _gather_sharding_rule(operand, indices, *, dimension_numbers,
|
||||
slice_sizes, unique_indices, indices_are_sorted,
|
||||
mode, fill_value):
|
||||
# TODO(yashkatariya): Write a proper gather sharding rule.
|
||||
if mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
|
||||
return None
|
||||
raise GatherShardingError(
|
||||
"Use `.at[...].get(out_specs=)` to provide output PartitionSpec for the"
|
||||
" gather indexing.")
|
||||
|
||||
def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
|
||||
unique_indices, indices_are_sorted, fill_value,
|
||||
@ -2056,7 +2069,7 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *,
|
||||
|
||||
gather_p = standard_primitive(
|
||||
_gather_shape_rule, _gather_dtype_rule, 'gather',
|
||||
weak_type_rule=_argnum_weak_type(0))
|
||||
weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule)
|
||||
ad.defjvp(gather_p, _gather_jvp_rule, None)
|
||||
ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
||||
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
|
@ -41,6 +41,8 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import array_api_metadata
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.pjit import auto_mode, PartitionSpec
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.ops import scatter
|
||||
@ -763,7 +765,7 @@ class _IndexUpdateRef:
|
||||
return f"_IndexUpdateRef({self.array!r}, {self.index!r})"
|
||||
|
||||
def get(self, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None, fill_value=None):
|
||||
mode=None, fill_value=None, out_spec=None):
|
||||
"""Equivalent to ``x[idx]``.
|
||||
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
@ -773,10 +775,15 @@ class _IndexUpdateRef:
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return lax_numpy._rewriting_take(self.array, self.index,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode,
|
||||
fill_value=fill_value)
|
||||
take = partial(lax_numpy._rewriting_take,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode,
|
||||
fill_value=fill_value)
|
||||
if out_spec is not None:
|
||||
assert isinstance(out_spec, PartitionSpec)
|
||||
take = auto_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
|
||||
out_specs=out_spec)
|
||||
return take(self.array, self.index)
|
||||
|
||||
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
|
@ -6092,6 +6092,33 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_auto_gather_out_spec(self, mesh):
|
||||
embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16),
|
||||
jax.NamedSharding(mesh, P(None, 'x')))
|
||||
tok = jax.device_put(jnp.arange(8 * 4).reshape(8, 4),
|
||||
jax.NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jax.jit
|
||||
def f(embed_vd, token_bt):
|
||||
out = embed_vd.at[token_bt].get(out_spec=P('x', None, None))
|
||||
self.assertEqual(out.shape, (8, 4, 16))
|
||||
self.assertEqual(out.sharding.spec, P('x', None, None))
|
||||
return out
|
||||
|
||||
out = f(embed, tok)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, None)))
|
||||
|
||||
lowered_text = f.lower(embed, tok).as_text()
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
def g(x, y):
|
||||
out = f(x, y)
|
||||
return jnp.sum(out)
|
||||
|
||||
out = jax.jit(jax.grad(g))(embed, tok)
|
||||
self.assertEqual(out.sharding, embed.sharding)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user