[Pallas] Refactor indexing primitives to use NDIndexer abstraction

Some notes about this change:
* This change upgrades the `RefView` abstraction to store multiple indexers.
  This allows doing things like `ref.at[0].at[0]` to recursively create a view
  of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s.
* This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p)
  but does *not* generalize their rules. Most of the rules will raise a
  NotImplementedError if you use multiple `NDIndexer`s. Adding support will be
  done in a future CL.
* With the above in mind, this change only preserves existing public facing APIs
  and adding actual support will involve updating the rules.

PiperOrigin-RevId: 595229523
This commit is contained in:
Sharad Vikram 2024-01-02 15:52:57 -08:00 committed by jax authors
parent 8c5e7b26ba
commit 836563fadf
15 changed files with 819 additions and 608 deletions

View File

@ -738,15 +738,17 @@ pytype_strict_library(
name = "state_types", name = "state_types",
srcs = [ srcs = [
"_src/state/__init__.py", "_src/state/__init__.py",
"_src/state/indexing.py",
"_src/state/types.py", "_src/state/types.py",
], ],
deps = [ deps = [
":core", ":core",
":effects", ":effects",
":pretty_printer", ":pretty_printer",
":tree_util",
":typing", ":typing",
":util", ":util",
], ] + py_deps("numpy"),
) )
pytype_strict_library( pytype_strict_library(

View File

@ -1,155 +0,0 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains shared logic and abstractions for Pallas indexing ops."""
from __future__ import annotations
import dataclasses
from typing import Any
import jax
from jax import core as jax_core
from jax import tree_util
from jax._src.interpreters import mlir
from jax._src.util import merge_lists
from jax._src.util import partition_list
import jax.numpy as jnp
import numpy as np
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
# so in the lowering, we have to expand out those squeezed dimensions again.
# Having a simple `broadcast_to` primitive allows us to lower directly
# to `tl.broadcast_to`.
broadcast_to_p = jax_core.Primitive('broadcast_to')
def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array:
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)
@broadcast_to_p.def_impl
def _broadcast_to_impl(a, *, shape):
return jnp.broadcast_to(a, shape)
@broadcast_to_p.def_abstract_eval
def _broadcast_to_abstract_eval(aval, *, shape):
return jax_core.ShapedArray(shape, aval.dtype)
mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int
def __post_init__(self):
if self.size < 0:
raise ValueError("`size` must not be negative.")
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (self.start, self.size)
return (self.start,), (self.size,)
@classmethod
def tree_unflatten(cls, aux_data, children) -> Slice:
return cls(*children, *aux_data)
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop, step = slc.indices(size)
if step != 1:
raise ValueError(f"slice must have a step of 1 (found: {step})")
return cls(start, stop - start)
def dslice(start: int | jax.Array | None, size: int | None = None
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
if start is None:
return slice(None)
if size is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, size)
ds = dslice # Handy alias
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: tuple[int | Slice | jax.Array, ...]
shape: tuple[int, ...]
int_indexer_shape: tuple[int, ...]
def __post_init__(self):
if len(self.indices) != len(self.shape):
raise ValueError("`indices` must be the same length as `Ref` shape.")
def tree_flatten(self):
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
self.int_indexer_shape)
@classmethod
def tree_unflatten(cls, data, flat_idx):
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)
@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`.")
# Pad out indices with slice(None)
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
# Convert all `slice`s to `Slice`s
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
is_int_indexing = [not isinstance(i, Slice) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [i.shape for i in int_indexers]
if indexer_shapes:
try:
bcast_shape = np.broadcast_shapes(*indexer_shapes)
except ValueError as e:
# Raise a nicer error than the NumPy one.
raise ValueError("Cannot broadcast shapes for indexing: "
f"{tuple(a for a in indexer_shapes)}") from e
else:
bcast_shape = ()
int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape)
def get_indexer_shape(self) -> tuple[int, ...]:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers] # type: ignore
return (*self.int_indexer_shape, *other_shape)

View File

@ -42,12 +42,12 @@ from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import core from jax._src.pallas import core
from jax._src.pallas import indexing
from jax._src.pallas import primitives from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax._src.state import discharge as state_discharge from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives from jax._src.state import primitives as state_primitives
from jax._src.util import safe_map from jax._src.util import safe_map
from jax._src.util import safe_zip from jax._src.util import safe_zip
@ -610,12 +610,16 @@ def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
def _get_lowering_rule( def _get_lowering_rule(
ctx: LoweringRuleContext, ref, *non_slice_idx, indexed_dims: Sequence[bool] ctx: LoweringRuleContext, ref, *idx, tree,
): ):
indexers = tree_util.tree_unflatten(tree, idx)
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
if len(indexers) > 1:
raise NotImplementedError("Only one indexer currently supported.")
# Call _load_lowering_rule (since it's more general) # Call _load_lowering_rule (since it's more general)
ref_aval, *non_slice_idx_avals = ctx.avals_in ref_aval, *non_slice_idx_avals = ctx.avals_in
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer( nd_indexer = indexers[0]
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims) nd_indexer_avals = indexers_avals[0]
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None)) args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None))
avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None)) avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None))
ctx = ctx.replace(avals_in=avals_flat) ctx = ctx.replace(avals_in=avals_flat)
@ -630,13 +634,17 @@ def _swap_lowering_rule(
ctx: LoweringRuleContext, ctx: LoweringRuleContext,
ref, ref,
val, val,
*non_slice_idx, *idx,
indexed_dims: Sequence[bool], tree
): ):
indexers = tree_util.tree_unflatten(tree, idx)
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
if len(indexers) > 1:
raise NotImplementedError("Only one indexer currently supported.")
# Call _masked_swap_lowering_rule (since it's more general) # Call _masked_swap_lowering_rule (since it's more general)
ref_aval, val_aval, *non_slice_idx_avals = ctx.avals_in ref_aval, val_aval, *_ = ctx.avals_in
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer( nd_indexer = indexers[0]
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims) nd_indexer_avals = indexers_avals[0]
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None)) args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None))
avals_flat = tree_util.tree_leaves( avals_flat = tree_util.tree_leaves(
(ref_aval, nd_indexer_avals, val_aval, None) (ref_aval, nd_indexer_avals, val_aval, None)

View File

@ -29,10 +29,10 @@ from jax._src import pretty_printer as pp
from jax._src import state from jax._src import state
from jax._src import tree_util from jax._src import tree_util
from jax._src import util from jax._src import util
from jax._src.state import primitives as state_primitives from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import indexing
from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import core as tpu_core
import jax.numpy as jnp import jax.numpy as jnp
@ -264,38 +264,6 @@ def _dma_start_abstract_eval(*args, tree, device_id_type):
del args, tree, device_id_type del args, tree, device_id_type
return [] return []
def _pp_slice(slc: indexing.Slice, dim: int, context: jax_core.JaxprPpContext
) -> str:
start, size = slc.start, slc.size
if isinstance(start, jax_core.Var):
start_str = jax_core.pp_var(start, context)
end_str = f'{start_str}+{size}'
else:
start_str = '' if start == 0 else str(start)
end = start + size
end_str = '' if end == dim else str(end)
return f'{start_str}:{end_str}'
def _pp_indexer(indexer: indexing.NDIndexer,
context: jax_core.JaxprPpContext) -> pp.Doc:
indices = []
for idx, dim in zip(indexer.indices, indexer.shape):
if isinstance(idx, indexing.Slice):
indices.append(_pp_slice(idx, dim, context))
else:
indices.append(jax_core.pp_var(idx, context)) # type: ignore
return pp.text(','.join(indices))
def _pp_ref(ref, indexer, context):
return state_primitives.pp_ref(
pp.concat([
pp.text(jax_core.pp_var(ref, context)),
pp.text("["),
_pp_indexer(indexer, context),
pp.text("]"),
])
)
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
context: jax_core.JaxprPpContext, context: jax_core.JaxprPpContext,
settings: jax_core.JaxprPpSettings): settings: jax_core.JaxprPpSettings):
@ -309,9 +277,9 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
return pp.concat([ return pp.concat([
pp.text('dma_start'), pp.text('dma_start'),
pp.text(' '), pp.text(' '),
_pp_ref(src_ref, src_indexer, context), sp.pp_ref_indexers(context, src_ref, (src_indexer,)),
pp.text(' -> '), pp.text(' -> '),
_pp_ref(dst_ref, dst_indexer, context), sp.pp_ref_indexers(context, dst_ref, (dst_indexer,)),
pp.text(' '), pp.text(' '),
pp.text(jax_core.pp_var(dst_sem, context)), pp.text(jax_core.pp_var(dst_sem, context)),
]) ])
@ -358,13 +326,14 @@ def _dma_wait_abstract_eval(*args, tree, device_id_type):
def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
context: jax_core.JaxprPpContext, context: jax_core.JaxprPpContext,
settings: jax_core.JaxprPpSettings): settings: jax_core.JaxprPpSettings):
del settings
invars = eqn.invars invars = eqn.invars
tree = eqn.params["tree"] tree = eqn.params["tree"]
sem, ref, indexer = tree_util.tree_unflatten(tree, invars) sem, ref, indexer = tree_util.tree_unflatten(tree, invars)
return pp.concat([ return pp.concat([
pp.text('dma_wait'), pp.text('dma_wait'),
pp.text(' '), pp.text(' '),
_pp_ref(ref, indexer, context), sp.pp_ref_indexers(context, ref, (indexer,)),
pp.text(' '), pp.text(' '),
pp.text(jax_core.pp_var(sem, context)), pp.text(jax_core.pp_var(sem, context)),
]) ])
@ -373,7 +342,10 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
def _get_ref_and_indexer(ref): def _get_ref_and_indexer(ref):
if isinstance(ref, state.RefView): if isinstance(ref, state.RefView):
return ref.ref, ref.indexer indexers = ref.indexers
if len(indexers) > 1:
raise NotImplementedError("Only one indexer supported.")
return ref.ref, ref.indexers[0].indices
return ref, (slice(None),) * len(ref.shape) return ref, (slice(None),) * len(ref.shape)
def make_async_copy(src_ref, dst_ref, sem): def make_async_copy(src_ref, dst_ref, sem):

View File

@ -27,15 +27,15 @@ from jax._src import core as jax_core
from jax._src import pretty_printer as pp from jax._src import pretty_printer as pp
from jax._src import state from jax._src import state
from jax._src.util import (safe_map, safe_zip) from jax._src.util import (safe_map, safe_zip)
from jax._src.state import primitives as state_primitives
from jax._src.state import discharge as state_discharge from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax.interpreters import ad from jax.interpreters import ad
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters import xla from jax.interpreters import xla
import jax.numpy as jnp import jax.numpy as jnp
from jax._src.pallas import core as pallas_core from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
# TODO(sharadmv): enable type checking # TODO(sharadmv): enable type checking
# mypy: ignore-errors # mypy: ignore-errors
@ -224,44 +224,13 @@ def _load_abstract_eval(*avals_flat, args_tree, **_):
load_p.def_effectful_abstract_eval(_load_abstract_eval) load_p.def_effectful_abstract_eval(_load_abstract_eval)
def _pp_dslice(dim: int, slice: Slice, context):
size = pp.text(str(slice.size))
if isinstance(slice.start, int):
if slice.start == 0:
start = pp.text("")
else:
start = pp.text(str(slice.start))
if slice.size == dim:
end = pp.text("")
else:
end = pp.text(str(slice.start + slice.size))
else:
start = pp.text(jax_core.pp_var(slice.start, context))
end = pp.concat([start, pp.text("+"), size])
return pp.concat([start, pp.text(":"), end])
def _pp_idx(ref_aval, idx: NDIndexer, context):
docs = [
_pp_dslice(d, s, context) if isinstance(s, Slice)
else pp.text(jax_core.pp_var(s, context))
for s, d in zip(idx.indices, ref_aval.shape)]
if not docs:
return pp.text("")
doc = [docs[0]]
for d in docs[1:]:
doc.append(pp.text(","))
doc.append(d)
return pp.concat(doc)
def _load_pp_rule(eqn, context, settings): def _load_pp_rule(eqn, context, settings):
# Pretty prints `a = load x i` as `x[i] <- a` # Pretty prints `a = load x i` as `x[i] <- a`
y, = eqn.outvars y, = eqn.outvars
x, idx, _, _ = eqn.params["args_tree"].unflatten(eqn.invars) x, idx, _, _ = eqn.params["args_tree"].unflatten(eqn.invars)
idx = _pp_idx(eqn.invars[0].aval, idx, context)
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes) lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([ return pp.concat([
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']') lhs, pp.text(' <- '), sp.pp_ref_indexers(context, x, (idx,))])
]))])
jax_core.pp_eqn_rules[load_p] = _load_pp_rule jax_core.pp_eqn_rules[load_p] = _load_pp_rule
@ -339,15 +308,14 @@ def _swap_pp_rule(eqn, context, settings):
# Pretty prints `_ = swap x v i` as `x[i] <- v` # Pretty prints `_ = swap x v i` as `x[i] <- v`
y, = eqn.outvars y, = eqn.outvars
x, idx, val, _ = eqn.params["args_tree"].unflatten(eqn.invars) x, idx, val, _ = eqn.params["args_tree"].unflatten(eqn.invars)
idx = _pp_idx(eqn.invars[0].aval, idx, context) x_i = sp.pp_ref_indexers(context, x, (idx,))
x_i = pp.concat([pp.text(jax_core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')])
if isinstance(y, jax_core.DropVar): if isinstance(y, jax_core.DropVar):
return pp.concat([state_primitives.pp_ref( return pp.concat([
x_i), pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]) x_i,
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes) y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), state_primitives.pp_ref(x_i), return pp.concat([y, pp.text(', '), x_i,
pp.text(' <- '), state_primitives.pp_ref(x_i), pp.text(' <- '), x_i,
pp.text(', '), pp.text(jax_core.pp_var(val, context))]) pp.text(', '), pp.text(jax_core.pp_var(val, context))])
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule

View File

@ -39,12 +39,12 @@ from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import hlo_helpers from jax._src.lib import hlo_helpers
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
from jax._src.pallas import primitives from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.state import AbstractRef from jax._src.state import AbstractRef
from jax._src.state import discharge from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp from jax._src.state import primitives as sp
from jax._src.util import merge_lists from jax._src.util import merge_lists
from jax._src.util import partition_list from jax._src.util import partition_list
@ -438,7 +438,7 @@ _TRITON_FN_MAPPING = {
lax.nextafter_p: tl.math.nextafter, lax.nextafter_p: tl.math.nextafter,
ad_util.add_any_p: tl.semantic.add, ad_util.add_any_p: tl.semantic.add,
# Other ops. # Other ops.
indexing.broadcast_to_p: tl.broadcast_to, sp.broadcast_to_p: tl.broadcast_to,
primitives.atomic_cas_p: tl.atomic_cas, primitives.atomic_cas_p: tl.atomic_cas,
primitives.max_contiguous_p: tl.max_contiguous, primitives.max_contiguous_p: tl.max_contiguous,
primitives.multiple_of_p: tl.multiple_of, primitives.multiple_of_p: tl.multiple_of,
@ -727,28 +727,16 @@ def _pack_indices(non_slice_idx, indexed_dims):
def _get_lowering_rule( def _get_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims ctx: TritonLoweringRuleContext, ptr, *idx, tree
): ):
indexers = tree_util.tree_unflatten(tree, idx)
if not isinstance(ptr.type, tl.pointer_type): if not isinstance(ptr.type, tl.pointer_type):
assert not non_slice_idx assert len(indexers) == 0
return ptr return ptr
if len(indexers) > 1:
ref_aval, *idx_avals = ctx.avals_in raise NotImplementedError("No support for multiple indexers yet.")
idx_avals = _pack_indices(idx_avals, indexed_dims) indexer = indexers[0]
if non_slice_idx: args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, None, None))
(int_indexer_shape,) = {
i.shape for i in idx_avals if not isinstance(i, slice)
}
else:
int_indexer_shape = ()
idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(ref_aval.shape, idx)
)
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None))
return _masked_load_lowering_rule( return _masked_load_lowering_rule(
ctx, ctx,
*args_flat, *args_flat,
@ -794,24 +782,16 @@ triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
def _swap_lowering_rule( def _swap_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
): ):
ref_aval, _, *idx_avals = ctx.avals_in indexers = tree_util.tree_unflatten(tree, idx)
idx_avals = _pack_indices(idx_avals, indexed_dims) if not isinstance(ptr.type, tl.pointer_type):
if non_slice_idx: assert len(indexers) == 0
(int_indexer_shape,) = { return ptr
i.shape for i in idx_avals if not isinstance(i, slice) if len(indexers) > 1:
} raise NotImplementedError("No support for multiple indexers yet.")
else: indexer = indexers[0]
int_indexer_shape = () args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, value, None))
idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(ref_aval.shape, idx)
)
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None))
return _masked_swap_lowering_rule( return _masked_swap_lowering_rule(
ctx, *args_flat, args_tree=args_tree, eviction_policy=None ctx, *args_flat, args_tree=args_tree, eviction_policy=None
) )
@ -850,24 +830,17 @@ triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
def _addupdate_lowering_rule( def _addupdate_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
): ):
ref_block_info, *_ = ctx.block_infos indexers = tree_util.tree_unflatten(tree, idx)
avals_in = ctx.avals_in if not isinstance(ptr.type, tl.pointer_type):
idx = _pack_indices(non_slice_idx, indexed_dims) assert len(indexers) == 0
if non_slice_idx: return ptr
(int_indexer_shape,) = { if len(indexers) > 1:
tuple(map(lambda x: x.value, i.shape)) for i in non_slice_idx raise NotImplementedError("No support for multiple indexers yet.")
} indexer = indexers[0]
else:
int_indexer_shape = ()
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(avals_in[0].shape, idx)
)
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
ptr = _compute_pointers_from_indices( ptr = _compute_pointers_from_indices(
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder ptr, ctx.block_infos[0], indexer, ctx.avals_in[0].shape, ctx.builder
) )
tl.atomic_add(ptr, value, _builder=ctx.builder) tl.atomic_add(ptr, value, _builder=ctx.builder)
return [] return []

View File

@ -34,9 +34,11 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax from jax._src.lax import lax
from jax._src.lax import slicing as lax_slicing from jax._src.lax import slicing as lax_slicing
from jax._src.state import indexing
from jax._src.state.types import AbstractRef, RefEffect from jax._src.state.types import AbstractRef, RefEffect
from jax._src.state.primitives import get_p, swap_p, addupdate_p from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.state.utils import hoist_consts_to_refs from jax._src.state.utils import hoist_consts_to_refs
from jax._src.typing import Array
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache, from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
partition_list, merge_lists, split_dict) partition_list, merge_lists, split_dict)
@ -144,34 +146,112 @@ def _eval_jaxpr_discharge_state(
env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge]) env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge])
return out_vals + ref_vals return out_vals + ref_vals
def _is_trivial_indexer(indexer: indexing.NDIndexer):
for s, idx in zip(indexer.shape, indexer.indices):
if not isinstance(idx, indexing.Slice):
return False
if not isinstance(idx.start, int):
return False
if idx.start:
return False
if idx.size != s:
return False
return True
def _convert_to_array_indexer(indexer: indexing.NDIndexer
) -> tuple[int | Array, ...]:
# This is the general gather case. We need to create the gather arrays.
is_integer_indexer, _, integer_indexer = (
indexing.unpack_ndindexer(indexer)
)
total_shape = indexer.get_indexer_shape()
int_indexer_shape = indexer.int_indexer_shape
slice_shape = total_shape[len(int_indexer_shape):]
slice_dims = tuple(
i + len(int_indexer_shape) for i in range(len(slice_shape))
)
slice_dim_iter = iter(slice_dims)
slice_indexer: list[Array] = []
for idx, is_int_index in zip(indexer.indices, is_integer_indexer):
if not is_int_index:
assert isinstance(idx, indexing.Slice)
slice_indices = lax.broadcasted_iota(
np.dtype("int32"), total_shape, next(slice_dim_iter)
) + idx.start
slice_indexer.append(slice_indices)
integer_indexer = tuple(
lax.expand_dims(idx, (-1,)) for idx in integer_indexer
)
continue
assert next(slice_dim_iter, None) is None
return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer))
def _maybe_convert_to_dynamic_slice(
indexer: indexing.NDIndexer,
) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None:
# An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice`
# if each of the indexers is a `Slice` or a ()-shaped value.
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
for i in indexer.indices):
return None
_convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32"))
starts = tuple(
_convert_i32(i.start) if isinstance(i, indexing.Slice)
else _convert_i32(i) for i in indexer.indices
)
sizes = tuple(
i.size if isinstance(i, indexing.Slice) else 1 for i in indexer.indices
)
squeeze_dims = tuple(
i
for i, idx in enumerate(indexer.indices)
if not isinstance(idx, indexing.Slice)
)
return starts, sizes, squeeze_dims
@register_discharge_rule(get_p) @register_discharge_rule(get_p)
def _get_discharge_rule( def _get_discharge_rule(
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, *non_slice_idx, out_avals: Sequence[core.AbstractValue], x, *idx,
indexed_dims: Sequence[bool]): tree):
del in_avals, out_avals del in_avals, out_avals
y = _get_discharge(x, non_slice_idx, indexed_dims) y = _get_discharge(x, idx, tree)
return (None,) * (len(non_slice_idx) + 1), y return (None,) * (len(idx) + 1), y
def _get_discharge(x, idx, indexed_dims): def _prepend_gather(x, indexer):
if not any(indexed_dims):
return x
if all(not i.shape for i in idx):
return _dynamic_index(x, idx, indexed_dims)
else:
return _prepend_gather(x, idx, indexed_dims)
def _prepend_gather(x, idx, indexed_dims):
indexer = _indexer(idx, indexed_dims)
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
return x[None][(np.array(0, 'int32'), *indexer)] return x[None][(np.array(0, 'int32'), *indexer)]
def _prepend_scatter(x, idx, indexed_dims, val, *, add=False): def _prepend_scatter(x, indexer, val, *, add=False):
indexer = _indexer(idx, indexed_dims) # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
# However, since this is scatter, we need to remove the 1-sized dimension
# we added at the front.
if add: if add:
return x[None].at[(0, *indexer)].add(val)[0] return x[None].at[(0, *indexer)].add(val)[0]
return x[None].at[(0, *indexer)].set(val)[0] return x[None].at[(0, *indexer)].set(val)[0]
def _get_discharge(x, idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if len(indexers) > 1:
raise NotImplementedError("Only single indexer is supported.")
indexer = indexers[0]
if _is_trivial_indexer(indexer):
return x
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
y = lax_slicing.dynamic_slice(x, starts, sizes)
return lax.squeeze(y, squeeze_dims)
indexer = _convert_to_array_indexer(indexer)
if indexer is None:
return x
return x[None][(np.array(0, 'int32'), *indexer)]
def _indexer(idx, indexed_dims): def _indexer(idx, indexed_dims):
idx_ = iter(idx) idx_ = iter(idx)
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims) indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
@ -181,59 +261,59 @@ def _indexer(idx, indexed_dims):
@register_discharge_rule(swap_p) @register_discharge_rule(swap_p)
def _swap_discharge_rule( def _swap_discharge_rule(
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx, out_avals: Sequence[core.AbstractValue], x, val, *idx,
indexed_dims: Sequence[bool]): tree):
del in_avals, out_avals del in_avals, out_avals
if not any(indexed_dims): z, x_new = _swap_discharge(x, val, idx, tree)
z, x_new = x, val return (x_new, None) + (None,) * len(idx), z
z, x_new = _swap_discharge(x, val, non_slice_idx, indexed_dims)
return (x_new, None) + (None,) * len(non_slice_idx), z
def _swap_discharge(x, val, idx, indexed_dims): def _swap_discharge(x, val, idx, tree):
if not any(indexed_dims): indexers = tree_util.tree_unflatten(tree, idx)
z, x_new = x, val if len(indexers) > 1:
elif all(not i.shape for i in idx): raise NotImplementedError("Only single indexer is supported.")
z = _dynamic_index(x, idx, indexed_dims) indexer = indexers[0]
x_new = _dynamic_update_index(x, idx, val, indexed_dims) if _is_trivial_indexer(indexer):
else: return x, val
z = _prepend_gather(x, idx, indexed_dims) # If everything in the indexer is a slice or ()-shaped, we can also
x_new = _prepend_scatter(x, idx, indexed_dims, val) # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
return z, x_new # We need to squeeze out the the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
val = lax.expand_dims(val, squeeze_dims)
y = lax_slicing.dynamic_update_slice(x, val, starts)
return lax.squeeze(x_old, squeeze_dims), y
indexer = _convert_to_array_indexer(indexer)
x_old = _prepend_gather(x, indexer)
return x_old, _prepend_scatter(x, indexer, val)
@register_discharge_rule(addupdate_p) @register_discharge_rule(addupdate_p)
def _addupdate_discharge_rule( def _addupdate_discharge_rule(
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx, out_avals: Sequence[core.AbstractValue], x, val, *idx,
indexed_dims: Sequence[bool]): tree):
del in_avals, out_avals del in_avals, out_avals
ans = _addupdate_discharge(x, val, non_slice_idx, indexed_dims) ans = _addupdate_discharge(x, val, idx, tree)
return (ans, None) + (None,) * len(non_slice_idx), [] return (ans, None) + (None,) * len(idx), []
def _addupdate_discharge(x, val, idx, indexed_dims): def _addupdate_discharge(x, val, idx, tree):
if not any(indexed_dims): indexers = tree_util.tree_unflatten(tree, idx)
if len(indexers) > 1:
raise NotImplementedError("Only single indexer is supported.")
indexer = indexers[0]
if _is_trivial_indexer(indexer):
return x + val return x + val
if all(not i.shape for i in idx): # If everything in the indexer is a slice or ()-shaped, we can also
y = val + _dynamic_index(x, idx, indexed_dims) # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
return _dynamic_update_index(x, idx, y, indexed_dims) # We need to squeeze out the the 1-sized slices at the end.
else: if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
return _prepend_scatter(x, idx, indexed_dims, val, add=True) starts, sizes, squeeze_dims = maybe_slice
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
def _dynamic_index(x, idx, indexed_dims): val = lax.expand_dims(val, squeeze_dims)
assert isinstance(idx, (list, tuple)) and idx y = lax_slicing.dynamic_update_slice(x, x_old + val, starts)
idx_ = iter(idx) return y
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims] indexer = _convert_to_array_indexer(indexer)
assert next(idx_, None) is None return _prepend_scatter(x, indexer, val, add=True)
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
out = lax_slicing.dynamic_slice(x, starts, sizes)
return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b])
def _dynamic_update_index(x, idx, val, indexed_dims):
assert isinstance(idx, (list, tuple)) and idx
idx_ = iter(idx)
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
assert next(idx_, None) is None
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts)
@weakref_lru_cache @weakref_lru_cache
def _cached_closed_jaxpr_discharge(closed_jaxpr): def _cached_closed_jaxpr_discharge(closed_jaxpr):

192
jax/_src/state/indexing.py Normal file
View File

@ -0,0 +1,192 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains shared logic and abstractions for Pallas indexing ops."""
from __future__ import annotations
import dataclasses
from typing import Any, Union
from jax._src import core
from jax._src import tree_util
from jax._src.typing import Array
from jax._src.util import merge_lists
from jax._src.util import partition_list
import numpy as np
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int
def __post_init__(self):
if self.size < 0:
raise ValueError("`size` must not be negative.")
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (self.start, self.size)
return (self.start,), (self.size,)
@classmethod
def tree_unflatten(cls, aux_data, children) -> Slice:
return cls(*children, *aux_data)
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop, step = slc.indices(size)
if step != 1:
raise ValueError(f"slice must have a step of 1 (found: {step})")
return cls(start, stop - start)
def dslice(start: int | Array | None, size: int | None = None
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
if start is None:
return slice(None)
if size is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, size)
ds = dslice # Handy alias
IntIndexer = Union[int, Array]
DimIndexer = Union[IntIndexer, Slice]
def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
tuple[Slice, ...],
tuple[IntIndexer, ...]]:
is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
slice_indexers, int_indexers = partition_list(
is_int_indexing, indexer.indices)
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
def _maybe_concretize(x: Any):
try:
return core.concrete_or_error(None, x)
except core.ConcretizationTypeError:
return None
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: tuple[DimIndexer, ...]
shape: tuple[int, ...]
int_indexer_shape: tuple[int, ...]
validate: bool = False
def __post_init__(self):
if not self.validate:
return
if len(self.indices) != len(self.shape):
raise ValueError(
f"`indices` must be the same length as `Ref` shape.: {self}."
)
# We validate integer indexing shapes here
for idx, s in zip(self.indices, self.shape):
if isinstance(idx, Slice):
start = idx.start
if value := _maybe_concretize(start):
if value >= s:
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
if value + idx.size > s:
raise ValueError(
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
)
continue
# The shape of indexer integers should be broadcastable up to the
# int_indexer_shape of the whole NDIndexer
if not np.shape(idx):
if (value := _maybe_concretize(idx)) and value >= s:
raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
# For ()-shaped indexers, we can broadcast no problm.
continue
# If we don't have a ()-shaped indexer, the rank must match
# int_indexer_shape
if np.ndim(idx) != len(self.int_indexer_shape):
raise ValueError(
f"Indexer must have rank {np.ndim(idx)}: {idx=} vs."
f" {self.int_indexer_shape=}"
)
# Here we check that the shapes broadcast.
try:
np.broadcast_shapes(np.shape(idx), self.int_indexer_shape)
except ValueError as e:
raise ValueError(
f"Could not broadcast integer indexer: {idx=} vs."
f" {self.int_indexer_shape=}"
) from e
def tree_flatten(self):
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
@classmethod
def tree_unflatten(cls, data, flat_idx):
idx_tree, shape, int_indexer_shape = data
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)
@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
if indices == ...:
indices = (slice(None),) * len(shape)
if not isinstance(indices, tuple):
indices = (indices,)
if any(idx is ... for idx in indices):
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
# e.g. x[0, ..., 1].
raise NotImplementedError("Ellipsis in indexer not supported yet.")
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`: "
f"{indices=}, {shape=}")
# Pad out indices with slice(None)
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
# Convert all `slice`s to `Slice`s
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
is_int_indexing = [not isinstance(i, Slice) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
indexer_shapes = [core.get_aval(i).shape for i in int_indexers]
if indexer_shapes:
try:
bcast_shape = np.broadcast_shapes(*indexer_shapes)
except ValueError as e:
# Raise a nicer error than the NumPy one.
raise ValueError("Cannot broadcast shapes for indexing: "
f"{tuple(a for a in indexer_shapes)}") from e
else:
bcast_shape = ()
# Here we use the `broadcast_to` primitive instead of composing lax
# primitives together because it is easier to lower in targets like
# Triton/Mosaic.
from jax._src.state import primitives as sp # pytype: disable=import-error
int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
def get_indexer_shape(self) -> tuple[int, ...]:
_, slice_indexers, _ = unpack_ndindexer(self)
slice_shape = [s.size for s in slice_indexers]
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)

View File

@ -23,14 +23,17 @@ import numpy as np
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
from jax._src import pretty_printer as pp from jax._src import pretty_printer as pp
from jax._src import tree_util
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.lax import lax from jax._src.lax import lax
from jax._src.typing import Array from jax._src.typing import Array
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, from jax._src.state import indexing
from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect,
AccumEffect) AccumEffect)
from jax._src.util import safe_map, safe_zip, tuple_insert from jax._src.util import safe_map, safe_zip
## General utilities ## General utilities
@ -51,41 +54,14 @@ zip, unsafe_zip = safe_zip, zip
# a:f32[3] <- x[] # a:f32[3] <- x[]
get_p = core.Primitive("get") get_p = core.Primitive("get")
def _get_impl(ref: AbstractRef, *idx: int, **_): def _get_impl(ref: AbstractRef, *args: Any, tree):
del ref, idx del ref, args, tree
raise ValueError("Cannot run stateful primitive.") raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl) get_p.def_impl(_get_impl)
Indexer = tuple[Union[int, slice, Array], ...] Indexer = tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType) # or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
def _is_trivial_indexer(idx: Indexer) -> bool:
if idx is ...:
return True
if type(idx) is tuple:
if len(idx) == 0:
return True
return len(idx) == 1 and idx[0] is ...
return False
def _unpack_idx(idx: Indexer, ndim: int
) -> tuple[tuple[Array, ...], tuple[bool, ...]]:
if _is_trivial_indexer(idx):
idx = tuple(slice(None) for _ in range(ndim))
indexed_dims_ = []
non_slice_idx = []
for i in idx:
if isinstance(i, slice):
if i.start is not None or i.stop is not None or i.step is not None:
raise NotImplementedError("Reference indexing only supports trivial slices")
indexed_dims_.append(False)
else:
non_slice_idx.append(i)
indexed_dims_.append(True)
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
import jax.numpy as jnp
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
def _get_slice_output_shape(in_shape: tuple[int, ...], def _get_slice_output_shape(in_shape: tuple[int, ...],
idx_shapes: tuple[tuple[int, ...], ...], idx_shapes: tuple[tuple[int, ...], ...],
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]: indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
@ -95,24 +71,30 @@ def _get_slice_output_shape(in_shape: tuple[int, ...],
shape = (*shape_prefix, *shape_suffix) shape = (*shape_prefix, *shape_suffix)
return shape return shape
def _get_indexer(ref: AbstractRef, idx: Indexer
) -> tuple[Indexer, tuple[bool, ...]]:
if isinstance(ref.inner_aval, core.ShapedArray):
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
else:
if not _is_trivial_indexer(idx):
raise ValueError(
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
non_slice_idx, indexed_dims = (), ()
return non_slice_idx, indexed_dims
def ref_get(ref: Any, idx: Indexer) -> Array: def _get_indexer(
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" ref_or_view: Any, idx: Indexer | None, function_name: str
) -> tuple[Any, tuple[indexing.NDIndexer, ...]]:
if isinstance(ref_or_view, RefView):
ref, indexers = ref_or_view.ref, ref_or_view.indexers
else:
ref, indexers = ref_or_view, ()
ref_aval = core.get_aval(ref) ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef): if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `get` on a `Ref`: {ref}") raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
non_slice_idx, indexed_dims = _get_indexer(ref, idx) if not isinstance(ref_aval.inner_aval, core.ShapedArray):
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims) return ref, ()
if idx is None:
return ref, indexers
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
return ref, (*indexers, nd_indexer)
def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
ref, indexers = _get_indexer(ref_or_view, idx, "ref_get")
flat_indexers, tree = tree_util.tree_flatten(indexers)
return get_p.bind(ref, *flat_indexers, tree=tree)
# `swap` mutates a `Ref`, setting its value and returns its previous value. # `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a) # b = swap_p.bind(x, a)
@ -132,22 +114,21 @@ def ref_get(ref: Any, idx: Indexer) -> Array:
# x:Ref{f32[3]}[i, j] <- a # x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap") swap_p = core.Primitive("swap")
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_): def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
del ref, value, idx del ref, value, idx, tree
raise ValueError("Cannot run stateful primitive.") raise ValueError("Cannot run stateful primitive.")
swap_p.def_impl(_swap_impl) swap_p.def_impl(_swap_impl)
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array: def ref_swap(ref_or_view: AbstractRef, idx: Indexer | None, value: Array,
*, _function_name: str = "ref_swap") -> Array:
"""Sets a `Ref`'s value and returns the original value.""" """Sets a `Ref`'s value and returns the original value."""
ref_aval = core.get_aval(ref) ref, indexers = _get_indexer(ref_or_view, idx, _function_name)
if not isinstance(ref_aval, AbstractRef): flat_indexers, tree = tree_util.tree_flatten(indexers)
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}") return swap_p.bind(ref, value, *flat_indexers, tree=tree)
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None: def ref_set(ref: AbstractRef, idx: Indexer | None, value: Array) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" """Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref, idx, value) ref_swap(ref, idx, value, _function_name="ref_set")
# `addupdate_p` mutates a `Ref`, adding a value to its existing value. # `addupdate_p` mutates a `Ref`, adding a value to its existing value.
# Semantically, # Semantically,
@ -163,38 +144,40 @@ def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
addupdate_p = core.Primitive('addupdate') addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True addupdate_p.multiple_results = True
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int): def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
del ref, idx, value del ref, value, args, tree
raise ValueError("Can't evaluate `addupdate` outside a stateful context.") raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl) addupdate_p.def_impl(_addupdate_impl)
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None: def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`.""" """Mutates a ref with an additive update i.e. `ref[idx] += x`."""
ref_aval = core.get_aval(ref) ref, indexers = _get_indexer(ref_or_view, idx, "ref_addupdate")
if not isinstance(ref_aval, AbstractRef): flat_indexers, tree = tree_util.tree_flatten(indexers)
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}") return addupdate_p.bind(ref, x, *flat_indexers, tree=tree)
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
## get/set/addupdate abstract evaluation rules ## get/set/addupdate abstract evaluation rules
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
indexed_dims): def _shape_after_indexing(
shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...]
) -> tuple[int, ...]:
for indexer in indexers:
# Run some simple checks that all the indexers have consistent shapes
assert indexer.shape == shape, (indexer.shape, shape)
shape = indexer.get_indexer_shape()
return shape
def _get_abstract_eval(ref_aval: AbstractRef, *args,
tree):
indexers = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef): if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray): if isinstance(ref_aval.inner_aval, core.ShapedArray):
if not isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_indexing(ref_aval.shape, indexers)
raise ValueError("`get` with nontrivial indexing must be called " out_aval = ref_aval.inner_aval.update(shape=out_shape)
f"on `ShapedArray` `Ref`: {ref_aval}.")
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
idx_shapes = tuple(i.shape for i in idx)
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
out_aval = ref_aval.inner_aval.update(shape=shape)
else: else:
if idx: if indexers:
raise ValueError("Cannot index non-shaped array with nontrivial indices.") raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval out_aval = ref_aval.inner_aval
return (out_aval, {ReadEffect(0)}) return (out_aval, {ReadEffect(0)})
@ -202,34 +185,29 @@ get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: AbstractRef, def _swap_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue, val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: tuple[bool]): *args: Any, tree):
indexers = tree_util.tree_unflatten(tree, args)
out_aval: core.AbstractValue out_aval: core.AbstractValue
if not isinstance(ref_aval, AbstractRef): if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray): if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval) val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray) assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx) expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers)
expected_output_shape = _get_slice_output_shape( if expected_out_shape != val_aval.shape:
ref_aval.shape, idx_shapes, indexed_dims)
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. " raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. " f"Ref shape: {ref_aval.shape}. "
f"Expected shape: {expected_out_shape}. "
f"Value shape: {val_aval.shape}. " f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ") f"Indices: {indexers}. ")
if ref_aval.dtype != val_aval.dtype: if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. " raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. " f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ") f"Value shape: {val_aval.dtype}. ")
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype) out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype)
else: else:
if idx: if indexers:
raise ValueError("`swap` with nontrivial indexing must be called " raise ValueError("Cannot index non-shaped array with nontrivial indices.")
f"on `ShapedArray` `Ref`: {ref_aval}.")
out_aval = ref_aval.inner_aval out_aval = ref_aval.inner_aval
return (out_aval, {WriteEffect(0)}) return (out_aval, {WriteEffect(0)})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval) swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
@ -237,93 +215,118 @@ swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: AbstractRef, def _addupdate_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue, val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: tuple[bool]): *args: Any, tree):
indexers = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef): if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
raise ValueError("`addupdate` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray): if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval) val_aval = core.raise_to_shaped(val_aval)
slice_shape = _shape_after_indexing(ref_aval.shape, indexers)
assert isinstance(val_aval, core.ShapedArray) assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
slice_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if slice_shape != val_aval.shape: if slice_shape != val_aval.shape:
raise ValueError("Invalid shape for `addupdate`. " raise ValueError("Invalid shape for `addupdate`. "
f"Ref shape: {ref_aval.shape}. " f"Ref shape: {ref_aval.shape}. "
f"Slice shape: {slice_shape}. "
f"Value shape: {val_aval.shape}. " f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ") f"Indices: {indexers}. ")
if ref_aval.dtype != val_aval.dtype: if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `addupdate`. " raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. " f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ") f"Value shape: {val_aval.dtype}. ")
elif idx: else:
raise ValueError("`addupdate` with nontrivial indexing must be called " # Check that the indexers are valid
f"on `ShapedArray` `Ref`: {ref_aval}.") if indexers:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
return [], {AccumEffect(0)} return [], {AccumEffect(0)}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
## Pretty printing for `get` and `swap` in jaxprs ## Pretty printing for `get` and `swap` in jaxprs
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL, pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL,
foreground=pp.Color.GREEN) foreground=pp.Color.GREEN)
def _pp_idx(context, non_slice_idx, indexed_dims): def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice
idx_iter = iter(non_slice_idx) ) -> str:
idx = ','.join(core.pp_var(next(idx_iter), context) if indexed else ':' start, size = slc.start, slc.size
for indexed in indexed_dims) if isinstance(start, core.Var):
assert next(idx_iter, None) is None start_str = core.pp_var(start, context)
return pp.text(idx) end_str = f'{start_str}+{size}'
else:
start_str = '' if start == 0 else str(start)
end = start + size
end_str = '' if end == dim else str(end)
return f'{start_str}:{end_str}'
def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer
) -> pp.Doc:
indices = []
for idx, dim in zip(indexer.indices, indexer.shape):
if isinstance(idx, indexing.Slice):
indices.append(_pp_slice(context, dim, idx))
else:
indices.append(core.pp_var(idx, context)) # type: ignore
return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")])
def _pp_indexers(
context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...],
):
return pp.concat(
[pp_indexer(context, indexer) for indexer in indexers]
)
def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers):
return pp_ref_var(
pp.concat([
pp.text(core.pp_var(ref, context)),
_pp_indexers(context, indexers),
])
)
def _get_pp_rule(eqn, context, settings) -> pp.Doc: def _get_pp_rule(eqn, context, settings) -> pp.Doc:
# Pretty prints `a = get x i` as `x[i] <- a` # Pretty prints `a = get x i` as `x[i] <- a`
y, = eqn.outvars y, = eqn.outvars
x, *idx = eqn.invars x, *flat_idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"]) indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
# TODO more general get return pp.concat([
return pp.concat([lhs, pp.text(' <- '), pp_ref(pp.concat([ lhs,
pp.text(core.pp_var(x, context)), pp.text('['), idx, pp.text(']')]))]) pp.text(' <- '),
pp_ref_indexers(context, x, indexers)
])
core.pp_eqn_rules[get_p] = _get_pp_rule core.pp_eqn_rules[get_p] = _get_pp_rule
def _swap_pp_rule(eqn, context, settings) -> pp.Doc: def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
y, = eqn.outvars y, = eqn.outvars
x, v, *idx = eqn.invars x, v, *flat_idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"]) indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
if type(y) is core.DropVar: if type(y) is core.DropVar:
# In the case of a set (ignored return value), # In the case of a set (ignored return value),
# pretty print `_ = swap x v i` as `x[i] <- v` # pretty print `_ = swap x v i` as `x[i] <- v`
del y del y
return pp.concat([ return pp.concat([
pp_ref(pp.concat([ pp_ref_indexers(context, x, indexers),
pp.text(core.pp_var(x, context)), pp.text(' <- '),
pp.text('['), idx, pp.text(']') pp.text(core.pp_var(v, context))
])), pp.text(' <- '), pp.text(core.pp_var(v, context))]) ])
else: else:
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
x_i = pp.concat([pp.text(core.pp_var(x, context)), x_i = pp_ref_indexers(context, x, indexers)
pp.text('['), idx, pp.text(']')])
y = core.pp_vars([y], context, print_shapes=settings.print_shapes) y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), pp_ref(x_i), pp.text(' <- '), return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '),
pp_ref(x_i), pp.text(', '), x_i, pp.text(', '),
pp.text(core.pp_var(v, context))]) pp.text(core.pp_var(v, context))])
core.pp_eqn_rules[swap_p] = _swap_pp_rule core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc: def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
del settings
# pretty-print ` = addupdate x i v` as `x[i] += v` # pretty-print ` = addupdate x i v` as `x[i] += v`
() = eqn.outvars () = eqn.outvars
x, v, *idx = eqn.invars x, v, *flat_idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"]) indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
return pp.concat([ return pp.concat([
pp_ref(pp.concat([ pp_ref_indexers(context, x, indexers),
pp.text(core.pp_var(x, context)), pp.text(' += '),
pp.text('['), idx, pp.text(']') pp.text(core.pp_var(v, context))])
])), pp.text(' += '), pp.text(core.pp_var(v, context))])
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules ## get/swap/addupdate JVP rules
@ -366,6 +369,7 @@ def _get_transpose(g, ref, *idx, **params):
ad.primitive_transposes[get_p] = _get_transpose ad.primitive_transposes[get_p] = _get_transpose
def _swap_transpose(g, ref, x, *idx, **params): def _swap_transpose(g, ref, x, *idx, **params):
del x # old value doesn't matter anymore
# swap transpose is swap # swap transpose is swap
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params) x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
return [None, x_bar] + [None] * len(idx) return [None, x_bar] + [None] * len(idx)
@ -403,101 +407,162 @@ def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
num_idxs_to_left = sum(indexed_dims[:ref_dim]) num_idxs_to_left = sum(indexed_dims[:ref_dim])
return ref_dim - num_idxs_to_left + len(idxs_shape) return ref_dim - num_idxs_to_left + len(idxs_shape)
def _get_vmap(batched_args, batched_dims, *, indexed_dims): def _batch_indexer(indexer: indexing.NDIndexer, dims,
axis_size: int,
ref_shape: tuple[int, ...],
ref_dim: int | batching.NotMapped,
idx_is_batched: bool) -> indexing.NDIndexer:
indices = indexer.indices
indices_dims = dims.indices
new_indices: list[Array | indexing.Slice | int] = []
new_integer_indexer_shape = (axis_size, *indexer.int_indexer_shape)
for idx, dim in zip(indices, indices_dims):
if idx_is_batched:
# If at least one of the idx is batched, we broadcast them all and move the
# batch dim to the front.
if isinstance(idx, indexing.Slice):
# size is static, but start can be dynamic
# Check if start is static (which it can be)
is_static_slice = len(tree_util.tree_leaves(idx)) == 0
if is_static_slice:
new_indices.append(idx)
continue
dim = dim.start
if dim is batching.not_mapped:
# Broadcasting the slice is free (the start index stays the same)
new_indices.append(idx)
else:
raise NotImplementedError(
f"No support for vmapping over nontrivial slices just yet: {idx}")
else:
# Check if we are indexing with a scalar or not. If we are indexing
# with a scalar and we are not batched, we can avoid broadcasting it.
assert hasattr(idx, "shape")
if not idx.shape:
if dim is not batching.not_mapped:
assert idx.shape == (axis_size,)
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, (0,))
new_indices.append(idx)
else:
if dim is batching.not_mapped:
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
bcast_dims)
else:
idx = batching.moveaxis(idx, dim, 0)
new_indices.append(idx)
else:
if ref_dim is not batching.not_mapped:
if not isinstance(idx, indexing.Slice):
assert hasattr(idx, "shape")
if idx.shape:
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
bcast_dims)
new_indices.append(idx)
if ref_dim is not batching.not_mapped:
iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0)
new_indices.insert(ref_dim, iota)
return indexing.NDIndexer(tuple(new_indices), ref_shape,
new_integer_indexer_shape,
validate=True)
def _get_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped} if d is not batching.not_mapped}
ref, *idxs = batched_args ref, *flat_idxs = batched_args
ref_dim, *idx_dims = batched_dims ref_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims) for i_dim in flat_idx_dims)
bdim_out = 0 if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
if idx_is_batched: # TODO(sharadmv): handle vmap of multiple indexers
# If at least one of the idx is batched, we broadcast them all and move the indexers = tuple(_batch_indexer(indexer, dims, axis_size,
# batch dim to the front. ref.shape, ref_dim, idx_is_batched)
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d for indexer, dims in zip(indexers, indexers_dims))
in zip(idxs, idx_dims)) flat_indexers, tree = tree_util.tree_flatten(indexers)
idxs_shape, = {i.shape for i in idxs} or [()] return get_p.bind(ref, *flat_indexers, tree=tree), 0
if ref_is_batched:
# If ref is batched, we are doing a `get` with an additional axis. If `idxs`
# are also batched, then we are indexing into the batch axis with an `iota`.
indexed_dims = tuple_insert(indexed_dims, ref_dim, idx_is_batched)
if idx_is_batched:
# If we have batched idx, we need to insert the new iota index. The place
# where we add in the new `iota` index is `ref_dim` so we need to compute
# what `ref_dim` *would be* if we inserted it into `idxs` instead, because
# `idxs` doesn't include the non indexed dims.
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
else:
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
return get_p.bind(ref, *idxs, indexed_dims=indexed_dims), bdim_out
batching.primitive_batchers[get_p] = _get_vmap batching.primitive_batchers[get_p] = _get_vmap
def _swap_vmap(batched_args, batched_dims, *, indexed_dims): def _swap_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped} if d is not batching.not_mapped}
ref, val, *idxs = batched_args ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *idx_dims = batched_dims ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims) idx_is_batched = any(i_dim is not batching.not_mapped
if idx_is_batched: for i_dim in flat_idx_dims)
# If at least one of the idx is batched, we broadcast them all and move the if len(indexers) > 1:
# batch dim to the front. raise NotImplementedError("Batching with multiple indexers not supported.")
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d # TODO(sharadmv): handle vmap of multiple indexers
in zip(idxs, idx_dims)) indexers = tuple(_batch_indexer(indexer, dims, axis_size,
idxs_shape, = {i.shape for i in idxs} or [()] ref.shape, ref_dim, idx_is_batched)
if ref_is_batched and not idx_is_batched: for indexer, dims in zip(indexers, indexers_dims))
indexed_dims = tuple_insert(indexed_dims, ref_dim, False) flat_indexers, tree = tree_util.tree_flatten(indexers)
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape) if (ref_is_batched or idx_is_batched) and not val_is_batched:
if not val_is_batched: val = batching.broadcast(val, axis_size, 0)
val = batching.broadcast(val, axis_size, 0) if val_is_batched:
val_dim = 0
val = batching.moveaxis(val, val_dim, bdim_out)
elif idx_is_batched:
assert ref_is_batched and val_is_batched
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0) val = batching.moveaxis(val, val_dim, 0)
bdim_out = 0 return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0
return swap_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), bdim_out
batching.primitive_batchers[swap_p] = _swap_vmap batching.primitive_batchers[swap_p] = _swap_vmap
def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims): def _addupdate_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped} if d is not batching.not_mapped}
ref, val, *idxs = batched_args ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *idx_dims = batched_dims ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims) idx_is_batched = any(i_dim is not batching.not_mapped
if idx_is_batched: for i_dim in flat_idx_dims)
# If at least one of the idx is batched, we ensure all have bdims at front. if len(indexers) > 1:
idxs = tuple(batching.bdim_at_front(i, d, axis_size) raise NotImplementedError("Batching with multiple indexers not supported.")
for i, d in zip(idxs, idx_dims)) # TODO(sharadmv): handle vmap of multiple indexers
idxs_shape, = {i.shape for i in idxs} or [()] indexers = tuple(_batch_indexer(indexer, dims, axis_size,
if ref_is_batched and not idx_is_batched: ref.shape, ref_dim, idx_is_batched)
indexed_dims = tuple_insert(indexed_dims, ref_dim, False) for indexer, dims in zip(indexers, indexers_dims))
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape) flat_indexers, tree = tree_util.tree_flatten(indexers)
if not val_is_batched: if (ref_is_batched or idx_is_batched) and not val_is_batched:
val = batching.broadcast(val, axis_size, 0) val = batching.broadcast(val, axis_size, 0)
val_dim = 0 if val_is_batched:
val = batching.moveaxis(val, val_dim, bdim_out)
elif idx_is_batched:
assert ref_is_batched and val_is_batched
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
idxs_shape, = {i.shape for i in idxs} or [()]
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0) val = batching.moveaxis(val, val_dim, 0)
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), [] return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), []
batching.primitive_batchers[addupdate_p] = _addupdate_vmap batching.primitive_batchers[addupdate_p] = _addupdate_vmap
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
# so in the lowering, we have to expand out those squeezed dimensions again.
# Having a simple `broadcast_to` primitive allows us to lower directly
# to `tl.broadcast_to`.
broadcast_to_p = core.Primitive('broadcast_to')
def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array:
import jax.numpy as jnp
a = jnp.asarray(a)
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)
@broadcast_to_p.def_impl
def _broadcast_to_impl(a, *, shape):
import jax.numpy as jnp
return jnp.broadcast_to(a, shape)
@broadcast_to_p.def_abstract_eval
def _broadcast_to_abstract_eval(aval, *, shape):
return core.ShapedArray(shape, aval.dtype)
mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)

View File

@ -23,6 +23,7 @@ from typing import Any, Generic, TypeVar, Union
from jax._src import core from jax._src import core
from jax._src import effects from jax._src import effects
from jax._src import pretty_printer as pp from jax._src import pretty_printer as pp
from jax._src.state import indexing
from jax._src.util import safe_map, safe_zip from jax._src.util import safe_map, safe_zip
from jax._src.typing import Array from jax._src.typing import Array
@ -77,21 +78,34 @@ Aval = TypeVar("Aval", bound=core.AbstractValue)
@dataclasses.dataclass @dataclasses.dataclass
class RefIndexer: class RefIndexer:
ref: Any ref_or_view: Any
def __getitem__(self, slc): def __getitem__(self, slc):
if not isinstance(slc, tuple): if not isinstance(slc, tuple):
slc = (slc,) slc = (slc,)
return RefView(self.ref, slc) indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)
if isinstance(self.ref_or_view, RefView):
view = self.ref_or_view
return RefView(view.ref, (*view.indexers, indexer))
return RefView(self.ref_or_view, (indexer,))
Indexer = Any
@dataclasses.dataclass @dataclasses.dataclass
class RefView: class RefView:
ref: Any ref: Any
indexer: Any indexers: tuple[indexing.NDIndexer, ...]
@property @property
def at(self): def shape(self) -> tuple[int, ...]:
raise NotImplementedError("Can't call `.at` multiple times.") assert (
len(self.indexers) > 0
), "Should not be able to create a trivial RefView"
return self.indexers[-1].get_indexer_shape()
@property
def at(self) -> RefIndexer:
return RefIndexer(self)
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. # We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
@ -137,14 +151,10 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
return ref_set(tracer, idx, value) return ref_set(tracer, idx, value)
def _getitem(self, tracer, idx) -> Array: def _getitem(self, tracer, idx) -> Array:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_get # pytype: disable=import-error from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(tracer, idx) return ref_get(tracer, idx)
def _setitem(self, tracer, idx, value) -> None: def _setitem(self, tracer, idx, value) -> None:
if not isinstance(idx, tuple):
idx = idx,
from jax._src.state.primitives import ref_set # pytype: disable=import-error from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(tracer, idx, value) return ref_set(tracer, idx, value)

View File

@ -1472,6 +1472,7 @@ tf_not_yet_impl = [
"for", "for",
"inspect_sharding", "inspect_sharding",
"io_callback", "io_callback",
"broadcast_to",
"shard_map", "shard_map",
"global_array_to_host_local_array", "global_array_to_host_local_array",
"host_local_array_to_global_array", "host_local_array_to_global_array",

View File

@ -17,10 +17,6 @@
from jax._src import pallas from jax._src import pallas
from jax._src.pallas.core import BlockSpec from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import no_block_spec
from jax._src.pallas.indexing import broadcast_to
from jax._src.pallas.indexing import ds
from jax._src.pallas.indexing import dslice
from jax._src.pallas.indexing import Slice
from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call
from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.primitives import atomic_add from jax._src.pallas.primitives import atomic_add
@ -42,6 +38,10 @@ from jax._src.pallas.utils import cdiv
from jax._src.pallas.utils import next_power_of_2 from jax._src.pallas.utils import next_power_of_2
from jax._src.pallas.utils import strides_from_shape from jax._src.pallas.utils import strides_from_shape
from jax._src.pallas.utils import when from jax._src.pallas.utils import when
from jax._src.state.primitives import broadcast_to
from jax._src.state.indexing import ds
from jax._src.state.indexing import dslice
from jax._src.state.indexing import Slice
try: try:
from jax.experimental.pallas import gpu # pytype: disable=import-error from jax.experimental.pallas import gpu # pytype: disable=import-error

View File

@ -22,7 +22,7 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import jax import jax
from jax._src import util from jax._src import util
from jax._src.pallas import indexing from jax._src.state import indexing
import numpy as np import numpy as np
try: try:
@ -49,7 +49,8 @@ def int_indexer_strategy(dim) -> hps.SearchStrategy[int]:
@hps.composite @hps.composite
def slice_indexer_strategy(draw, dim) -> Slice | slice: def slice_indexer_strategy(draw, dim) -> Slice | slice:
start = draw(int_indexer_strategy(dim)) start = draw(int_indexer_strategy(dim))
size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max)) max_size = dim - start
size = draw(hps.integers(min_value=0, max_value=max_size))
return draw( return draw(
hps.one_of( hps.one_of(
hps.just(Slice(start, size)), hps.just(slice(start, start + size)) hps.just(Slice(start, size)), hps.just(slice(start, start + size))
@ -78,8 +79,8 @@ def indexer_strategy(draw, dim, int_indexer_shape
def nd_indexer_strategy(draw, shape) -> NDIndexer: def nd_indexer_strategy(draw, shape) -> NDIndexer:
num_indices = draw(hps.integers(min_value=0, max_value=len(shape))) num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
int_indexer_shape = draw(hnp.array_shapes()) int_indexer_shape = draw(hnp.array_shapes())
indices = [draw(indexer_strategy(dim, int_indexer_shape)) for dim indices = tuple(draw(indexer_strategy(dim, int_indexer_shape))
in shape[:num_indices]] for dim in shape[:num_indices])
return NDIndexer.from_indices_shape(indices, shape) return NDIndexer.from_indices_shape(indices, shape)
@ -97,6 +98,24 @@ class IndexerTest(parameterized.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape) _ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_int(self):
indices = (4, 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_slice_start(self):
indices = (slice(3, 2), 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_slice_end(self):
indices = (Slice(2, 2), 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_ndindexer_with_padding(self): def test_ndindexer_with_padding(self):
indices = () indices = ()
shape = (5, 5) shape = (5, 5)
@ -137,17 +156,17 @@ class IndexerTest(parameterized.TestCase):
indexer = NDIndexer.from_indices_shape(indices, shape) indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3)) self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3))
indices = (0, slice(4, 10), np.arange(5)) indices = (0, slice(2, 10), np.arange(5))
indexer = NDIndexer.from_indices_shape(indices, shape) indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 0)) self.assertTupleEqual(indexer.get_indexer_shape(), (5, 1))
indices = (0, 5, np.arange(5)) indices = (0, 1, np.arange(5))
indexer = NDIndexer.from_indices_shape(indices, shape) indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5,)) self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
indices = (ds(2, 3), np.arange(5)[:, None], np.arange(4)[None]) indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None])
indexer = NDIndexer.from_indices_shape(indices, shape) indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 3)) self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2))
@hp.given(hps.data()) @hp.given(hps.data())
def test_ndindexer(self, data): def test_ndindexer(self, data):

View File

@ -1471,7 +1471,7 @@ class PallasPrimitivesTest(PallasTest):
@parameterized.parameters(*[ @parameterized.parameters(*[
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "<- a[:,:,:]"), (lambda: (pl.dslice(0, 4), slice(None), slice(None)), "<- a[:,:,:]"),
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "<- a[:3,:,:]"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), "<- a[:3,:,:]"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:4,:,:4]"), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:,:,:4]"),
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "<- a[b,:,:4]"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "<- a[b,:,:4]"),
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.ds(4)), "<- a[f,g,:4]"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.ds(4)), "<- a[f,g,:4]"),
]) ])
@ -1486,7 +1486,7 @@ class PallasPrimitivesTest(PallasTest):
@parameterized.parameters(*[ @parameterized.parameters(*[
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "a[:,:,:] <-"), (lambda: (pl.dslice(0, 4), slice(None), slice(None)), "a[:,:,:] <-"),
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "a[:3,:,:] <-"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), "a[:3,:,:] <-"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:4,:,:4] <-"), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:,:,:4] <-"),
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "a[b,:,:4] <-"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "a[b,:,:4] <-"),
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), "a[m,n,:4] <-"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), "a[m,n,:4] <-"),
]) ])
@ -1504,7 +1504,7 @@ class PallasPrimitivesTest(PallasTest):
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), (lambda: (pl.dslice(0, 3), slice(None), slice(None)),
"c:i32[3,3,2], a[:3,:,:] <-"), "c:i32[3,3,2], a[:3,:,:] <-"),
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)),
"c:i32[3,3,4], a[1:4,:,:4] <-"), "c:i32[3,3,4], a[1:,:,:4] <-"),
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)),
"e:i32[5,3,4], a[b,:,:4] <-"), "e:i32[5,3,4], a[b,:,:4] <-"),
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)),

View File

@ -56,15 +56,15 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def test_cant_eval_get_primitive(self): def test_cant_eval_get_primitive(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
get_p.bind(jnp.ones(5)) get_p.bind(jnp.ones(5), tree=None)
def test_cant_eval_swap_primitive(self): def test_cant_eval_swap_primitive(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
swap_p.bind(jnp.ones(5), jnp.zeros(5)) swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
def test_cant_eval_addupdate_primitive(self): def test_cant_eval_addupdate_primitive(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
addupdate_p.bind(jnp.ones(5), jnp.zeros(5)) addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
def test_get_abstract_aval_must_take_in_refs(self): def test_get_abstract_aval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32) ref_aval = core.ShapedArray((), jnp.float32)
@ -95,11 +95,37 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])), idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32), out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_nontrivial_slice",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_nontrivial_slice2",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="get_with_ref_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(1, 3), slice(None), slice(None)),
out_shape=(2, 2, 4), out_dtype=jnp.float32,
at_indices=((0,),)),
dict(testcase_name="get_with_ref_simple_at2",
ref_shape=(6, 1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 2), slice(0, 1), slice(1, 3), slice(None), slice(None)),
out_shape=(2, 1, 2, 2, 4), out_dtype=jnp.float32,
at_indices=((slice(2, 6),),)),
dict(testcase_name="get_with_ref_multiple_at",
ref_shape=(1, 3, 5, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(None), slice(0, 2)),
out_shape=(3, 1, 2), out_dtype=jnp.float32,
at_indices=((0,), (slice(None), slice(0, 1)))),
) )
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None, def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
out_dtype=None, should_error=False): out_dtype=None, at_indices=(),
should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
def f(x_ref): def f(x_ref):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
out = ref_get(x_ref, idx) out = ref_get(x_ref, idx)
return [out] return [out]
if should_error: if should_error:
@ -160,13 +186,43 @@ class StatePrimitivesTest(jtu.JaxTestCase):
val_shape=(2, 1, 2), val_dtype=jnp.float32, val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])), idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32), out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nontrivial_slice",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
val_shape=(2, 1, 2), val_dtype=jnp.float32,
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nontrivial_slice2",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
val_shape=(1, 2, 2, 4), val_dtype=jnp.float32,
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="swap_with_ref_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None),),
val_shape=(1, 1, 4), val_dtype=jnp.float32,
out_shape=(1, 1, 4), out_dtype=jnp.float32,
at_indices=((0,),),),
dict(testcase_name="swap_with_ref_simple_at2",
ref_shape=(4, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
at_indices=((slice(0, 2),),),),
dict(testcase_name="swap_with_ref_multiple_at2",
ref_shape=(1, 4, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
at_indices=((slice(None), slice(0, 2),), (0,)),),
) )
def test_swap_abstract_eval(self, ref_shape, ref_dtype, def test_swap_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None, val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
should_error=False): at_indices=(), should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype) val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val): def f(x_ref, val):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
out = ref_swap(x_ref, idx, val) out = ref_swap(x_ref, idx, val)
return [out] return [out]
if should_error: if should_error:
@ -192,13 +248,13 @@ class StatePrimitivesTest(jtu.JaxTestCase):
idx=(slice(None),), should_error=True), idx=(slice(None),), should_error=True),
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2), dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32, ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), out_shape=(1, 2), out_dtype=jnp.float32), idx=(),),
dict(testcase_name="bad_dtype", ref_shape=(1, 2), dict(testcase_name="bad_dtype", ref_shape=(1, 2),
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32, ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), should_error=True), idx=(), should_error=True),
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2), dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32, ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(0,), out_shape=(2,), out_dtype=jnp.float32), idx=(0,),),
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2), dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32, ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
idx=(slice(None), 0)), idx=(slice(None), 0)),
@ -216,13 +272,34 @@ class StatePrimitivesTest(jtu.JaxTestCase):
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 1, 2), val_dtype=jnp.float32, val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))), idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
dict(testcase_name="ref_with_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((0,),)),
dict(testcase_name="ref_with_simple_at2",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 3, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(0, 3),),)),
dict(testcase_name="ref_with_multiple_at",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(0, 3),), (0,))),
dict(testcase_name="ref_with_multiple_at2",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(None), slice(0, 3),), (0,))),
) )
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype, def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None, val_shape, val_dtype, idx, at_indices=(), should_error=False):
should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype) val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val): def f(x_ref, val):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
ref_addupdate(x_ref, idx, val) ref_addupdate(x_ref, idx, val)
return [] return []
if should_error: if should_error:
@ -1595,7 +1672,6 @@ if CAN_USE_HYPOTHESIS:
def test_vjp(self, data): def test_vjp(self, data):
spec = data.draw(func_spec()) spec = data.draw(func_spec())
print(spec)
def impl(x): def impl(x):
return spec.call((x, jnp.zeros_like(x)))[1] return spec.call((x, jnp.zeros_like(x)))[1]