[Pallas] Refactor indexing primitives to use NDIndexer abstraction

Some notes about this change:
* This change upgrades the `RefView` abstraction to store multiple indexers.
  This allows doing things like `ref.at[0].at[0]` to recursively create a view
  of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s.
* This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p)
  but does *not* generalize their rules. Most of the rules will raise a
  NotImplementedError if you use multiple `NDIndexer`s. Adding support will be
  done in a future CL.
* With the above in mind, this change only preserves existing public facing APIs
  and adding actual support will involve updating the rules.

PiperOrigin-RevId: 595229523
This commit is contained in:
Sharad Vikram 2024-01-02 15:52:57 -08:00 committed by jax authors
parent 8c5e7b26ba
commit 836563fadf
15 changed files with 819 additions and 608 deletions

View File

@ -738,15 +738,17 @@ pytype_strict_library(
name = "state_types",
srcs = [
"_src/state/__init__.py",
"_src/state/indexing.py",
"_src/state/types.py",
],
deps = [
":core",
":effects",
":pretty_printer",
":tree_util",
":typing",
":util",
],
] + py_deps("numpy"),
)
pytype_strict_library(

View File

@ -1,155 +0,0 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains shared logic and abstractions for Pallas indexing ops."""
from __future__ import annotations
import dataclasses
from typing import Any
import jax
from jax import core as jax_core
from jax import tree_util
from jax._src.interpreters import mlir
from jax._src.util import merge_lists
from jax._src.util import partition_list
import jax.numpy as jnp
import numpy as np
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
# so in the lowering, we have to expand out those squeezed dimensions again.
# Having a simple `broadcast_to` primitive allows us to lower directly
# to `tl.broadcast_to`.
broadcast_to_p = jax_core.Primitive('broadcast_to')
def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array:
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)
@broadcast_to_p.def_impl
def _broadcast_to_impl(a, *, shape):
return jnp.broadcast_to(a, shape)
@broadcast_to_p.def_abstract_eval
def _broadcast_to_abstract_eval(aval, *, shape):
return jax_core.ShapedArray(shape, aval.dtype)
mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int
def __post_init__(self):
if self.size < 0:
raise ValueError("`size` must not be negative.")
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (self.start, self.size)
return (self.start,), (self.size,)
@classmethod
def tree_unflatten(cls, aux_data, children) -> Slice:
return cls(*children, *aux_data)
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop, step = slc.indices(size)
if step != 1:
raise ValueError(f"slice must have a step of 1 (found: {step})")
return cls(start, stop - start)
def dslice(start: int | jax.Array | None, size: int | None = None
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
if start is None:
return slice(None)
if size is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, size)
ds = dslice # Handy alias
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: tuple[int | Slice | jax.Array, ...]
shape: tuple[int, ...]
int_indexer_shape: tuple[int, ...]
def __post_init__(self):
if len(self.indices) != len(self.shape):
raise ValueError("`indices` must be the same length as `Ref` shape.")
def tree_flatten(self):
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
self.int_indexer_shape)
@classmethod
def tree_unflatten(cls, data, flat_idx):
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)
@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`.")
# Pad out indices with slice(None)
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
# Convert all `slice`s to `Slice`s
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
is_int_indexing = [not isinstance(i, Slice) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [i.shape for i in int_indexers]
if indexer_shapes:
try:
bcast_shape = np.broadcast_shapes(*indexer_shapes)
except ValueError as e:
# Raise a nicer error than the NumPy one.
raise ValueError("Cannot broadcast shapes for indexing: "
f"{tuple(a for a in indexer_shapes)}") from e
else:
bcast_shape = ()
int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape)
def get_indexer_shape(self) -> tuple[int, ...]:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers] # type: ignore
return (*self.int_indexer_shape, *other_shape)

View File

@ -42,12 +42,12 @@ from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import core
from jax._src.pallas import indexing
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
from jax._src.util import safe_map
from jax._src.util import safe_zip
@ -610,12 +610,16 @@ def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
def _get_lowering_rule(
ctx: LoweringRuleContext, ref, *non_slice_idx, indexed_dims: Sequence[bool]
ctx: LoweringRuleContext, ref, *idx, tree,
):
indexers = tree_util.tree_unflatten(tree, idx)
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
if len(indexers) > 1:
raise NotImplementedError("Only one indexer currently supported.")
# Call _load_lowering_rule (since it's more general)
ref_aval, *non_slice_idx_avals = ctx.avals_in
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
nd_indexer = indexers[0]
nd_indexer_avals = indexers_avals[0]
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None))
avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None))
ctx = ctx.replace(avals_in=avals_flat)
@ -630,13 +634,17 @@ def _swap_lowering_rule(
ctx: LoweringRuleContext,
ref,
val,
*non_slice_idx,
indexed_dims: Sequence[bool],
*idx,
tree
):
indexers = tree_util.tree_unflatten(tree, idx)
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
if len(indexers) > 1:
raise NotImplementedError("Only one indexer currently supported.")
# Call _masked_swap_lowering_rule (since it's more general)
ref_aval, val_aval, *non_slice_idx_avals = ctx.avals_in
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
ref_aval, val_aval, *_ = ctx.avals_in
nd_indexer = indexers[0]
nd_indexer_avals = indexers_avals[0]
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None))
avals_flat = tree_util.tree_leaves(
(ref_aval, nd_indexer_avals, val_aval, None)

View File

@ -29,10 +29,10 @@ from jax._src import pretty_printer as pp
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.state import primitives as state_primitives
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import indexing
from jax._src.pallas.mosaic import core as tpu_core
import jax.numpy as jnp
@ -264,38 +264,6 @@ def _dma_start_abstract_eval(*args, tree, device_id_type):
del args, tree, device_id_type
return []
def _pp_slice(slc: indexing.Slice, dim: int, context: jax_core.JaxprPpContext
) -> str:
start, size = slc.start, slc.size
if isinstance(start, jax_core.Var):
start_str = jax_core.pp_var(start, context)
end_str = f'{start_str}+{size}'
else:
start_str = '' if start == 0 else str(start)
end = start + size
end_str = '' if end == dim else str(end)
return f'{start_str}:{end_str}'
def _pp_indexer(indexer: indexing.NDIndexer,
context: jax_core.JaxprPpContext) -> pp.Doc:
indices = []
for idx, dim in zip(indexer.indices, indexer.shape):
if isinstance(idx, indexing.Slice):
indices.append(_pp_slice(idx, dim, context))
else:
indices.append(jax_core.pp_var(idx, context)) # type: ignore
return pp.text(','.join(indices))
def _pp_ref(ref, indexer, context):
return state_primitives.pp_ref(
pp.concat([
pp.text(jax_core.pp_var(ref, context)),
pp.text("["),
_pp_indexer(indexer, context),
pp.text("]"),
])
)
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
context: jax_core.JaxprPpContext,
settings: jax_core.JaxprPpSettings):
@ -309,9 +277,9 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
return pp.concat([
pp.text('dma_start'),
pp.text(' '),
_pp_ref(src_ref, src_indexer, context),
sp.pp_ref_indexers(context, src_ref, (src_indexer,)),
pp.text(' -> '),
_pp_ref(dst_ref, dst_indexer, context),
sp.pp_ref_indexers(context, dst_ref, (dst_indexer,)),
pp.text(' '),
pp.text(jax_core.pp_var(dst_sem, context)),
])
@ -358,13 +326,14 @@ def _dma_wait_abstract_eval(*args, tree, device_id_type):
def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
context: jax_core.JaxprPpContext,
settings: jax_core.JaxprPpSettings):
del settings
invars = eqn.invars
tree = eqn.params["tree"]
sem, ref, indexer = tree_util.tree_unflatten(tree, invars)
return pp.concat([
pp.text('dma_wait'),
pp.text(' '),
_pp_ref(ref, indexer, context),
sp.pp_ref_indexers(context, ref, (indexer,)),
pp.text(' '),
pp.text(jax_core.pp_var(sem, context)),
])
@ -373,7 +342,10 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
def _get_ref_and_indexer(ref):
if isinstance(ref, state.RefView):
return ref.ref, ref.indexer
indexers = ref.indexers
if len(indexers) > 1:
raise NotImplementedError("Only one indexer supported.")
return ref.ref, ref.indexers[0].indices
return ref, (slice(None),) * len(ref.shape)
def make_async_copy(src_ref, dst_ref, sem):

View File

@ -27,15 +27,15 @@ from jax._src import core as jax_core
from jax._src import pretty_printer as pp
from jax._src import state
from jax._src.util import (safe_map, safe_zip)
from jax._src.state import primitives as state_primitives
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
@ -224,44 +224,13 @@ def _load_abstract_eval(*avals_flat, args_tree, **_):
load_p.def_effectful_abstract_eval(_load_abstract_eval)
def _pp_dslice(dim: int, slice: Slice, context):
size = pp.text(str(slice.size))
if isinstance(slice.start, int):
if slice.start == 0:
start = pp.text("")
else:
start = pp.text(str(slice.start))
if slice.size == dim:
end = pp.text("")
else:
end = pp.text(str(slice.start + slice.size))
else:
start = pp.text(jax_core.pp_var(slice.start, context))
end = pp.concat([start, pp.text("+"), size])
return pp.concat([start, pp.text(":"), end])
def _pp_idx(ref_aval, idx: NDIndexer, context):
docs = [
_pp_dslice(d, s, context) if isinstance(s, Slice)
else pp.text(jax_core.pp_var(s, context))
for s, d in zip(idx.indices, ref_aval.shape)]
if not docs:
return pp.text("")
doc = [docs[0]]
for d in docs[1:]:
doc.append(pp.text(","))
doc.append(d)
return pp.concat(doc)
def _load_pp_rule(eqn, context, settings):
# Pretty prints `a = load x i` as `x[i] <- a`
y, = eqn.outvars
x, idx, _, _ = eqn.params["args_tree"].unflatten(eqn.invars)
idx = _pp_idx(eqn.invars[0].aval, idx, context)
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']')
]))])
return pp.concat([
lhs, pp.text(' <- '), sp.pp_ref_indexers(context, x, (idx,))])
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
@ -339,15 +308,14 @@ def _swap_pp_rule(eqn, context, settings):
# Pretty prints `_ = swap x v i` as `x[i] <- v`
y, = eqn.outvars
x, idx, val, _ = eqn.params["args_tree"].unflatten(eqn.invars)
idx = _pp_idx(eqn.invars[0].aval, idx, context)
x_i = pp.concat([pp.text(jax_core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')])
x_i = sp.pp_ref_indexers(context, x, (idx,))
if isinstance(y, jax_core.DropVar):
return pp.concat([state_primitives.pp_ref(
x_i), pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
return pp.concat([
x_i,
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), state_primitives.pp_ref(x_i),
pp.text(' <- '), state_primitives.pp_ref(x_i),
return pp.concat([y, pp.text(', '), x_i,
pp.text(' <- '), x_i,
pp.text(', '), pp.text(jax_core.pp_var(val, context))])
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule

View File

@ -39,12 +39,12 @@ from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import hlo_helpers
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.state import AbstractRef
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.util import merge_lists
from jax._src.util import partition_list
@ -438,7 +438,7 @@ _TRITON_FN_MAPPING = {
lax.nextafter_p: tl.math.nextafter,
ad_util.add_any_p: tl.semantic.add,
# Other ops.
indexing.broadcast_to_p: tl.broadcast_to,
sp.broadcast_to_p: tl.broadcast_to,
primitives.atomic_cas_p: tl.atomic_cas,
primitives.max_contiguous_p: tl.max_contiguous,
primitives.multiple_of_p: tl.multiple_of,
@ -727,28 +727,16 @@ def _pack_indices(non_slice_idx, indexed_dims):
def _get_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims
ctx: TritonLoweringRuleContext, ptr, *idx, tree
):
indexers = tree_util.tree_unflatten(tree, idx)
if not isinstance(ptr.type, tl.pointer_type):
assert not non_slice_idx
assert len(indexers) == 0
return ptr
ref_aval, *idx_avals = ctx.avals_in
idx_avals = _pack_indices(idx_avals, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
i.shape for i in idx_avals if not isinstance(i, slice)
}
else:
int_indexer_shape = ()
idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(ref_aval.shape, idx)
)
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None))
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, None, None))
return _masked_load_lowering_rule(
ctx,
*args_flat,
@ -794,24 +782,16 @@ triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
def _swap_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
):
ref_aval, _, *idx_avals = ctx.avals_in
idx_avals = _pack_indices(idx_avals, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
i.shape for i in idx_avals if not isinstance(i, slice)
}
else:
int_indexer_shape = ()
idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(ref_aval.shape, idx)
)
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None))
indexers = tree_util.tree_unflatten(tree, idx)
if not isinstance(ptr.type, tl.pointer_type):
assert len(indexers) == 0
return ptr
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, value, None))
return _masked_swap_lowering_rule(
ctx, *args_flat, args_tree=args_tree, eviction_policy=None
)
@ -850,24 +830,17 @@ triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
def _addupdate_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
):
ref_block_info, *_ = ctx.block_infos
avals_in = ctx.avals_in
idx = _pack_indices(non_slice_idx, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
tuple(map(lambda x: x.value, i.shape)) for i in non_slice_idx
}
else:
int_indexer_shape = ()
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(avals_in[0].shape, idx)
)
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
indexers = tree_util.tree_unflatten(tree, idx)
if not isinstance(ptr.type, tl.pointer_type):
assert len(indexers) == 0
return ptr
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0].shape, ctx.builder
)
tl.atomic_add(ptr, value, _builder=ctx.builder)
return []

View File

@ -34,9 +34,11 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src.lax import slicing as lax_slicing
from jax._src.state import indexing
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.state.utils import hoist_consts_to_refs
from jax._src.typing import Array
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
partition_list, merge_lists, split_dict)
@ -144,34 +146,112 @@ def _eval_jaxpr_discharge_state(
env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge])
return out_vals + ref_vals
def _is_trivial_indexer(indexer: indexing.NDIndexer):
for s, idx in zip(indexer.shape, indexer.indices):
if not isinstance(idx, indexing.Slice):
return False
if not isinstance(idx.start, int):
return False
if idx.start:
return False
if idx.size != s:
return False
return True
def _convert_to_array_indexer(indexer: indexing.NDIndexer
) -> tuple[int | Array, ...]:
# This is the general gather case. We need to create the gather arrays.
is_integer_indexer, _, integer_indexer = (
indexing.unpack_ndindexer(indexer)
)
total_shape = indexer.get_indexer_shape()
int_indexer_shape = indexer.int_indexer_shape
slice_shape = total_shape[len(int_indexer_shape):]
slice_dims = tuple(
i + len(int_indexer_shape) for i in range(len(slice_shape))
)
slice_dim_iter = iter(slice_dims)
slice_indexer: list[Array] = []
for idx, is_int_index in zip(indexer.indices, is_integer_indexer):
if not is_int_index:
assert isinstance(idx, indexing.Slice)
slice_indices = lax.broadcasted_iota(
np.dtype("int32"), total_shape, next(slice_dim_iter)
) + idx.start
slice_indexer.append(slice_indices)
integer_indexer = tuple(
lax.expand_dims(idx, (-1,)) for idx in integer_indexer
)
continue
assert next(slice_dim_iter, None) is None
return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer))
def _maybe_convert_to_dynamic_slice(
indexer: indexing.NDIndexer,
) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None:
# An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice`
# if each of the indexers is a `Slice` or a ()-shaped value.
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
for i in indexer.indices):
return None
_convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32"))
starts = tuple(
_convert_i32(i.start) if isinstance(i, indexing.Slice)
else _convert_i32(i) for i in indexer.indices
)
sizes = tuple(
i.size if isinstance(i, indexing.Slice) else 1 for i in indexer.indices
)
squeeze_dims = tuple(
i
for i, idx in enumerate(indexer.indices)
if not isinstance(idx, indexing.Slice)
)
return starts, sizes, squeeze_dims
@register_discharge_rule(get_p)
def _get_discharge_rule(
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, *non_slice_idx,
indexed_dims: Sequence[bool]):
out_avals: Sequence[core.AbstractValue], x, *idx,
tree):
del in_avals, out_avals
y = _get_discharge(x, non_slice_idx, indexed_dims)
return (None,) * (len(non_slice_idx) + 1), y
y = _get_discharge(x, idx, tree)
return (None,) * (len(idx) + 1), y
def _get_discharge(x, idx, indexed_dims):
if not any(indexed_dims):
return x
if all(not i.shape for i in idx):
return _dynamic_index(x, idx, indexed_dims)
else:
return _prepend_gather(x, idx, indexed_dims)
def _prepend_gather(x, idx, indexed_dims):
indexer = _indexer(idx, indexed_dims)
def _prepend_gather(x, indexer):
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
return x[None][(np.array(0, 'int32'), *indexer)]
def _prepend_scatter(x, idx, indexed_dims, val, *, add=False):
indexer = _indexer(idx, indexed_dims)
def _prepend_scatter(x, indexer, val, *, add=False):
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
# However, since this is scatter, we need to remove the 1-sized dimension
# we added at the front.
if add:
return x[None].at[(0, *indexer)].add(val)[0]
return x[None].at[(0, *indexer)].set(val)[0]
def _get_discharge(x, idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if len(indexers) > 1:
raise NotImplementedError("Only single indexer is supported.")
indexer = indexers[0]
if _is_trivial_indexer(indexer):
return x
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
y = lax_slicing.dynamic_slice(x, starts, sizes)
return lax.squeeze(y, squeeze_dims)
indexer = _convert_to_array_indexer(indexer)
if indexer is None:
return x
return x[None][(np.array(0, 'int32'), *indexer)]
def _indexer(idx, indexed_dims):
idx_ = iter(idx)
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
@ -181,59 +261,59 @@ def _indexer(idx, indexed_dims):
@register_discharge_rule(swap_p)
def _swap_discharge_rule(
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
indexed_dims: Sequence[bool]):
out_avals: Sequence[core.AbstractValue], x, val, *idx,
tree):
del in_avals, out_avals
if not any(indexed_dims):
z, x_new = x, val
z, x_new = _swap_discharge(x, val, non_slice_idx, indexed_dims)
return (x_new, None) + (None,) * len(non_slice_idx), z
z, x_new = _swap_discharge(x, val, idx, tree)
return (x_new, None) + (None,) * len(idx), z
def _swap_discharge(x, val, idx, indexed_dims):
if not any(indexed_dims):
z, x_new = x, val
elif all(not i.shape for i in idx):
z = _dynamic_index(x, idx, indexed_dims)
x_new = _dynamic_update_index(x, idx, val, indexed_dims)
else:
z = _prepend_gather(x, idx, indexed_dims)
x_new = _prepend_scatter(x, idx, indexed_dims, val)
return z, x_new
def _swap_discharge(x, val, idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if len(indexers) > 1:
raise NotImplementedError("Only single indexer is supported.")
indexer = indexers[0]
if _is_trivial_indexer(indexer):
return x, val
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
val = lax.expand_dims(val, squeeze_dims)
y = lax_slicing.dynamic_update_slice(x, val, starts)
return lax.squeeze(x_old, squeeze_dims), y
indexer = _convert_to_array_indexer(indexer)
x_old = _prepend_gather(x, indexer)
return x_old, _prepend_scatter(x, indexer, val)
@register_discharge_rule(addupdate_p)
def _addupdate_discharge_rule(
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
indexed_dims: Sequence[bool]):
out_avals: Sequence[core.AbstractValue], x, val, *idx,
tree):
del in_avals, out_avals
ans = _addupdate_discharge(x, val, non_slice_idx, indexed_dims)
return (ans, None) + (None,) * len(non_slice_idx), []
ans = _addupdate_discharge(x, val, idx, tree)
return (ans, None) + (None,) * len(idx), []
def _addupdate_discharge(x, val, idx, indexed_dims):
if not any(indexed_dims):
def _addupdate_discharge(x, val, idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if len(indexers) > 1:
raise NotImplementedError("Only single indexer is supported.")
indexer = indexers[0]
if _is_trivial_indexer(indexer):
return x + val
if all(not i.shape for i in idx):
y = val + _dynamic_index(x, idx, indexed_dims)
return _dynamic_update_index(x, idx, y, indexed_dims)
else:
return _prepend_scatter(x, idx, indexed_dims, val, add=True)
def _dynamic_index(x, idx, indexed_dims):
assert isinstance(idx, (list, tuple)) and idx
idx_ = iter(idx)
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
assert next(idx_, None) is None
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
out = lax_slicing.dynamic_slice(x, starts, sizes)
return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b])
def _dynamic_update_index(x, idx, val, indexed_dims):
assert isinstance(idx, (list, tuple)) and idx
idx_ = iter(idx)
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
assert next(idx_, None) is None
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts)
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
val = lax.expand_dims(val, squeeze_dims)
y = lax_slicing.dynamic_update_slice(x, x_old + val, starts)
return y
indexer = _convert_to_array_indexer(indexer)
return _prepend_scatter(x, indexer, val, add=True)
@weakref_lru_cache
def _cached_closed_jaxpr_discharge(closed_jaxpr):

192
jax/_src/state/indexing.py Normal file
View File

@ -0,0 +1,192 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains shared logic and abstractions for Pallas indexing ops."""
from __future__ import annotations
import dataclasses
from typing import Any, Union
from jax._src import core
from jax._src import tree_util
from jax._src.typing import Array
from jax._src.util import merge_lists
from jax._src.util import partition_list
import numpy as np
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int
def __post_init__(self):
if self.size < 0:
raise ValueError("`size` must not be negative.")
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (self.start, self.size)
return (self.start,), (self.size,)
@classmethod
def tree_unflatten(cls, aux_data, children) -> Slice:
return cls(*children, *aux_data)
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop, step = slc.indices(size)
if step != 1:
raise ValueError(f"slice must have a step of 1 (found: {step})")
return cls(start, stop - start)
def dslice(start: int | Array | None, size: int | None = None
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
if start is None:
return slice(None)
if size is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, size)
ds = dslice # Handy alias
IntIndexer = Union[int, Array]
DimIndexer = Union[IntIndexer, Slice]
def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
tuple[Slice, ...],
tuple[IntIndexer, ...]]:
is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
slice_indexers, int_indexers = partition_list(
is_int_indexing, indexer.indices)
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
def _maybe_concretize(x: Any):
try:
return core.concrete_or_error(None, x)
except core.ConcretizationTypeError:
return None
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: tuple[DimIndexer, ...]
shape: tuple[int, ...]
int_indexer_shape: tuple[int, ...]
validate: bool = False
def __post_init__(self):
if not self.validate:
return
if len(self.indices) != len(self.shape):
raise ValueError(
f"`indices` must be the same length as `Ref` shape.: {self}."
)
# We validate integer indexing shapes here
for idx, s in zip(self.indices, self.shape):
if isinstance(idx, Slice):
start = idx.start
if value := _maybe_concretize(start):
if value >= s:
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
if value + idx.size > s:
raise ValueError(
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
)
continue
# The shape of indexer integers should be broadcastable up to the
# int_indexer_shape of the whole NDIndexer
if not np.shape(idx):
if (value := _maybe_concretize(idx)) and value >= s:
raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
# For ()-shaped indexers, we can broadcast no problm.
continue
# If we don't have a ()-shaped indexer, the rank must match
# int_indexer_shape
if np.ndim(idx) != len(self.int_indexer_shape):
raise ValueError(
f"Indexer must have rank {np.ndim(idx)}: {idx=} vs."
f" {self.int_indexer_shape=}"
)
# Here we check that the shapes broadcast.
try:
np.broadcast_shapes(np.shape(idx), self.int_indexer_shape)
except ValueError as e:
raise ValueError(
f"Could not broadcast integer indexer: {idx=} vs."
f" {self.int_indexer_shape=}"
) from e
def tree_flatten(self):
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
@classmethod
def tree_unflatten(cls, data, flat_idx):
idx_tree, shape, int_indexer_shape = data
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)
@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
if indices == ...:
indices = (slice(None),) * len(shape)
if not isinstance(indices, tuple):
indices = (indices,)
if any(idx is ... for idx in indices):
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
# e.g. x[0, ..., 1].
raise NotImplementedError("Ellipsis in indexer not supported yet.")
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`: "
f"{indices=}, {shape=}")
# Pad out indices with slice(None)
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
# Convert all `slice`s to `Slice`s
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
is_int_indexing = [not isinstance(i, Slice) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
indexer_shapes = [core.get_aval(i).shape for i in int_indexers]
if indexer_shapes:
try:
bcast_shape = np.broadcast_shapes(*indexer_shapes)
except ValueError as e:
# Raise a nicer error than the NumPy one.
raise ValueError("Cannot broadcast shapes for indexing: "
f"{tuple(a for a in indexer_shapes)}") from e
else:
bcast_shape = ()
# Here we use the `broadcast_to` primitive instead of composing lax
# primitives together because it is easier to lower in targets like
# Triton/Mosaic.
from jax._src.state import primitives as sp # pytype: disable=import-error
int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
def get_indexer_shape(self) -> tuple[int, ...]:
_, slice_indexers, _ = unpack_ndindexer(self)
slice_shape = [s.size for s in slice_indexers]
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)

View File

@ -23,14 +23,17 @@ import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import pretty_printer as pp
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.typing import Array
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
from jax._src.state import indexing
from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect,
AccumEffect)
from jax._src.util import safe_map, safe_zip, tuple_insert
from jax._src.util import safe_map, safe_zip
## General utilities
@ -51,41 +54,14 @@ zip, unsafe_zip = safe_zip, zip
# a:f32[3] <- x[]
get_p = core.Primitive("get")
def _get_impl(ref: AbstractRef, *idx: int, **_):
del ref, idx
def _get_impl(ref: AbstractRef, *args: Any, tree):
del ref, args, tree
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)
Indexer = tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
def _is_trivial_indexer(idx: Indexer) -> bool:
if idx is ...:
return True
if type(idx) is tuple:
if len(idx) == 0:
return True
return len(idx) == 1 and idx[0] is ...
return False
def _unpack_idx(idx: Indexer, ndim: int
) -> tuple[tuple[Array, ...], tuple[bool, ...]]:
if _is_trivial_indexer(idx):
idx = tuple(slice(None) for _ in range(ndim))
indexed_dims_ = []
non_slice_idx = []
for i in idx:
if isinstance(i, slice):
if i.start is not None or i.stop is not None or i.step is not None:
raise NotImplementedError("Reference indexing only supports trivial slices")
indexed_dims_.append(False)
else:
non_slice_idx.append(i)
indexed_dims_.append(True)
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
import jax.numpy as jnp
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
def _get_slice_output_shape(in_shape: tuple[int, ...],
idx_shapes: tuple[tuple[int, ...], ...],
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
@ -95,24 +71,30 @@ def _get_slice_output_shape(in_shape: tuple[int, ...],
shape = (*shape_prefix, *shape_suffix)
return shape
def _get_indexer(ref: AbstractRef, idx: Indexer
) -> tuple[Indexer, tuple[bool, ...]]:
if isinstance(ref.inner_aval, core.ShapedArray):
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
else:
if not _is_trivial_indexer(idx):
raise ValueError(
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
non_slice_idx, indexed_dims = (), ()
return non_slice_idx, indexed_dims
def ref_get(ref: Any, idx: Indexer) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
def _get_indexer(
ref_or_view: Any, idx: Indexer | None, function_name: str
) -> tuple[Any, tuple[indexing.NDIndexer, ...]]:
if isinstance(ref_or_view, RefView):
ref, indexers = ref_or_view.ref, ref_or_view.indexers
else:
ref, indexers = ref_or_view, ()
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `get` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)
raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
return ref, ()
if idx is None:
return ref, indexers
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
return ref, (*indexers, nd_indexer)
def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
ref, indexers = _get_indexer(ref_or_view, idx, "ref_get")
flat_indexers, tree = tree_util.tree_flatten(indexers)
return get_p.bind(ref, *flat_indexers, tree=tree)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
@ -132,22 +114,21 @@ def ref_get(ref: Any, idx: Indexer) -> Array:
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_):
del ref, value, idx
def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
del ref, value, idx, tree
raise ValueError("Cannot run stateful primitive.")
swap_p.def_impl(_swap_impl)
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array:
def ref_swap(ref_or_view: AbstractRef, idx: Indexer | None, value: Array,
*, _function_name: str = "ref_swap") -> Array:
"""Sets a `Ref`'s value and returns the original value."""
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
ref, indexers = _get_indexer(ref_or_view, idx, _function_name)
flat_indexers, tree = tree_util.tree_flatten(indexers)
return swap_p.bind(ref, value, *flat_indexers, tree=tree)
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
def ref_set(ref: AbstractRef, idx: Indexer | None, value: Array) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref, idx, value)
ref_swap(ref, idx, value, _function_name="ref_set")
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
# Semantically,
@ -163,38 +144,40 @@ def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int):
del ref, idx, value
def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
del ref, value, args, tree
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl)
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None:
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
ref, indexers = _get_indexer(ref_or_view, idx, "ref_addupdate")
flat_indexers, tree = tree_util.tree_flatten(indexers)
return addupdate_p.bind(ref, x, *flat_indexers, tree=tree)
## get/set/addupdate abstract evaluation rules
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
indexed_dims):
def _shape_after_indexing(
shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...]
) -> tuple[int, ...]:
for indexer in indexers:
# Run some simple checks that all the indexers have consistent shapes
assert indexer.shape == shape, (indexer.shape, shape)
shape = indexer.get_indexer_shape()
return shape
def _get_abstract_eval(ref_aval: AbstractRef, *args,
tree):
indexers = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
raise ValueError("`get` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
idx_shapes = tuple(i.shape for i in idx)
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
out_aval = ref_aval.inner_aval.update(shape=shape)
out_shape = _shape_after_indexing(ref_aval.shape, indexers)
out_aval = ref_aval.inner_aval.update(shape=out_shape)
else:
if idx:
if indexers:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {ReadEffect(0)})
@ -202,34 +185,29 @@ get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
*args: Any, tree):
indexers = tree_util.tree_unflatten(tree, args)
out_aval: core.AbstractValue
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
expected_output_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if expected_output_shape != val_aval.shape:
expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers)
if expected_out_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Expected shape: {expected_out_shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
f"Indices: {indexers}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype)
out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype)
else:
if idx:
raise ValueError("`swap` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if indexers:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {WriteEffect(0)})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
@ -237,93 +215,118 @@ swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
*args: Any, tree):
indexers = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
raise ValueError("`addupdate` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
slice_shape = _shape_after_indexing(ref_aval.shape, indexers)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
slice_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if slice_shape != val_aval.shape:
raise ValueError("Invalid shape for `addupdate`. "
f"Ref shape: {ref_aval.shape}. "
f"Slice shape: {slice_shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
f"Indices: {indexers}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
elif idx:
raise ValueError("`addupdate` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
else:
# Check that the indexers are valid
if indexers:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
return [], {AccumEffect(0)}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
## Pretty printing for `get` and `swap` in jaxprs
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL,
foreground=pp.Color.GREEN)
def _pp_idx(context, non_slice_idx, indexed_dims):
idx_iter = iter(non_slice_idx)
idx = ','.join(core.pp_var(next(idx_iter), context) if indexed else ':'
for indexed in indexed_dims)
assert next(idx_iter, None) is None
return pp.text(idx)
def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice
) -> str:
start, size = slc.start, slc.size
if isinstance(start, core.Var):
start_str = core.pp_var(start, context)
end_str = f'{start_str}+{size}'
else:
start_str = '' if start == 0 else str(start)
end = start + size
end_str = '' if end == dim else str(end)
return f'{start_str}:{end_str}'
def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer
) -> pp.Doc:
indices = []
for idx, dim in zip(indexer.indices, indexer.shape):
if isinstance(idx, indexing.Slice):
indices.append(_pp_slice(context, dim, idx))
else:
indices.append(core.pp_var(idx, context)) # type: ignore
return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")])
def _pp_indexers(
context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...],
):
return pp.concat(
[pp_indexer(context, indexer) for indexer in indexers]
)
def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers):
return pp_ref_var(
pp.concat([
pp.text(core.pp_var(ref, context)),
_pp_indexers(context, indexers),
])
)
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
# Pretty prints `a = get x i` as `x[i] <- a`
y, = eqn.outvars
x, *idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
x, *flat_idx = eqn.invars
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
# TODO more general get
return pp.concat([lhs, pp.text(' <- '), pp_ref(pp.concat([
pp.text(core.pp_var(x, context)), pp.text('['), idx, pp.text(']')]))])
return pp.concat([
lhs,
pp.text(' <- '),
pp_ref_indexers(context, x, indexers)
])
core.pp_eqn_rules[get_p] = _get_pp_rule
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
y, = eqn.outvars
x, v, *idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
x, v, *flat_idx = eqn.invars
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
if type(y) is core.DropVar:
# In the case of a set (ignored return value),
# pretty print `_ = swap x v i` as `x[i] <- v`
del y
return pp.concat([
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')
])), pp.text(' <- '), pp.text(core.pp_var(v, context))])
pp_ref_indexers(context, x, indexers),
pp.text(' <- '),
pp.text(core.pp_var(v, context))
])
else:
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
x_i = pp.concat([pp.text(core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')])
x_i = pp_ref_indexers(context, x, indexers)
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), pp_ref(x_i), pp.text(' <- '),
pp_ref(x_i), pp.text(', '),
return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '),
x_i, pp.text(', '),
pp.text(core.pp_var(v, context))])
core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
del settings
# pretty-print ` = addupdate x i v` as `x[i] += v`
() = eqn.outvars
x, v, *idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
x, v, *flat_idx = eqn.invars
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
return pp.concat([
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')
])), pp.text(' += '), pp.text(core.pp_var(v, context))])
pp_ref_indexers(context, x, indexers),
pp.text(' += '),
pp.text(core.pp_var(v, context))])
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules
@ -366,6 +369,7 @@ def _get_transpose(g, ref, *idx, **params):
ad.primitive_transposes[get_p] = _get_transpose
def _swap_transpose(g, ref, x, *idx, **params):
del x # old value doesn't matter anymore
# swap transpose is swap
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
return [None, x_bar] + [None] * len(idx)
@ -403,101 +407,162 @@ def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
num_idxs_to_left = sum(indexed_dims[:ref_dim])
return ref_dim - num_idxs_to_left + len(idxs_shape)
def _get_vmap(batched_args, batched_dims, *, indexed_dims):
def _batch_indexer(indexer: indexing.NDIndexer, dims,
axis_size: int,
ref_shape: tuple[int, ...],
ref_dim: int | batching.NotMapped,
idx_is_batched: bool) -> indexing.NDIndexer:
indices = indexer.indices
indices_dims = dims.indices
new_indices: list[Array | indexing.Slice | int] = []
new_integer_indexer_shape = (axis_size, *indexer.int_indexer_shape)
for idx, dim in zip(indices, indices_dims):
if idx_is_batched:
# If at least one of the idx is batched, we broadcast them all and move the
# batch dim to the front.
if isinstance(idx, indexing.Slice):
# size is static, but start can be dynamic
# Check if start is static (which it can be)
is_static_slice = len(tree_util.tree_leaves(idx)) == 0
if is_static_slice:
new_indices.append(idx)
continue
dim = dim.start
if dim is batching.not_mapped:
# Broadcasting the slice is free (the start index stays the same)
new_indices.append(idx)
else:
raise NotImplementedError(
f"No support for vmapping over nontrivial slices just yet: {idx}")
else:
# Check if we are indexing with a scalar or not. If we are indexing
# with a scalar and we are not batched, we can avoid broadcasting it.
assert hasattr(idx, "shape")
if not idx.shape:
if dim is not batching.not_mapped:
assert idx.shape == (axis_size,)
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, (0,))
new_indices.append(idx)
else:
if dim is batching.not_mapped:
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
bcast_dims)
else:
idx = batching.moveaxis(idx, dim, 0)
new_indices.append(idx)
else:
if ref_dim is not batching.not_mapped:
if not isinstance(idx, indexing.Slice):
assert hasattr(idx, "shape")
if idx.shape:
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
bcast_dims)
new_indices.append(idx)
if ref_dim is not batching.not_mapped:
iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0)
new_indices.insert(ref_dim, iota)
return indexing.NDIndexer(tuple(new_indices), ref_shape,
new_integer_indexer_shape,
validate=True)
def _get_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, *idxs = batched_args
ref_dim, *idx_dims = batched_dims
ref, *flat_idxs = batched_args
ref_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
bdim_out = 0
if idx_is_batched:
# If at least one of the idx is batched, we broadcast them all and move the
# batch dim to the front.
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
in zip(idxs, idx_dims))
idxs_shape, = {i.shape for i in idxs} or [()]
if ref_is_batched:
# If ref is batched, we are doing a `get` with an additional axis. If `idxs`
# are also batched, then we are indexing into the batch axis with an `iota`.
indexed_dims = tuple_insert(indexed_dims, ref_dim, idx_is_batched)
if idx_is_batched:
# If we have batched idx, we need to insert the new iota index. The place
# where we add in the new `iota` index is `ref_dim` so we need to compute
# what `ref_dim` *would be* if we inserted it into `idxs` instead, because
# `idxs` doesn't include the non indexed dims.
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
else:
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
return get_p.bind(ref, *idxs, indexed_dims=indexed_dims), bdim_out
idx_is_batched = any(i_dim is not batching.not_mapped
for i_dim in flat_idx_dims)
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
# TODO(sharadmv): handle vmap of multiple indexers
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(indexers)
return get_p.bind(ref, *flat_indexers, tree=tree), 0
batching.primitive_batchers[get_p] = _get_vmap
def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
def _swap_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, val, *idxs = batched_args
ref_dim, val_dim, *idx_dims = batched_dims
ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
if idx_is_batched:
# If at least one of the idx is batched, we broadcast them all and move the
# batch dim to the front.
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
in zip(idxs, idx_dims))
idxs_shape, = {i.shape for i in idxs} or [()]
if ref_is_batched and not idx_is_batched:
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
if not val_is_batched:
val = batching.broadcast(val, axis_size, 0)
val_dim = 0
val = batching.moveaxis(val, val_dim, bdim_out)
elif idx_is_batched:
assert ref_is_batched and val_is_batched
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
idx_is_batched = any(i_dim is not batching.not_mapped
for i_dim in flat_idx_dims)
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
# TODO(sharadmv): handle vmap of multiple indexers
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(indexers)
if (ref_is_batched or idx_is_batched) and not val_is_batched:
val = batching.broadcast(val, axis_size, 0)
if val_is_batched:
val = batching.moveaxis(val, val_dim, 0)
bdim_out = 0
return swap_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), bdim_out
return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0
batching.primitive_batchers[swap_p] = _swap_vmap
def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
def _addupdate_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, val, *idxs = batched_args
ref_dim, val_dim, *idx_dims = batched_dims
ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
if idx_is_batched:
# If at least one of the idx is batched, we ensure all have bdims at front.
idxs = tuple(batching.bdim_at_front(i, d, axis_size)
for i, d in zip(idxs, idx_dims))
idxs_shape, = {i.shape for i in idxs} or [()]
if ref_is_batched and not idx_is_batched:
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
if not val_is_batched:
val = batching.broadcast(val, axis_size, 0)
val_dim = 0
val = batching.moveaxis(val, val_dim, bdim_out)
elif idx_is_batched:
assert ref_is_batched and val_is_batched
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
idxs_shape, = {i.shape for i in idxs} or [()]
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
idx_is_batched = any(i_dim is not batching.not_mapped
for i_dim in flat_idx_dims)
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
# TODO(sharadmv): handle vmap of multiple indexers
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(indexers)
if (ref_is_batched or idx_is_batched) and not val_is_batched:
val = batching.broadcast(val, axis_size, 0)
if val_is_batched:
val = batching.moveaxis(val, val_dim, 0)
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), []
batching.primitive_batchers[addupdate_p] = _addupdate_vmap
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
# so in the lowering, we have to expand out those squeezed dimensions again.
# Having a simple `broadcast_to` primitive allows us to lower directly
# to `tl.broadcast_to`.
broadcast_to_p = core.Primitive('broadcast_to')
def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array:
import jax.numpy as jnp
a = jnp.asarray(a)
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)
@broadcast_to_p.def_impl
def _broadcast_to_impl(a, *, shape):
import jax.numpy as jnp
return jnp.broadcast_to(a, shape)
@broadcast_to_p.def_abstract_eval
def _broadcast_to_abstract_eval(aval, *, shape):
return core.ShapedArray(shape, aval.dtype)
mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)

View File

@ -23,6 +23,7 @@ from typing import Any, Generic, TypeVar, Union
from jax._src import core
from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src.state import indexing
from jax._src.util import safe_map, safe_zip
from jax._src.typing import Array
@ -77,21 +78,34 @@ Aval = TypeVar("Aval", bound=core.AbstractValue)
@dataclasses.dataclass
class RefIndexer:
ref: Any
ref_or_view: Any
def __getitem__(self, slc):
if not isinstance(slc, tuple):
slc = (slc,)
return RefView(self.ref, slc)
indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)
if isinstance(self.ref_or_view, RefView):
view = self.ref_or_view
return RefView(view.ref, (*view.indexers, indexer))
return RefView(self.ref_or_view, (indexer,))
Indexer = Any
@dataclasses.dataclass
class RefView:
ref: Any
indexer: Any
indexers: tuple[indexing.NDIndexer, ...]
@property
def at(self):
raise NotImplementedError("Can't call `.at` multiple times.")
def shape(self) -> tuple[int, ...]:
assert (
len(self.indexers) > 0
), "Should not be able to create a trivial RefView"
return self.indexers[-1].get_indexer_shape()
@property
def at(self) -> RefIndexer:
return RefIndexer(self)
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
@ -137,14 +151,10 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
return ref_set(tracer, idx, value)
def _getitem(self, tracer, idx) -> Array:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(tracer, idx)
def _setitem(self, tracer, idx, value) -> None:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(tracer, idx, value)

View File

@ -1472,6 +1472,7 @@ tf_not_yet_impl = [
"for",
"inspect_sharding",
"io_callback",
"broadcast_to",
"shard_map",
"global_array_to_host_local_array",
"host_local_array_to_global_array",

View File

@ -17,10 +17,6 @@
from jax._src import pallas
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.indexing import broadcast_to
from jax._src.pallas.indexing import ds
from jax._src.pallas.indexing import dslice
from jax._src.pallas.indexing import Slice
from jax._src.pallas.pallas_call import pallas_call
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.primitives import atomic_add
@ -42,6 +38,10 @@ from jax._src.pallas.utils import cdiv
from jax._src.pallas.utils import next_power_of_2
from jax._src.pallas.utils import strides_from_shape
from jax._src.pallas.utils import when
from jax._src.state.primitives import broadcast_to
from jax._src.state.indexing import ds
from jax._src.state.indexing import dslice
from jax._src.state.indexing import Slice
try:
from jax.experimental.pallas import gpu # pytype: disable=import-error

View File

@ -22,7 +22,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import util
from jax._src.pallas import indexing
from jax._src.state import indexing
import numpy as np
try:
@ -49,7 +49,8 @@ def int_indexer_strategy(dim) -> hps.SearchStrategy[int]:
@hps.composite
def slice_indexer_strategy(draw, dim) -> Slice | slice:
start = draw(int_indexer_strategy(dim))
size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max))
max_size = dim - start
size = draw(hps.integers(min_value=0, max_value=max_size))
return draw(
hps.one_of(
hps.just(Slice(start, size)), hps.just(slice(start, start + size))
@ -78,8 +79,8 @@ def indexer_strategy(draw, dim, int_indexer_shape
def nd_indexer_strategy(draw, shape) -> NDIndexer:
num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
int_indexer_shape = draw(hnp.array_shapes())
indices = [draw(indexer_strategy(dim, int_indexer_shape)) for dim
in shape[:num_indices]]
indices = tuple(draw(indexer_strategy(dim, int_indexer_shape))
for dim in shape[:num_indices])
return NDIndexer.from_indices_shape(indices, shape)
@ -97,6 +98,24 @@ class IndexerTest(parameterized.TestCase):
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_int(self):
indices = (4, 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_slice_start(self):
indices = (slice(3, 2), 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_slice_end(self):
indices = (Slice(2, 2), 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_ndindexer_with_padding(self):
indices = ()
shape = (5, 5)
@ -137,17 +156,17 @@ class IndexerTest(parameterized.TestCase):
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3))
indices = (0, slice(4, 10), np.arange(5))
indices = (0, slice(2, 10), np.arange(5))
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 0))
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 1))
indices = (0, 5, np.arange(5))
indices = (0, 1, np.arange(5))
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
indices = (ds(2, 3), np.arange(5)[:, None], np.arange(4)[None])
indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None])
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 3))
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2))
@hp.given(hps.data())
def test_ndindexer(self, data):

View File

@ -1471,7 +1471,7 @@ class PallasPrimitivesTest(PallasTest):
@parameterized.parameters(*[
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "<- a[:,:,:]"),
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "<- a[:3,:,:]"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:4,:,:4]"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:,:,:4]"),
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "<- a[b,:,:4]"),
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.ds(4)), "<- a[f,g,:4]"),
])
@ -1486,7 +1486,7 @@ class PallasPrimitivesTest(PallasTest):
@parameterized.parameters(*[
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "a[:,:,:] <-"),
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "a[:3,:,:] <-"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:4,:,:4] <-"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:,:,:4] <-"),
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "a[b,:,:4] <-"),
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), "a[m,n,:4] <-"),
])
@ -1504,7 +1504,7 @@ class PallasPrimitivesTest(PallasTest):
(lambda: (pl.dslice(0, 3), slice(None), slice(None)),
"c:i32[3,3,2], a[:3,:,:] <-"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)),
"c:i32[3,3,4], a[1:4,:,:4] <-"),
"c:i32[3,3,4], a[1:,:,:4] <-"),
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)),
"e:i32[5,3,4], a[b,:,:4] <-"),
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)),

View File

@ -56,15 +56,15 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def test_cant_eval_get_primitive(self):
with self.assertRaises(ValueError):
get_p.bind(jnp.ones(5))
get_p.bind(jnp.ones(5), tree=None)
def test_cant_eval_swap_primitive(self):
with self.assertRaises(ValueError):
swap_p.bind(jnp.ones(5), jnp.zeros(5))
swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
def test_cant_eval_addupdate_primitive(self):
with self.assertRaises(ValueError):
addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
def test_get_abstract_aval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
@ -95,11 +95,37 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_nontrivial_slice",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_nontrivial_slice2",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="get_with_ref_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(1, 3), slice(None), slice(None)),
out_shape=(2, 2, 4), out_dtype=jnp.float32,
at_indices=((0,),)),
dict(testcase_name="get_with_ref_simple_at2",
ref_shape=(6, 1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 2), slice(0, 1), slice(1, 3), slice(None), slice(None)),
out_shape=(2, 1, 2, 2, 4), out_dtype=jnp.float32,
at_indices=((slice(2, 6),),)),
dict(testcase_name="get_with_ref_multiple_at",
ref_shape=(1, 3, 5, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(None), slice(0, 2)),
out_shape=(3, 1, 2), out_dtype=jnp.float32,
at_indices=((0,), (slice(None), slice(0, 1)))),
)
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
out_dtype=None, should_error=False):
out_dtype=None, at_indices=(),
should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
def f(x_ref):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
out = ref_get(x_ref, idx)
return [out]
if should_error:
@ -160,13 +186,43 @@ class StatePrimitivesTest(jtu.JaxTestCase):
val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nontrivial_slice",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
val_shape=(2, 1, 2), val_dtype=jnp.float32,
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nontrivial_slice2",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
val_shape=(1, 2, 2, 4), val_dtype=jnp.float32,
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="swap_with_ref_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None),),
val_shape=(1, 1, 4), val_dtype=jnp.float32,
out_shape=(1, 1, 4), out_dtype=jnp.float32,
at_indices=((0,),),),
dict(testcase_name="swap_with_ref_simple_at2",
ref_shape=(4, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
at_indices=((slice(0, 2),),),),
dict(testcase_name="swap_with_ref_multiple_at2",
ref_shape=(1, 4, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
at_indices=((slice(None), slice(0, 2),), (0,)),),
)
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
should_error=False):
at_indices=(), should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
out = ref_swap(x_ref, idx, val)
return [out]
if should_error:
@ -192,13 +248,13 @@ class StatePrimitivesTest(jtu.JaxTestCase):
idx=(slice(None),), should_error=True),
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
idx=(),),
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
idx=(0,),),
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
idx=(slice(None), 0)),
@ -216,13 +272,34 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
dict(testcase_name="ref_with_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((0,),)),
dict(testcase_name="ref_with_simple_at2",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 3, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(0, 3),),)),
dict(testcase_name="ref_with_multiple_at",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(0, 3),), (0,))),
dict(testcase_name="ref_with_multiple_at2",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(None), slice(0, 3),), (0,))),
)
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
should_error=False):
val_shape, val_dtype, idx, at_indices=(), should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
ref_addupdate(x_ref, idx, val)
return []
if should_error:
@ -1595,7 +1672,6 @@ if CAN_USE_HYPOTHESIS:
def test_vjp(self, data):
spec = data.draw(func_spec())
print(spec)
def impl(x):
return spec.call((x, jnp.zeros_like(x)))[1]