mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[Pallas] Refactor indexing primitives to use NDIndexer abstraction
Some notes about this change: * This change upgrades the `RefView` abstraction to store multiple indexers. This allows doing things like `ref.at[0].at[0]` to recursively create a view of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s. * This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p) but does *not* generalize their rules. Most of the rules will raise a NotImplementedError if you use multiple `NDIndexer`s. Adding support will be done in a future CL. * With the above in mind, this change only preserves existing public facing APIs and adding actual support will involve updating the rules. PiperOrigin-RevId: 595229523
This commit is contained in:
parent
8c5e7b26ba
commit
836563fadf
@ -738,15 +738,17 @@ pytype_strict_library(
|
||||
name = "state_types",
|
||||
srcs = [
|
||||
"_src/state/__init__.py",
|
||||
"_src/state/indexing.py",
|
||||
"_src/state/types.py",
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":effects",
|
||||
":pretty_printer",
|
||||
":tree_util",
|
||||
":typing",
|
||||
":util",
|
||||
],
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
|
@ -1,155 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax import tree_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.util import merge_lists
|
||||
from jax._src.util import partition_list
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
|
||||
# We could use `jnp.broadcast_to` but that lowers to squeezing,
|
||||
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
|
||||
# so in the lowering, we have to expand out those squeezed dimensions again.
|
||||
# Having a simple `broadcast_to` primitive allows us to lower directly
|
||||
# to `tl.broadcast_to`.
|
||||
broadcast_to_p = jax_core.Primitive('broadcast_to')
|
||||
|
||||
def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array:
|
||||
if a.shape == shape:
|
||||
return a
|
||||
return broadcast_to_p.bind(a, shape=shape)
|
||||
|
||||
@broadcast_to_p.def_impl
|
||||
def _broadcast_to_impl(a, *, shape):
|
||||
return jnp.broadcast_to(a, shape)
|
||||
|
||||
@broadcast_to_p.def_abstract_eval
|
||||
def _broadcast_to_abstract_eval(aval, *, shape):
|
||||
return jax_core.ShapedArray(shape, aval.dtype)
|
||||
|
||||
mlir.register_lowering(
|
||||
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
|
||||
)
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class Slice:
|
||||
"""Represents a slice with a dynamic start index and a fixed size."""
|
||||
start: Any
|
||||
size: int
|
||||
|
||||
def __post_init__(self):
|
||||
if self.size < 0:
|
||||
raise ValueError("`size` must not be negative.")
|
||||
|
||||
def tree_flatten(self):
|
||||
# If `start` is statically known, we treat it as static information
|
||||
if isinstance(self.start, int):
|
||||
return (), (self.start, self.size)
|
||||
return (self.start,), (self.size,)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children) -> Slice:
|
||||
return cls(*children, *aux_data)
|
||||
|
||||
@classmethod
|
||||
def from_slice(cls, slc: slice, size: int) -> Slice:
|
||||
start, stop, step = slc.indices(size)
|
||||
if step != 1:
|
||||
raise ValueError(f"slice must have a step of 1 (found: {step})")
|
||||
return cls(start, stop - start)
|
||||
|
||||
|
||||
def dslice(start: int | jax.Array | None, size: int | None = None
|
||||
) -> slice | Slice:
|
||||
"""Constructs a `Slice` from a start and a size."""
|
||||
if start is None:
|
||||
return slice(None)
|
||||
if size is None:
|
||||
if not isinstance(start, int):
|
||||
raise ValueError("Non-static `dslice`")
|
||||
return Slice(0, start)
|
||||
return Slice(start, size)
|
||||
ds = dslice # Handy alias
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class NDIndexer:
|
||||
indices: tuple[int | Slice | jax.Array, ...]
|
||||
shape: tuple[int, ...]
|
||||
int_indexer_shape: tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.indices) != len(self.shape):
|
||||
raise ValueError("`indices` must be the same length as `Ref` shape.")
|
||||
|
||||
def tree_flatten(self):
|
||||
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
|
||||
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
|
||||
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
|
||||
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
|
||||
self.int_indexer_shape)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, data, flat_idx):
|
||||
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
|
||||
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
|
||||
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
|
||||
return NDIndexer(tuple(indices), shape, int_indexer_shape)
|
||||
|
||||
@classmethod
|
||||
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
||||
if len(indices) > len(shape):
|
||||
raise ValueError("`indices` must not be longer than `shape`.")
|
||||
# Pad out indices with slice(None)
|
||||
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
|
||||
# Convert all `slice`s to `Slice`s
|
||||
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
|
||||
else i for i, s in zip(indices, shape))
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
||||
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
||||
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
|
||||
int_indexers]
|
||||
indexer_shapes = [i.shape for i in int_indexers]
|
||||
if indexer_shapes:
|
||||
try:
|
||||
bcast_shape = np.broadcast_shapes(*indexer_shapes)
|
||||
except ValueError as e:
|
||||
# Raise a nicer error than the NumPy one.
|
||||
raise ValueError("Cannot broadcast shapes for indexing: "
|
||||
f"{tuple(a for a in indexer_shapes)}") from e
|
||||
else:
|
||||
bcast_shape = ()
|
||||
int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers]
|
||||
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
||||
return NDIndexer(tuple(indices), shape, bcast_shape)
|
||||
|
||||
def get_indexer_shape(self) -> tuple[int, ...]:
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
|
||||
other_indexers, _ = partition_list(is_int_indexing, self.indices)
|
||||
other_shape = [s.size for s in other_indexers] # type: ignore
|
||||
return (*self.int_indexer_shape, *other_shape)
|
@ -42,12 +42,12 @@ from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
from jax._src.pallas import core
|
||||
from jax._src.pallas import indexing
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.util import safe_map
|
||||
from jax._src.util import safe_zip
|
||||
@ -610,12 +610,16 @@ def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
|
||||
|
||||
|
||||
def _get_lowering_rule(
|
||||
ctx: LoweringRuleContext, ref, *non_slice_idx, indexed_dims: Sequence[bool]
|
||||
ctx: LoweringRuleContext, ref, *idx, tree,
|
||||
):
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Only one indexer currently supported.")
|
||||
# Call _load_lowering_rule (since it's more general)
|
||||
ref_aval, *non_slice_idx_avals = ctx.avals_in
|
||||
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
|
||||
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
|
||||
nd_indexer = indexers[0]
|
||||
nd_indexer_avals = indexers_avals[0]
|
||||
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None))
|
||||
avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None))
|
||||
ctx = ctx.replace(avals_in=avals_flat)
|
||||
@ -630,13 +634,17 @@ def _swap_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
ref,
|
||||
val,
|
||||
*non_slice_idx,
|
||||
indexed_dims: Sequence[bool],
|
||||
*idx,
|
||||
tree
|
||||
):
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Only one indexer currently supported.")
|
||||
# Call _masked_swap_lowering_rule (since it's more general)
|
||||
ref_aval, val_aval, *non_slice_idx_avals = ctx.avals_in
|
||||
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
|
||||
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
|
||||
ref_aval, val_aval, *_ = ctx.avals_in
|
||||
nd_indexer = indexers[0]
|
||||
nd_indexer_avals = indexers_avals[0]
|
||||
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None))
|
||||
avals_flat = tree_util.tree_leaves(
|
||||
(ref_aval, nd_indexer_avals, val_aval, None)
|
||||
|
@ -29,10 +29,10 @@ from jax._src import pretty_printer as pp
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import indexing
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -264,38 +264,6 @@ def _dma_start_abstract_eval(*args, tree, device_id_type):
|
||||
del args, tree, device_id_type
|
||||
return []
|
||||
|
||||
def _pp_slice(slc: indexing.Slice, dim: int, context: jax_core.JaxprPpContext
|
||||
) -> str:
|
||||
start, size = slc.start, slc.size
|
||||
if isinstance(start, jax_core.Var):
|
||||
start_str = jax_core.pp_var(start, context)
|
||||
end_str = f'{start_str}+{size}'
|
||||
else:
|
||||
start_str = '' if start == 0 else str(start)
|
||||
end = start + size
|
||||
end_str = '' if end == dim else str(end)
|
||||
return f'{start_str}:{end_str}'
|
||||
|
||||
def _pp_indexer(indexer: indexing.NDIndexer,
|
||||
context: jax_core.JaxprPpContext) -> pp.Doc:
|
||||
indices = []
|
||||
for idx, dim in zip(indexer.indices, indexer.shape):
|
||||
if isinstance(idx, indexing.Slice):
|
||||
indices.append(_pp_slice(idx, dim, context))
|
||||
else:
|
||||
indices.append(jax_core.pp_var(idx, context)) # type: ignore
|
||||
return pp.text(','.join(indices))
|
||||
|
||||
def _pp_ref(ref, indexer, context):
|
||||
return state_primitives.pp_ref(
|
||||
pp.concat([
|
||||
pp.text(jax_core.pp_var(ref, context)),
|
||||
pp.text("["),
|
||||
_pp_indexer(indexer, context),
|
||||
pp.text("]"),
|
||||
])
|
||||
)
|
||||
|
||||
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
context: jax_core.JaxprPpContext,
|
||||
settings: jax_core.JaxprPpSettings):
|
||||
@ -309,9 +277,9 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
return pp.concat([
|
||||
pp.text('dma_start'),
|
||||
pp.text(' '),
|
||||
_pp_ref(src_ref, src_indexer, context),
|
||||
sp.pp_ref_indexers(context, src_ref, (src_indexer,)),
|
||||
pp.text(' -> '),
|
||||
_pp_ref(dst_ref, dst_indexer, context),
|
||||
sp.pp_ref_indexers(context, dst_ref, (dst_indexer,)),
|
||||
pp.text(' '),
|
||||
pp.text(jax_core.pp_var(dst_sem, context)),
|
||||
])
|
||||
@ -358,13 +326,14 @@ def _dma_wait_abstract_eval(*args, tree, device_id_type):
|
||||
def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
context: jax_core.JaxprPpContext,
|
||||
settings: jax_core.JaxprPpSettings):
|
||||
del settings
|
||||
invars = eqn.invars
|
||||
tree = eqn.params["tree"]
|
||||
sem, ref, indexer = tree_util.tree_unflatten(tree, invars)
|
||||
return pp.concat([
|
||||
pp.text('dma_wait'),
|
||||
pp.text(' '),
|
||||
_pp_ref(ref, indexer, context),
|
||||
sp.pp_ref_indexers(context, ref, (indexer,)),
|
||||
pp.text(' '),
|
||||
pp.text(jax_core.pp_var(sem, context)),
|
||||
])
|
||||
@ -373,7 +342,10 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
|
||||
|
||||
def _get_ref_and_indexer(ref):
|
||||
if isinstance(ref, state.RefView):
|
||||
return ref.ref, ref.indexer
|
||||
indexers = ref.indexers
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Only one indexer supported.")
|
||||
return ref.ref, ref.indexers[0].indices
|
||||
return ref, (slice(None),) * len(ref.shape)
|
||||
|
||||
def make_async_copy(src_ref, dst_ref, sem):
|
||||
|
@ -27,15 +27,15 @@ from jax._src import core as jax_core
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import state
|
||||
from jax._src.util import (safe_map, safe_zip)
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as sp
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import indexing
|
||||
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
@ -224,44 +224,13 @@ def _load_abstract_eval(*avals_flat, args_tree, **_):
|
||||
|
||||
load_p.def_effectful_abstract_eval(_load_abstract_eval)
|
||||
|
||||
def _pp_dslice(dim: int, slice: Slice, context):
|
||||
size = pp.text(str(slice.size))
|
||||
if isinstance(slice.start, int):
|
||||
if slice.start == 0:
|
||||
start = pp.text("")
|
||||
else:
|
||||
start = pp.text(str(slice.start))
|
||||
if slice.size == dim:
|
||||
end = pp.text("")
|
||||
else:
|
||||
end = pp.text(str(slice.start + slice.size))
|
||||
else:
|
||||
start = pp.text(jax_core.pp_var(slice.start, context))
|
||||
end = pp.concat([start, pp.text("+"), size])
|
||||
return pp.concat([start, pp.text(":"), end])
|
||||
|
||||
def _pp_idx(ref_aval, idx: NDIndexer, context):
|
||||
docs = [
|
||||
_pp_dslice(d, s, context) if isinstance(s, Slice)
|
||||
else pp.text(jax_core.pp_var(s, context))
|
||||
for s, d in zip(idx.indices, ref_aval.shape)]
|
||||
if not docs:
|
||||
return pp.text("")
|
||||
doc = [docs[0]]
|
||||
for d in docs[1:]:
|
||||
doc.append(pp.text(","))
|
||||
doc.append(d)
|
||||
return pp.concat(doc)
|
||||
|
||||
def _load_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `a = load x i` as `x[i] <- a`
|
||||
y, = eqn.outvars
|
||||
x, idx, _, _ = eqn.params["args_tree"].unflatten(eqn.invars)
|
||||
idx = _pp_idx(eqn.invars[0].aval, idx, context)
|
||||
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
|
||||
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']')
|
||||
]))])
|
||||
return pp.concat([
|
||||
lhs, pp.text(' <- '), sp.pp_ref_indexers(context, x, (idx,))])
|
||||
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
|
||||
|
||||
|
||||
@ -339,15 +308,14 @@ def _swap_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `_ = swap x v i` as `x[i] <- v`
|
||||
y, = eqn.outvars
|
||||
x, idx, val, _ = eqn.params["args_tree"].unflatten(eqn.invars)
|
||||
idx = _pp_idx(eqn.invars[0].aval, idx, context)
|
||||
x_i = pp.concat([pp.text(jax_core.pp_var(x, context)),
|
||||
pp.text('['), idx, pp.text(']')])
|
||||
x_i = sp.pp_ref_indexers(context, x, (idx,))
|
||||
if isinstance(y, jax_core.DropVar):
|
||||
return pp.concat([state_primitives.pp_ref(
|
||||
x_i), pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
|
||||
return pp.concat([
|
||||
x_i,
|
||||
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
|
||||
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return pp.concat([y, pp.text(', '), state_primitives.pp_ref(x_i),
|
||||
pp.text(' <- '), state_primitives.pp_ref(x_i),
|
||||
return pp.concat([y, pp.text(', '), x_i,
|
||||
pp.text(' <- '), x_i,
|
||||
pp.text(', '), pp.text(jax_core.pp_var(val, context))])
|
||||
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||||
|
||||
|
@ -39,12 +39,12 @@ from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.lib import hlo_helpers
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import indexing
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.state import AbstractRef
|
||||
from jax._src.state import discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src.util import merge_lists
|
||||
from jax._src.util import partition_list
|
||||
@ -438,7 +438,7 @@ _TRITON_FN_MAPPING = {
|
||||
lax.nextafter_p: tl.math.nextafter,
|
||||
ad_util.add_any_p: tl.semantic.add,
|
||||
# Other ops.
|
||||
indexing.broadcast_to_p: tl.broadcast_to,
|
||||
sp.broadcast_to_p: tl.broadcast_to,
|
||||
primitives.atomic_cas_p: tl.atomic_cas,
|
||||
primitives.max_contiguous_p: tl.max_contiguous,
|
||||
primitives.multiple_of_p: tl.multiple_of,
|
||||
@ -727,28 +727,16 @@ def _pack_indices(non_slice_idx, indexed_dims):
|
||||
|
||||
|
||||
def _get_lowering_rule(
|
||||
ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims
|
||||
ctx: TritonLoweringRuleContext, ptr, *idx, tree
|
||||
):
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
if not isinstance(ptr.type, tl.pointer_type):
|
||||
assert not non_slice_idx
|
||||
assert len(indexers) == 0
|
||||
return ptr
|
||||
|
||||
ref_aval, *idx_avals = ctx.avals_in
|
||||
idx_avals = _pack_indices(idx_avals, indexed_dims)
|
||||
if non_slice_idx:
|
||||
(int_indexer_shape,) = {
|
||||
i.shape for i in idx_avals if not isinstance(i, slice)
|
||||
}
|
||||
else:
|
||||
int_indexer_shape = ()
|
||||
|
||||
idx = _pack_indices(non_slice_idx, indexed_dims)
|
||||
idx = tuple(
|
||||
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
||||
for s, slc in zip(ref_aval.shape, idx)
|
||||
)
|
||||
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
|
||||
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None))
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("No support for multiple indexers yet.")
|
||||
indexer = indexers[0]
|
||||
args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, None, None))
|
||||
return _masked_load_lowering_rule(
|
||||
ctx,
|
||||
*args_flat,
|
||||
@ -794,24 +782,16 @@ triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
|
||||
|
||||
|
||||
def _swap_lowering_rule(
|
||||
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
|
||||
ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
|
||||
):
|
||||
ref_aval, _, *idx_avals = ctx.avals_in
|
||||
idx_avals = _pack_indices(idx_avals, indexed_dims)
|
||||
if non_slice_idx:
|
||||
(int_indexer_shape,) = {
|
||||
i.shape for i in idx_avals if not isinstance(i, slice)
|
||||
}
|
||||
else:
|
||||
int_indexer_shape = ()
|
||||
|
||||
idx = _pack_indices(non_slice_idx, indexed_dims)
|
||||
idx = tuple(
|
||||
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
||||
for s, slc in zip(ref_aval.shape, idx)
|
||||
)
|
||||
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
|
||||
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None))
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
if not isinstance(ptr.type, tl.pointer_type):
|
||||
assert len(indexers) == 0
|
||||
return ptr
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("No support for multiple indexers yet.")
|
||||
indexer = indexers[0]
|
||||
args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, value, None))
|
||||
return _masked_swap_lowering_rule(
|
||||
ctx, *args_flat, args_tree=args_tree, eviction_policy=None
|
||||
)
|
||||
@ -850,24 +830,17 @@ triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
|
||||
|
||||
|
||||
def _addupdate_lowering_rule(
|
||||
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
|
||||
ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
|
||||
):
|
||||
ref_block_info, *_ = ctx.block_infos
|
||||
avals_in = ctx.avals_in
|
||||
idx = _pack_indices(non_slice_idx, indexed_dims)
|
||||
if non_slice_idx:
|
||||
(int_indexer_shape,) = {
|
||||
tuple(map(lambda x: x.value, i.shape)) for i in non_slice_idx
|
||||
}
|
||||
else:
|
||||
int_indexer_shape = ()
|
||||
idx = tuple(
|
||||
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
||||
for s, slc in zip(avals_in[0].shape, idx)
|
||||
)
|
||||
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
if not isinstance(ptr.type, tl.pointer_type):
|
||||
assert len(indexers) == 0
|
||||
return ptr
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("No support for multiple indexers yet.")
|
||||
indexer = indexers[0]
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
||||
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0].shape, ctx.builder
|
||||
)
|
||||
tl.atomic_add(ptr, value, _builder=ctx.builder)
|
||||
return []
|
||||
|
@ -34,9 +34,11 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import slicing as lax_slicing
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||
from jax._src.state.utils import hoist_consts_to_refs
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
|
||||
partition_list, merge_lists, split_dict)
|
||||
|
||||
@ -144,34 +146,112 @@ def _eval_jaxpr_discharge_state(
|
||||
env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge])
|
||||
return out_vals + ref_vals
|
||||
|
||||
def _is_trivial_indexer(indexer: indexing.NDIndexer):
|
||||
for s, idx in zip(indexer.shape, indexer.indices):
|
||||
if not isinstance(idx, indexing.Slice):
|
||||
return False
|
||||
if not isinstance(idx.start, int):
|
||||
return False
|
||||
if idx.start:
|
||||
return False
|
||||
if idx.size != s:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _convert_to_array_indexer(indexer: indexing.NDIndexer
|
||||
) -> tuple[int | Array, ...]:
|
||||
# This is the general gather case. We need to create the gather arrays.
|
||||
is_integer_indexer, _, integer_indexer = (
|
||||
indexing.unpack_ndindexer(indexer)
|
||||
)
|
||||
total_shape = indexer.get_indexer_shape()
|
||||
int_indexer_shape = indexer.int_indexer_shape
|
||||
slice_shape = total_shape[len(int_indexer_shape):]
|
||||
slice_dims = tuple(
|
||||
i + len(int_indexer_shape) for i in range(len(slice_shape))
|
||||
)
|
||||
slice_dim_iter = iter(slice_dims)
|
||||
slice_indexer: list[Array] = []
|
||||
for idx, is_int_index in zip(indexer.indices, is_integer_indexer):
|
||||
if not is_int_index:
|
||||
assert isinstance(idx, indexing.Slice)
|
||||
slice_indices = lax.broadcasted_iota(
|
||||
np.dtype("int32"), total_shape, next(slice_dim_iter)
|
||||
) + idx.start
|
||||
slice_indexer.append(slice_indices)
|
||||
integer_indexer = tuple(
|
||||
lax.expand_dims(idx, (-1,)) for idx in integer_indexer
|
||||
)
|
||||
continue
|
||||
assert next(slice_dim_iter, None) is None
|
||||
return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer))
|
||||
|
||||
|
||||
def _maybe_convert_to_dynamic_slice(
|
||||
indexer: indexing.NDIndexer,
|
||||
) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None:
|
||||
# An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice`
|
||||
# if each of the indexers is a `Slice` or a ()-shaped value.
|
||||
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
|
||||
for i in indexer.indices):
|
||||
return None
|
||||
_convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32"))
|
||||
starts = tuple(
|
||||
_convert_i32(i.start) if isinstance(i, indexing.Slice)
|
||||
else _convert_i32(i) for i in indexer.indices
|
||||
)
|
||||
sizes = tuple(
|
||||
i.size if isinstance(i, indexing.Slice) else 1 for i in indexer.indices
|
||||
)
|
||||
squeeze_dims = tuple(
|
||||
i
|
||||
for i, idx in enumerate(indexer.indices)
|
||||
if not isinstance(idx, indexing.Slice)
|
||||
)
|
||||
return starts, sizes, squeeze_dims
|
||||
|
||||
|
||||
@register_discharge_rule(get_p)
|
||||
def _get_discharge_rule(
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue], x, *non_slice_idx,
|
||||
indexed_dims: Sequence[bool]):
|
||||
out_avals: Sequence[core.AbstractValue], x, *idx,
|
||||
tree):
|
||||
del in_avals, out_avals
|
||||
y = _get_discharge(x, non_slice_idx, indexed_dims)
|
||||
return (None,) * (len(non_slice_idx) + 1), y
|
||||
y = _get_discharge(x, idx, tree)
|
||||
return (None,) * (len(idx) + 1), y
|
||||
|
||||
def _get_discharge(x, idx, indexed_dims):
|
||||
if not any(indexed_dims):
|
||||
return x
|
||||
if all(not i.shape for i in idx):
|
||||
return _dynamic_index(x, idx, indexed_dims)
|
||||
else:
|
||||
return _prepend_gather(x, idx, indexed_dims)
|
||||
|
||||
def _prepend_gather(x, idx, indexed_dims):
|
||||
indexer = _indexer(idx, indexed_dims)
|
||||
def _prepend_gather(x, indexer):
|
||||
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
|
||||
return x[None][(np.array(0, 'int32'), *indexer)]
|
||||
|
||||
def _prepend_scatter(x, idx, indexed_dims, val, *, add=False):
|
||||
indexer = _indexer(idx, indexed_dims)
|
||||
def _prepend_scatter(x, indexer, val, *, add=False):
|
||||
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
|
||||
# However, since this is scatter, we need to remove the 1-sized dimension
|
||||
# we added at the front.
|
||||
if add:
|
||||
return x[None].at[(0, *indexer)].add(val)[0]
|
||||
return x[None].at[(0, *indexer)].set(val)[0]
|
||||
|
||||
|
||||
def _get_discharge(x, idx, tree):
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Only single indexer is supported.")
|
||||
indexer = indexers[0]
|
||||
if _is_trivial_indexer(indexer):
|
||||
return x
|
||||
# If everything in the indexer is a slice or ()-shaped, we can also
|
||||
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
||||
# We need to squeeze out the the 1-sized slices at the end.
|
||||
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
||||
starts, sizes, squeeze_dims = maybe_slice
|
||||
y = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||
return lax.squeeze(y, squeeze_dims)
|
||||
indexer = _convert_to_array_indexer(indexer)
|
||||
if indexer is None:
|
||||
return x
|
||||
return x[None][(np.array(0, 'int32'), *indexer)]
|
||||
|
||||
def _indexer(idx, indexed_dims):
|
||||
idx_ = iter(idx)
|
||||
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
|
||||
@ -181,59 +261,59 @@ def _indexer(idx, indexed_dims):
|
||||
@register_discharge_rule(swap_p)
|
||||
def _swap_discharge_rule(
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
|
||||
indexed_dims: Sequence[bool]):
|
||||
out_avals: Sequence[core.AbstractValue], x, val, *idx,
|
||||
tree):
|
||||
del in_avals, out_avals
|
||||
if not any(indexed_dims):
|
||||
z, x_new = x, val
|
||||
z, x_new = _swap_discharge(x, val, non_slice_idx, indexed_dims)
|
||||
return (x_new, None) + (None,) * len(non_slice_idx), z
|
||||
z, x_new = _swap_discharge(x, val, idx, tree)
|
||||
return (x_new, None) + (None,) * len(idx), z
|
||||
|
||||
def _swap_discharge(x, val, idx, indexed_dims):
|
||||
if not any(indexed_dims):
|
||||
z, x_new = x, val
|
||||
elif all(not i.shape for i in idx):
|
||||
z = _dynamic_index(x, idx, indexed_dims)
|
||||
x_new = _dynamic_update_index(x, idx, val, indexed_dims)
|
||||
else:
|
||||
z = _prepend_gather(x, idx, indexed_dims)
|
||||
x_new = _prepend_scatter(x, idx, indexed_dims, val)
|
||||
return z, x_new
|
||||
def _swap_discharge(x, val, idx, tree):
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Only single indexer is supported.")
|
||||
indexer = indexers[0]
|
||||
if _is_trivial_indexer(indexer):
|
||||
return x, val
|
||||
# If everything in the indexer is a slice or ()-shaped, we can also
|
||||
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
||||
# We need to squeeze out the the 1-sized slices at the end.
|
||||
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
||||
starts, sizes, squeeze_dims = maybe_slice
|
||||
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||
val = lax.expand_dims(val, squeeze_dims)
|
||||
y = lax_slicing.dynamic_update_slice(x, val, starts)
|
||||
return lax.squeeze(x_old, squeeze_dims), y
|
||||
indexer = _convert_to_array_indexer(indexer)
|
||||
x_old = _prepend_gather(x, indexer)
|
||||
return x_old, _prepend_scatter(x, indexer, val)
|
||||
|
||||
@register_discharge_rule(addupdate_p)
|
||||
def _addupdate_discharge_rule(
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
|
||||
indexed_dims: Sequence[bool]):
|
||||
out_avals: Sequence[core.AbstractValue], x, val, *idx,
|
||||
tree):
|
||||
del in_avals, out_avals
|
||||
ans = _addupdate_discharge(x, val, non_slice_idx, indexed_dims)
|
||||
return (ans, None) + (None,) * len(non_slice_idx), []
|
||||
ans = _addupdate_discharge(x, val, idx, tree)
|
||||
return (ans, None) + (None,) * len(idx), []
|
||||
|
||||
def _addupdate_discharge(x, val, idx, indexed_dims):
|
||||
if not any(indexed_dims):
|
||||
def _addupdate_discharge(x, val, idx, tree):
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Only single indexer is supported.")
|
||||
indexer = indexers[0]
|
||||
if _is_trivial_indexer(indexer):
|
||||
return x + val
|
||||
if all(not i.shape for i in idx):
|
||||
y = val + _dynamic_index(x, idx, indexed_dims)
|
||||
return _dynamic_update_index(x, idx, y, indexed_dims)
|
||||
else:
|
||||
return _prepend_scatter(x, idx, indexed_dims, val, add=True)
|
||||
|
||||
def _dynamic_index(x, idx, indexed_dims):
|
||||
assert isinstance(idx, (list, tuple)) and idx
|
||||
idx_ = iter(idx)
|
||||
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
|
||||
assert next(idx_, None) is None
|
||||
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
|
||||
out = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||
return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b])
|
||||
|
||||
def _dynamic_update_index(x, idx, val, indexed_dims):
|
||||
assert isinstance(idx, (list, tuple)) and idx
|
||||
idx_ = iter(idx)
|
||||
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
|
||||
assert next(idx_, None) is None
|
||||
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
|
||||
return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts)
|
||||
# If everything in the indexer is a slice or ()-shaped, we can also
|
||||
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
||||
# We need to squeeze out the the 1-sized slices at the end.
|
||||
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
||||
starts, sizes, squeeze_dims = maybe_slice
|
||||
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||
val = lax.expand_dims(val, squeeze_dims)
|
||||
y = lax_slicing.dynamic_update_slice(x, x_old + val, starts)
|
||||
return y
|
||||
indexer = _convert_to_array_indexer(indexer)
|
||||
return _prepend_scatter(x, indexer, val, add=True)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_closed_jaxpr_discharge(closed_jaxpr):
|
||||
|
192
jax/_src/state/indexing.py
Normal file
192
jax/_src/state/indexing.py
Normal file
@ -0,0 +1,192 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import merge_lists
|
||||
from jax._src.util import partition_list
|
||||
import numpy as np
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class Slice:
|
||||
"""Represents a slice with a dynamic start index and a fixed size."""
|
||||
start: Any
|
||||
size: int
|
||||
|
||||
def __post_init__(self):
|
||||
if self.size < 0:
|
||||
raise ValueError("`size` must not be negative.")
|
||||
|
||||
def tree_flatten(self):
|
||||
# If `start` is statically known, we treat it as static information
|
||||
if isinstance(self.start, int):
|
||||
return (), (self.start, self.size)
|
||||
return (self.start,), (self.size,)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children) -> Slice:
|
||||
return cls(*children, *aux_data)
|
||||
|
||||
@classmethod
|
||||
def from_slice(cls, slc: slice, size: int) -> Slice:
|
||||
start, stop, step = slc.indices(size)
|
||||
if step != 1:
|
||||
raise ValueError(f"slice must have a step of 1 (found: {step})")
|
||||
return cls(start, stop - start)
|
||||
|
||||
|
||||
def dslice(start: int | Array | None, size: int | None = None
|
||||
) -> slice | Slice:
|
||||
"""Constructs a `Slice` from a start and a size."""
|
||||
if start is None:
|
||||
return slice(None)
|
||||
if size is None:
|
||||
if not isinstance(start, int):
|
||||
raise ValueError("Non-static `dslice`")
|
||||
return Slice(0, start)
|
||||
return Slice(start, size)
|
||||
ds = dslice # Handy alias
|
||||
|
||||
|
||||
IntIndexer = Union[int, Array]
|
||||
DimIndexer = Union[IntIndexer, Slice]
|
||||
|
||||
def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
|
||||
tuple[Slice, ...],
|
||||
tuple[IntIndexer, ...]]:
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
|
||||
slice_indexers, int_indexers = partition_list(
|
||||
is_int_indexing, indexer.indices)
|
||||
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
|
||||
|
||||
def _maybe_concretize(x: Any):
|
||||
try:
|
||||
return core.concrete_or_error(None, x)
|
||||
except core.ConcretizationTypeError:
|
||||
return None
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class NDIndexer:
|
||||
indices: tuple[DimIndexer, ...]
|
||||
shape: tuple[int, ...]
|
||||
int_indexer_shape: tuple[int, ...]
|
||||
validate: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.validate:
|
||||
return
|
||||
if len(self.indices) != len(self.shape):
|
||||
raise ValueError(
|
||||
f"`indices` must be the same length as `Ref` shape.: {self}."
|
||||
)
|
||||
# We validate integer indexing shapes here
|
||||
for idx, s in zip(self.indices, self.shape):
|
||||
if isinstance(idx, Slice):
|
||||
start = idx.start
|
||||
if value := _maybe_concretize(start):
|
||||
if value >= s:
|
||||
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
|
||||
if value + idx.size > s:
|
||||
raise ValueError(
|
||||
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
|
||||
)
|
||||
continue
|
||||
# The shape of indexer integers should be broadcastable up to the
|
||||
# int_indexer_shape of the whole NDIndexer
|
||||
if not np.shape(idx):
|
||||
if (value := _maybe_concretize(idx)) and value >= s:
|
||||
raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
|
||||
# For ()-shaped indexers, we can broadcast no problm.
|
||||
continue
|
||||
# If we don't have a ()-shaped indexer, the rank must match
|
||||
# int_indexer_shape
|
||||
if np.ndim(idx) != len(self.int_indexer_shape):
|
||||
raise ValueError(
|
||||
f"Indexer must have rank {np.ndim(idx)}: {idx=} vs."
|
||||
f" {self.int_indexer_shape=}"
|
||||
)
|
||||
# Here we check that the shapes broadcast.
|
||||
try:
|
||||
np.broadcast_shapes(np.shape(idx), self.int_indexer_shape)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Could not broadcast integer indexer: {idx=} vs."
|
||||
f" {self.int_indexer_shape=}"
|
||||
) from e
|
||||
|
||||
def tree_flatten(self):
|
||||
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
|
||||
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, data, flat_idx):
|
||||
idx_tree, shape, int_indexer_shape = data
|
||||
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
|
||||
return NDIndexer(tuple(indices), shape, int_indexer_shape)
|
||||
|
||||
@classmethod
|
||||
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
||||
if indices == ...:
|
||||
indices = (slice(None),) * len(shape)
|
||||
if not isinstance(indices, tuple):
|
||||
indices = (indices,)
|
||||
if any(idx is ... for idx in indices):
|
||||
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
|
||||
# e.g. x[0, ..., 1].
|
||||
raise NotImplementedError("Ellipsis in indexer not supported yet.")
|
||||
if len(indices) > len(shape):
|
||||
raise ValueError("`indices` must not be longer than `shape`: "
|
||||
f"{indices=}, {shape=}")
|
||||
# Pad out indices with slice(None)
|
||||
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
|
||||
# Convert all `slice`s to `Slice`s
|
||||
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
|
||||
else i for i, s in zip(indices, shape))
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
||||
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
||||
indexer_shapes = [core.get_aval(i).shape for i in int_indexers]
|
||||
if indexer_shapes:
|
||||
try:
|
||||
bcast_shape = np.broadcast_shapes(*indexer_shapes)
|
||||
except ValueError as e:
|
||||
# Raise a nicer error than the NumPy one.
|
||||
raise ValueError("Cannot broadcast shapes for indexing: "
|
||||
f"{tuple(a for a in indexer_shapes)}") from e
|
||||
else:
|
||||
bcast_shape = ()
|
||||
# Here we use the `broadcast_to` primitive instead of composing lax
|
||||
# primitives together because it is easier to lower in targets like
|
||||
# Triton/Mosaic.
|
||||
from jax._src.state import primitives as sp # pytype: disable=import-error
|
||||
int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers]
|
||||
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
||||
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
|
||||
|
||||
def get_indexer_shape(self) -> tuple[int, ...]:
|
||||
_, slice_indexers, _ = unpack_ndindexer(self)
|
||||
slice_shape = [s.size for s in slice_indexers]
|
||||
# In NDIndexers, the int_indexer_shape is *always* at the front of the
|
||||
# result.
|
||||
return (*self.int_indexer_shape, *slice_shape)
|
@ -23,14 +23,17 @@ import numpy as np
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import Array
|
||||
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect,
|
||||
AccumEffect)
|
||||
from jax._src.util import safe_map, safe_zip, tuple_insert
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
|
||||
## General utilities
|
||||
@ -51,41 +54,14 @@ zip, unsafe_zip = safe_zip, zip
|
||||
# a:f32[3] <- x[]
|
||||
get_p = core.Primitive("get")
|
||||
|
||||
def _get_impl(ref: AbstractRef, *idx: int, **_):
|
||||
del ref, idx
|
||||
def _get_impl(ref: AbstractRef, *args: Any, tree):
|
||||
del ref, args, tree
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
get_p.def_impl(_get_impl)
|
||||
|
||||
Indexer = tuple[Union[int, slice, Array], ...]
|
||||
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
||||
|
||||
def _is_trivial_indexer(idx: Indexer) -> bool:
|
||||
if idx is ...:
|
||||
return True
|
||||
if type(idx) is tuple:
|
||||
if len(idx) == 0:
|
||||
return True
|
||||
return len(idx) == 1 and idx[0] is ...
|
||||
return False
|
||||
|
||||
def _unpack_idx(idx: Indexer, ndim: int
|
||||
) -> tuple[tuple[Array, ...], tuple[bool, ...]]:
|
||||
if _is_trivial_indexer(idx):
|
||||
idx = tuple(slice(None) for _ in range(ndim))
|
||||
indexed_dims_ = []
|
||||
non_slice_idx = []
|
||||
for i in idx:
|
||||
if isinstance(i, slice):
|
||||
if i.start is not None or i.stop is not None or i.step is not None:
|
||||
raise NotImplementedError("Reference indexing only supports trivial slices")
|
||||
indexed_dims_.append(False)
|
||||
else:
|
||||
non_slice_idx.append(i)
|
||||
indexed_dims_.append(True)
|
||||
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
|
||||
import jax.numpy as jnp
|
||||
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
|
||||
|
||||
def _get_slice_output_shape(in_shape: tuple[int, ...],
|
||||
idx_shapes: tuple[tuple[int, ...], ...],
|
||||
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
|
||||
@ -95,24 +71,30 @@ def _get_slice_output_shape(in_shape: tuple[int, ...],
|
||||
shape = (*shape_prefix, *shape_suffix)
|
||||
return shape
|
||||
|
||||
def _get_indexer(ref: AbstractRef, idx: Indexer
|
||||
) -> tuple[Indexer, tuple[bool, ...]]:
|
||||
if isinstance(ref.inner_aval, core.ShapedArray):
|
||||
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
||||
else:
|
||||
if not _is_trivial_indexer(idx):
|
||||
raise ValueError(
|
||||
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
|
||||
non_slice_idx, indexed_dims = (), ()
|
||||
return non_slice_idx, indexed_dims
|
||||
|
||||
def ref_get(ref: Any, idx: Indexer) -> Array:
|
||||
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||||
def _get_indexer(
|
||||
ref_or_view: Any, idx: Indexer | None, function_name: str
|
||||
) -> tuple[Any, tuple[indexing.NDIndexer, ...]]:
|
||||
if isinstance(ref_or_view, RefView):
|
||||
ref, indexers = ref_or_view.ref, ref_or_view.indexers
|
||||
else:
|
||||
ref, indexers = ref_or_view, ()
|
||||
ref_aval = core.get_aval(ref)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"Can only call `get` on a `Ref`: {ref}")
|
||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||||
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)
|
||||
raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
|
||||
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
return ref, ()
|
||||
if idx is None:
|
||||
return ref, indexers
|
||||
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
|
||||
return ref, (*indexers, nd_indexer)
|
||||
|
||||
|
||||
def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
|
||||
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||||
ref, indexers = _get_indexer(ref_or_view, idx, "ref_get")
|
||||
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||
return get_p.bind(ref, *flat_indexers, tree=tree)
|
||||
|
||||
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
||||
# b = swap_p.bind(x, a)
|
||||
@ -132,22 +114,21 @@ def ref_get(ref: Any, idx: Indexer) -> Array:
|
||||
# x:Ref{f32[3]}[i, j] <- a
|
||||
swap_p = core.Primitive("swap")
|
||||
|
||||
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_):
|
||||
del ref, value, idx
|
||||
def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
|
||||
del ref, value, idx, tree
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
swap_p.def_impl(_swap_impl)
|
||||
|
||||
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array:
|
||||
def ref_swap(ref_or_view: AbstractRef, idx: Indexer | None, value: Array,
|
||||
*, _function_name: str = "ref_swap") -> Array:
|
||||
"""Sets a `Ref`'s value and returns the original value."""
|
||||
ref_aval = core.get_aval(ref)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}")
|
||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||||
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
|
||||
ref, indexers = _get_indexer(ref_or_view, idx, _function_name)
|
||||
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||
return swap_p.bind(ref, value, *flat_indexers, tree=tree)
|
||||
|
||||
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
|
||||
def ref_set(ref: AbstractRef, idx: Indexer | None, value: Array) -> None:
|
||||
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
||||
ref_swap(ref, idx, value)
|
||||
ref_swap(ref, idx, value, _function_name="ref_set")
|
||||
|
||||
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
|
||||
# Semantically,
|
||||
@ -163,38 +144,40 @@ def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
|
||||
addupdate_p = core.Primitive('addupdate')
|
||||
addupdate_p.multiple_results = True
|
||||
|
||||
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int):
|
||||
del ref, idx, value
|
||||
def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
|
||||
del ref, value, args, tree
|
||||
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
||||
addupdate_p.def_impl(_addupdate_impl)
|
||||
|
||||
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None:
|
||||
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
|
||||
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||||
ref_aval = core.get_aval(ref)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}")
|
||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||||
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
|
||||
ref, indexers = _get_indexer(ref_or_view, idx, "ref_addupdate")
|
||||
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||
return addupdate_p.bind(ref, x, *flat_indexers, tree=tree)
|
||||
|
||||
## get/set/addupdate abstract evaluation rules
|
||||
|
||||
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
|
||||
indexed_dims):
|
||||
|
||||
def _shape_after_indexing(
|
||||
shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...]
|
||||
) -> tuple[int, ...]:
|
||||
for indexer in indexers:
|
||||
# Run some simple checks that all the indexers have consistent shapes
|
||||
assert indexer.shape == shape, (indexer.shape, shape)
|
||||
shape = indexer.get_indexer_shape()
|
||||
return shape
|
||||
|
||||
|
||||
def _get_abstract_eval(ref_aval: AbstractRef, *args,
|
||||
tree):
|
||||
indexers = tree_util.tree_unflatten(tree, args)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
raise ValueError("`get` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
|
||||
out_aval = ref_aval.inner_aval.update(shape=shape)
|
||||
out_shape = _shape_after_indexing(ref_aval.shape, indexers)
|
||||
out_aval = ref_aval.inner_aval.update(shape=out_shape)
|
||||
else:
|
||||
if idx:
|
||||
if indexers:
|
||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||
out_aval = ref_aval.inner_aval
|
||||
return (out_aval, {ReadEffect(0)})
|
||||
@ -202,34 +185,29 @@ get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
||||
|
||||
def _swap_abstract_eval(ref_aval: AbstractRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
|
||||
*args: Any, tree):
|
||||
indexers = tree_util.tree_unflatten(tree, args)
|
||||
out_aval: core.AbstractValue
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
expected_output_shape = _get_slice_output_shape(
|
||||
ref_aval.shape, idx_shapes, indexed_dims)
|
||||
if expected_output_shape != val_aval.shape:
|
||||
expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers)
|
||||
if expected_out_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Expected shape: {expected_out_shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
f"Indices: {indexers}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype)
|
||||
out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype)
|
||||
else:
|
||||
if idx:
|
||||
raise ValueError("`swap` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
if indexers:
|
||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||
out_aval = ref_aval.inner_aval
|
||||
return (out_aval, {WriteEffect(0)})
|
||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
@ -237,93 +215,118 @@ swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
def _addupdate_abstract_eval(ref_aval: AbstractRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
|
||||
*args: Any, tree):
|
||||
indexers = tree_util.tree_unflatten(tree, args)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||||
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
slice_shape = _shape_after_indexing(ref_aval.shape, indexers)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
slice_shape = _get_slice_output_shape(
|
||||
ref_aval.shape, idx_shapes, indexed_dims)
|
||||
if slice_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `addupdate`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Slice shape: {slice_shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
f"Indices: {indexers}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `addupdate`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
elif idx:
|
||||
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
else:
|
||||
# Check that the indexers are valid
|
||||
if indexers:
|
||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||
return [], {AccumEffect(0)}
|
||||
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
||||
|
||||
## Pretty printing for `get` and `swap` in jaxprs
|
||||
|
||||
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
|
||||
pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL,
|
||||
foreground=pp.Color.GREEN)
|
||||
|
||||
def _pp_idx(context, non_slice_idx, indexed_dims):
|
||||
idx_iter = iter(non_slice_idx)
|
||||
idx = ','.join(core.pp_var(next(idx_iter), context) if indexed else ':'
|
||||
for indexed in indexed_dims)
|
||||
assert next(idx_iter, None) is None
|
||||
return pp.text(idx)
|
||||
def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice
|
||||
) -> str:
|
||||
start, size = slc.start, slc.size
|
||||
if isinstance(start, core.Var):
|
||||
start_str = core.pp_var(start, context)
|
||||
end_str = f'{start_str}+{size}'
|
||||
else:
|
||||
start_str = '' if start == 0 else str(start)
|
||||
end = start + size
|
||||
end_str = '' if end == dim else str(end)
|
||||
return f'{start_str}:{end_str}'
|
||||
|
||||
def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer
|
||||
) -> pp.Doc:
|
||||
indices = []
|
||||
for idx, dim in zip(indexer.indices, indexer.shape):
|
||||
if isinstance(idx, indexing.Slice):
|
||||
indices.append(_pp_slice(context, dim, idx))
|
||||
else:
|
||||
indices.append(core.pp_var(idx, context)) # type: ignore
|
||||
return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")])
|
||||
|
||||
def _pp_indexers(
|
||||
context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...],
|
||||
):
|
||||
return pp.concat(
|
||||
[pp_indexer(context, indexer) for indexer in indexers]
|
||||
)
|
||||
|
||||
def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers):
|
||||
return pp_ref_var(
|
||||
pp.concat([
|
||||
pp.text(core.pp_var(ref, context)),
|
||||
_pp_indexers(context, indexers),
|
||||
])
|
||||
)
|
||||
|
||||
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||
# Pretty prints `a = get x i` as `x[i] <- a`
|
||||
y, = eqn.outvars
|
||||
x, *idx = eqn.invars
|
||||
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
||||
x, *flat_idx = eqn.invars
|
||||
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
|
||||
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
# TODO more general get
|
||||
return pp.concat([lhs, pp.text(' <- '), pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)), pp.text('['), idx, pp.text(']')]))])
|
||||
return pp.concat([
|
||||
lhs,
|
||||
pp.text(' <- '),
|
||||
pp_ref_indexers(context, x, indexers)
|
||||
])
|
||||
core.pp_eqn_rules[get_p] = _get_pp_rule
|
||||
|
||||
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||
y, = eqn.outvars
|
||||
x, v, *idx = eqn.invars
|
||||
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
||||
x, v, *flat_idx = eqn.invars
|
||||
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
|
||||
if type(y) is core.DropVar:
|
||||
# In the case of a set (ignored return value),
|
||||
# pretty print `_ = swap x v i` as `x[i] <- v`
|
||||
del y
|
||||
return pp.concat([
|
||||
pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), idx, pp.text(']')
|
||||
])), pp.text(' <- '), pp.text(core.pp_var(v, context))])
|
||||
pp_ref_indexers(context, x, indexers),
|
||||
pp.text(' <- '),
|
||||
pp.text(core.pp_var(v, context))
|
||||
])
|
||||
else:
|
||||
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
|
||||
x_i = pp.concat([pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), idx, pp.text(']')])
|
||||
x_i = pp_ref_indexers(context, x, indexers)
|
||||
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return pp.concat([y, pp.text(', '), pp_ref(x_i), pp.text(' <- '),
|
||||
pp_ref(x_i), pp.text(', '),
|
||||
return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '),
|
||||
x_i, pp.text(', '),
|
||||
pp.text(core.pp_var(v, context))])
|
||||
core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||||
|
||||
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||
del settings
|
||||
# pretty-print ` = addupdate x i v` as `x[i] += v`
|
||||
() = eqn.outvars
|
||||
x, v, *idx = eqn.invars
|
||||
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
||||
x, v, *flat_idx = eqn.invars
|
||||
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
|
||||
return pp.concat([
|
||||
pp_ref(pp.concat([
|
||||
pp.text(core.pp_var(x, context)),
|
||||
pp.text('['), idx, pp.text(']')
|
||||
])), pp.text(' += '), pp.text(core.pp_var(v, context))])
|
||||
pp_ref_indexers(context, x, indexers),
|
||||
pp.text(' += '),
|
||||
pp.text(core.pp_var(v, context))])
|
||||
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
||||
|
||||
## get/swap/addupdate JVP rules
|
||||
@ -366,6 +369,7 @@ def _get_transpose(g, ref, *idx, **params):
|
||||
ad.primitive_transposes[get_p] = _get_transpose
|
||||
|
||||
def _swap_transpose(g, ref, x, *idx, **params):
|
||||
del x # old value doesn't matter anymore
|
||||
# swap transpose is swap
|
||||
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
|
||||
return [None, x_bar] + [None] * len(idx)
|
||||
@ -403,101 +407,162 @@ def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
|
||||
num_idxs_to_left = sum(indexed_dims[:ref_dim])
|
||||
return ref_dim - num_idxs_to_left + len(idxs_shape)
|
||||
|
||||
def _get_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
def _batch_indexer(indexer: indexing.NDIndexer, dims,
|
||||
axis_size: int,
|
||||
ref_shape: tuple[int, ...],
|
||||
ref_dim: int | batching.NotMapped,
|
||||
idx_is_batched: bool) -> indexing.NDIndexer:
|
||||
indices = indexer.indices
|
||||
indices_dims = dims.indices
|
||||
new_indices: list[Array | indexing.Slice | int] = []
|
||||
new_integer_indexer_shape = (axis_size, *indexer.int_indexer_shape)
|
||||
for idx, dim in zip(indices, indices_dims):
|
||||
if idx_is_batched:
|
||||
# If at least one of the idx is batched, we broadcast them all and move the
|
||||
# batch dim to the front.
|
||||
if isinstance(idx, indexing.Slice):
|
||||
# size is static, but start can be dynamic
|
||||
# Check if start is static (which it can be)
|
||||
is_static_slice = len(tree_util.tree_leaves(idx)) == 0
|
||||
if is_static_slice:
|
||||
new_indices.append(idx)
|
||||
continue
|
||||
dim = dim.start
|
||||
if dim is batching.not_mapped:
|
||||
# Broadcasting the slice is free (the start index stays the same)
|
||||
new_indices.append(idx)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"No support for vmapping over nontrivial slices just yet: {idx}")
|
||||
else:
|
||||
# Check if we are indexing with a scalar or not. If we are indexing
|
||||
# with a scalar and we are not batched, we can avoid broadcasting it.
|
||||
assert hasattr(idx, "shape")
|
||||
if not idx.shape:
|
||||
if dim is not batching.not_mapped:
|
||||
assert idx.shape == (axis_size,)
|
||||
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, (0,))
|
||||
new_indices.append(idx)
|
||||
else:
|
||||
if dim is batching.not_mapped:
|
||||
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
|
||||
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
|
||||
bcast_dims)
|
||||
else:
|
||||
idx = batching.moveaxis(idx, dim, 0)
|
||||
new_indices.append(idx)
|
||||
else:
|
||||
if ref_dim is not batching.not_mapped:
|
||||
if not isinstance(idx, indexing.Slice):
|
||||
assert hasattr(idx, "shape")
|
||||
if idx.shape:
|
||||
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
|
||||
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
|
||||
bcast_dims)
|
||||
new_indices.append(idx)
|
||||
if ref_dim is not batching.not_mapped:
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0)
|
||||
new_indices.insert(ref_dim, iota)
|
||||
return indexing.NDIndexer(tuple(new_indices), ref_shape,
|
||||
new_integer_indexer_shape,
|
||||
validate=True)
|
||||
|
||||
def _get_vmap(batched_args, batched_dims, *, tree):
|
||||
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||||
if d is not batching.not_mapped}
|
||||
ref, *idxs = batched_args
|
||||
ref_dim, *idx_dims = batched_dims
|
||||
ref, *flat_idxs = batched_args
|
||||
ref_dim, *flat_idx_dims = batched_dims
|
||||
indexers = tree_util.tree_unflatten(tree, flat_idxs)
|
||||
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
|
||||
|
||||
ref_is_batched = ref_dim is not batching.not_mapped
|
||||
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
||||
bdim_out = 0
|
||||
|
||||
if idx_is_batched:
|
||||
# If at least one of the idx is batched, we broadcast them all and move the
|
||||
# batch dim to the front.
|
||||
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
|
||||
in zip(idxs, idx_dims))
|
||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
||||
if ref_is_batched:
|
||||
# If ref is batched, we are doing a `get` with an additional axis. If `idxs`
|
||||
# are also batched, then we are indexing into the batch axis with an `iota`.
|
||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, idx_is_batched)
|
||||
if idx_is_batched:
|
||||
# If we have batched idx, we need to insert the new iota index. The place
|
||||
# where we add in the new `iota` index is `ref_dim` so we need to compute
|
||||
# what `ref_dim` *would be* if we inserted it into `idxs` instead, because
|
||||
# `idxs` doesn't include the non indexed dims.
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
else:
|
||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||||
return get_p.bind(ref, *idxs, indexed_dims=indexed_dims), bdim_out
|
||||
idx_is_batched = any(i_dim is not batching.not_mapped
|
||||
for i_dim in flat_idx_dims)
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Batching with multiple indexers not supported.")
|
||||
# TODO(sharadmv): handle vmap of multiple indexers
|
||||
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
|
||||
ref.shape, ref_dim, idx_is_batched)
|
||||
for indexer, dims in zip(indexers, indexers_dims))
|
||||
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||
return get_p.bind(ref, *flat_indexers, tree=tree), 0
|
||||
batching.primitive_batchers[get_p] = _get_vmap
|
||||
|
||||
def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
def _swap_vmap(batched_args, batched_dims, *, tree):
|
||||
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||||
if d is not batching.not_mapped}
|
||||
ref, val, *idxs = batched_args
|
||||
ref_dim, val_dim, *idx_dims = batched_dims
|
||||
ref, val, *flat_idxs = batched_args
|
||||
ref_dim, val_dim, *flat_idx_dims = batched_dims
|
||||
indexers = tree_util.tree_unflatten(tree, flat_idxs)
|
||||
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
|
||||
|
||||
ref_is_batched = ref_dim is not batching.not_mapped
|
||||
val_is_batched = val_dim is not batching.not_mapped
|
||||
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
||||
if idx_is_batched:
|
||||
# If at least one of the idx is batched, we broadcast them all and move the
|
||||
# batch dim to the front.
|
||||
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
|
||||
in zip(idxs, idx_dims))
|
||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
||||
if ref_is_batched and not idx_is_batched:
|
||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
|
||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||||
if not val_is_batched:
|
||||
val = batching.broadcast(val, axis_size, 0)
|
||||
val_dim = 0
|
||||
val = batching.moveaxis(val, val_dim, bdim_out)
|
||||
elif idx_is_batched:
|
||||
assert ref_is_batched and val_is_batched
|
||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
idx_is_batched = any(i_dim is not batching.not_mapped
|
||||
for i_dim in flat_idx_dims)
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Batching with multiple indexers not supported.")
|
||||
# TODO(sharadmv): handle vmap of multiple indexers
|
||||
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
|
||||
ref.shape, ref_dim, idx_is_batched)
|
||||
for indexer, dims in zip(indexers, indexers_dims))
|
||||
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||
if (ref_is_batched or idx_is_batched) and not val_is_batched:
|
||||
val = batching.broadcast(val, axis_size, 0)
|
||||
if val_is_batched:
|
||||
val = batching.moveaxis(val, val_dim, 0)
|
||||
bdim_out = 0
|
||||
return swap_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), bdim_out
|
||||
return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0
|
||||
batching.primitive_batchers[swap_p] = _swap_vmap
|
||||
|
||||
def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
def _addupdate_vmap(batched_args, batched_dims, *, tree):
|
||||
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||||
if d is not batching.not_mapped}
|
||||
ref, val, *idxs = batched_args
|
||||
ref_dim, val_dim, *idx_dims = batched_dims
|
||||
ref, val, *flat_idxs = batched_args
|
||||
ref_dim, val_dim, *flat_idx_dims = batched_dims
|
||||
indexers = tree_util.tree_unflatten(tree, flat_idxs)
|
||||
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
|
||||
|
||||
ref_is_batched = ref_dim is not batching.not_mapped
|
||||
val_is_batched = val_dim is not batching.not_mapped
|
||||
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
||||
if idx_is_batched:
|
||||
# If at least one of the idx is batched, we ensure all have bdims at front.
|
||||
idxs = tuple(batching.bdim_at_front(i, d, axis_size)
|
||||
for i, d in zip(idxs, idx_dims))
|
||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
||||
if ref_is_batched and not idx_is_batched:
|
||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
|
||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||||
if not val_is_batched:
|
||||
val = batching.broadcast(val, axis_size, 0)
|
||||
val_dim = 0
|
||||
val = batching.moveaxis(val, val_dim, bdim_out)
|
||||
elif idx_is_batched:
|
||||
assert ref_is_batched and val_is_batched
|
||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
idx_is_batched = any(i_dim is not batching.not_mapped
|
||||
for i_dim in flat_idx_dims)
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("Batching with multiple indexers not supported.")
|
||||
# TODO(sharadmv): handle vmap of multiple indexers
|
||||
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
|
||||
ref.shape, ref_dim, idx_is_batched)
|
||||
for indexer, dims in zip(indexers, indexers_dims))
|
||||
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||
if (ref_is_batched or idx_is_batched) and not val_is_batched:
|
||||
val = batching.broadcast(val, axis_size, 0)
|
||||
if val_is_batched:
|
||||
val = batching.moveaxis(val, val_dim, 0)
|
||||
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
|
||||
return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), []
|
||||
batching.primitive_batchers[addupdate_p] = _addupdate_vmap
|
||||
|
||||
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
|
||||
# We could use `jnp.broadcast_to` but that lowers to squeezing,
|
||||
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
|
||||
# so in the lowering, we have to expand out those squeezed dimensions again.
|
||||
# Having a simple `broadcast_to` primitive allows us to lower directly
|
||||
# to `tl.broadcast_to`.
|
||||
broadcast_to_p = core.Primitive('broadcast_to')
|
||||
|
||||
def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array:
|
||||
import jax.numpy as jnp
|
||||
a = jnp.asarray(a)
|
||||
if a.shape == shape:
|
||||
return a
|
||||
return broadcast_to_p.bind(a, shape=shape)
|
||||
|
||||
@broadcast_to_p.def_impl
|
||||
def _broadcast_to_impl(a, *, shape):
|
||||
import jax.numpy as jnp
|
||||
return jnp.broadcast_to(a, shape)
|
||||
|
||||
@broadcast_to_p.def_abstract_eval
|
||||
def _broadcast_to_abstract_eval(aval, *, shape):
|
||||
return core.ShapedArray(shape, aval.dtype)
|
||||
|
||||
mlir.register_lowering(
|
||||
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ from typing import Any, Generic, TypeVar, Union
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.state import indexing
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.typing import Array
|
||||
|
||||
@ -77,21 +78,34 @@ Aval = TypeVar("Aval", bound=core.AbstractValue)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefIndexer:
|
||||
ref: Any
|
||||
ref_or_view: Any
|
||||
|
||||
def __getitem__(self, slc):
|
||||
if not isinstance(slc, tuple):
|
||||
slc = (slc,)
|
||||
return RefView(self.ref, slc)
|
||||
indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)
|
||||
if isinstance(self.ref_or_view, RefView):
|
||||
view = self.ref_or_view
|
||||
return RefView(view.ref, (*view.indexers, indexer))
|
||||
return RefView(self.ref_or_view, (indexer,))
|
||||
|
||||
Indexer = Any
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefView:
|
||||
ref: Any
|
||||
indexer: Any
|
||||
indexers: tuple[indexing.NDIndexer, ...]
|
||||
|
||||
@property
|
||||
def at(self):
|
||||
raise NotImplementedError("Can't call `.at` multiple times.")
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
assert (
|
||||
len(self.indexers) > 0
|
||||
), "Should not be able to create a trivial RefView"
|
||||
return self.indexers[-1].get_indexer_shape()
|
||||
|
||||
@property
|
||||
def at(self) -> RefIndexer:
|
||||
return RefIndexer(self)
|
||||
|
||||
|
||||
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
||||
@ -137,14 +151,10 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
|
||||
return ref_set(tracer, idx, value)
|
||||
|
||||
def _getitem(self, tracer, idx) -> Array:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
||||
return ref_get(tracer, idx)
|
||||
|
||||
def _setitem(self, tracer, idx, value) -> None:
|
||||
if not isinstance(idx, tuple):
|
||||
idx = idx,
|
||||
from jax._src.state.primitives import ref_set # pytype: disable=import-error
|
||||
return ref_set(tracer, idx, value)
|
||||
|
||||
|
@ -1472,6 +1472,7 @@ tf_not_yet_impl = [
|
||||
"for",
|
||||
"inspect_sharding",
|
||||
"io_callback",
|
||||
"broadcast_to",
|
||||
"shard_map",
|
||||
"global_array_to_host_local_array",
|
||||
"host_local_array_to_global_array",
|
||||
|
@ -17,10 +17,6 @@
|
||||
from jax._src import pallas
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.core import no_block_spec
|
||||
from jax._src.pallas.indexing import broadcast_to
|
||||
from jax._src.pallas.indexing import ds
|
||||
from jax._src.pallas.indexing import dslice
|
||||
from jax._src.pallas.indexing import Slice
|
||||
from jax._src.pallas.pallas_call import pallas_call
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.pallas.primitives import atomic_add
|
||||
@ -42,6 +38,10 @@ from jax._src.pallas.utils import cdiv
|
||||
from jax._src.pallas.utils import next_power_of_2
|
||||
from jax._src.pallas.utils import strides_from_shape
|
||||
from jax._src.pallas.utils import when
|
||||
from jax._src.state.primitives import broadcast_to
|
||||
from jax._src.state.indexing import ds
|
||||
from jax._src.state.indexing import dslice
|
||||
from jax._src.state.indexing import Slice
|
||||
|
||||
try:
|
||||
from jax.experimental.pallas import gpu # pytype: disable=import-error
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import util
|
||||
from jax._src.pallas import indexing
|
||||
from jax._src.state import indexing
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
@ -49,7 +49,8 @@ def int_indexer_strategy(dim) -> hps.SearchStrategy[int]:
|
||||
@hps.composite
|
||||
def slice_indexer_strategy(draw, dim) -> Slice | slice:
|
||||
start = draw(int_indexer_strategy(dim))
|
||||
size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max))
|
||||
max_size = dim - start
|
||||
size = draw(hps.integers(min_value=0, max_value=max_size))
|
||||
return draw(
|
||||
hps.one_of(
|
||||
hps.just(Slice(start, size)), hps.just(slice(start, start + size))
|
||||
@ -78,8 +79,8 @@ def indexer_strategy(draw, dim, int_indexer_shape
|
||||
def nd_indexer_strategy(draw, shape) -> NDIndexer:
|
||||
num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
|
||||
int_indexer_shape = draw(hnp.array_shapes())
|
||||
indices = [draw(indexer_strategy(dim, int_indexer_shape)) for dim
|
||||
in shape[:num_indices]]
|
||||
indices = tuple(draw(indexer_strategy(dim, int_indexer_shape))
|
||||
for dim in shape[:num_indices])
|
||||
return NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
|
||||
@ -97,6 +98,24 @@ class IndexerTest(parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_invalid_ndindexer_oob_int(self):
|
||||
indices = (4, 0)
|
||||
shape = (3, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_invalid_ndindexer_oob_slice_start(self):
|
||||
indices = (slice(3, 2), 0)
|
||||
shape = (3, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_invalid_ndindexer_oob_slice_end(self):
|
||||
indices = (Slice(2, 2), 0)
|
||||
shape = (3, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_ndindexer_with_padding(self):
|
||||
indices = ()
|
||||
shape = (5, 5)
|
||||
@ -137,17 +156,17 @@ class IndexerTest(parameterized.TestCase):
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3))
|
||||
|
||||
indices = (0, slice(4, 10), np.arange(5))
|
||||
indices = (0, slice(2, 10), np.arange(5))
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 0))
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 1))
|
||||
|
||||
indices = (0, 5, np.arange(5))
|
||||
indices = (0, 1, np.arange(5))
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
|
||||
|
||||
indices = (ds(2, 3), np.arange(5)[:, None], np.arange(4)[None])
|
||||
indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None])
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 3))
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2))
|
||||
|
||||
@hp.given(hps.data())
|
||||
def test_ndindexer(self, data):
|
||||
|
@ -1471,7 +1471,7 @@ class PallasPrimitivesTest(PallasTest):
|
||||
@parameterized.parameters(*[
|
||||
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "<- a[:,:,:]"),
|
||||
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "<- a[:3,:,:]"),
|
||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:4,:,:4]"),
|
||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:,:,:4]"),
|
||||
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "<- a[b,:,:4]"),
|
||||
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.ds(4)), "<- a[f,g,:4]"),
|
||||
])
|
||||
@ -1486,7 +1486,7 @@ class PallasPrimitivesTest(PallasTest):
|
||||
@parameterized.parameters(*[
|
||||
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "a[:,:,:] <-"),
|
||||
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "a[:3,:,:] <-"),
|
||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:4,:,:4] <-"),
|
||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:,:,:4] <-"),
|
||||
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "a[b,:,:4] <-"),
|
||||
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), "a[m,n,:4] <-"),
|
||||
])
|
||||
@ -1504,7 +1504,7 @@ class PallasPrimitivesTest(PallasTest):
|
||||
(lambda: (pl.dslice(0, 3), slice(None), slice(None)),
|
||||
"c:i32[3,3,2], a[:3,:,:] <-"),
|
||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)),
|
||||
"c:i32[3,3,4], a[1:4,:,:4] <-"),
|
||||
"c:i32[3,3,4], a[1:,:,:4] <-"),
|
||||
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)),
|
||||
"e:i32[5,3,4], a[b,:,:4] <-"),
|
||||
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)),
|
||||
|
@ -56,15 +56,15 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cant_eval_get_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
get_p.bind(jnp.ones(5))
|
||||
get_p.bind(jnp.ones(5), tree=None)
|
||||
|
||||
def test_cant_eval_swap_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
swap_p.bind(jnp.ones(5), jnp.zeros(5))
|
||||
swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
|
||||
|
||||
def test_cant_eval_addupdate_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
|
||||
addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
|
||||
|
||||
def test_get_abstract_aval_must_take_in_refs(self):
|
||||
ref_aval = core.ShapedArray((), jnp.float32)
|
||||
@ -95,11 +95,37 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||
dict(testcase_name="get_with_nontrivial_slice",
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||
dict(testcase_name="get_with_nontrivial_slice2",
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
|
||||
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
|
||||
dict(testcase_name="get_with_ref_simple_at",
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(1, 3), slice(None), slice(None)),
|
||||
out_shape=(2, 2, 4), out_dtype=jnp.float32,
|
||||
at_indices=((0,),)),
|
||||
dict(testcase_name="get_with_ref_simple_at2",
|
||||
ref_shape=(6, 1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(0, 2), slice(0, 1), slice(1, 3), slice(None), slice(None)),
|
||||
out_shape=(2, 1, 2, 2, 4), out_dtype=jnp.float32,
|
||||
at_indices=((slice(2, 6),),)),
|
||||
dict(testcase_name="get_with_ref_multiple_at",
|
||||
ref_shape=(1, 3, 5, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(None), slice(None), slice(0, 2)),
|
||||
out_shape=(3, 1, 2), out_dtype=jnp.float32,
|
||||
at_indices=((0,), (slice(None), slice(0, 1)))),
|
||||
)
|
||||
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
|
||||
out_dtype=None, should_error=False):
|
||||
out_dtype=None, at_indices=(),
|
||||
should_error=False):
|
||||
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||
def f(x_ref):
|
||||
for at_idx in at_indices:
|
||||
x_ref = x_ref.at[at_idx]
|
||||
out = ref_get(x_ref, idx)
|
||||
return [out]
|
||||
if should_error:
|
||||
@ -160,13 +186,43 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
||||
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||
dict(testcase_name="swap_with_nontrivial_slice",
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
||||
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||
dict(testcase_name="swap_with_nontrivial_slice2",
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
|
||||
val_shape=(1, 2, 2, 4), val_dtype=jnp.float32,
|
||||
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
|
||||
dict(testcase_name="swap_with_ref_simple_at",
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(0, 1), slice(1, 3), slice(None),),
|
||||
val_shape=(1, 1, 4), val_dtype=jnp.float32,
|
||||
out_shape=(1, 1, 4), out_dtype=jnp.float32,
|
||||
at_indices=((0,),),),
|
||||
dict(testcase_name="swap_with_ref_simple_at2",
|
||||
ref_shape=(4, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
|
||||
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
|
||||
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
|
||||
at_indices=((slice(0, 2),),),),
|
||||
dict(testcase_name="swap_with_ref_multiple_at2",
|
||||
ref_shape=(1, 4, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
|
||||
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
|
||||
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
|
||||
at_indices=((slice(None), slice(0, 2),), (0,)),),
|
||||
)
|
||||
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
|
||||
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
||||
should_error=False):
|
||||
at_indices=(), should_error=False):
|
||||
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||
val_aval = core.ShapedArray(val_shape, val_dtype)
|
||||
def f(x_ref, val):
|
||||
for at_idx in at_indices:
|
||||
x_ref = x_ref.at[at_idx]
|
||||
out = ref_swap(x_ref, idx, val)
|
||||
return [out]
|
||||
if should_error:
|
||||
@ -192,13 +248,13 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
idx=(slice(None),), should_error=True),
|
||||
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
|
||||
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
|
||||
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
|
||||
idx=(),),
|
||||
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
|
||||
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
|
||||
idx=(), should_error=True),
|
||||
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
|
||||
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
||||
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
|
||||
idx=(0,),),
|
||||
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
|
||||
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
|
||||
idx=(slice(None), 0)),
|
||||
@ -216,13 +272,34 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
||||
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
|
||||
dict(testcase_name="ref_with_simple_at",
|
||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
val_shape=(2, 2), val_dtype=jnp.float32,
|
||||
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
at_indices=((0,),)),
|
||||
dict(testcase_name="ref_with_simple_at2",
|
||||
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
val_shape=(2, 3, 4), val_dtype=jnp.float32,
|
||||
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
at_indices=((slice(0, 3),),)),
|
||||
dict(testcase_name="ref_with_multiple_at",
|
||||
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
val_shape=(2, 2), val_dtype=jnp.float32,
|
||||
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
at_indices=((slice(0, 3),), (0,))),
|
||||
dict(testcase_name="ref_with_multiple_at2",
|
||||
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
|
||||
val_shape=(2, 2), val_dtype=jnp.float32,
|
||||
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||
at_indices=((slice(None), slice(0, 3),), (0,))),
|
||||
)
|
||||
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
|
||||
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
||||
should_error=False):
|
||||
val_shape, val_dtype, idx, at_indices=(), should_error=False):
|
||||
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||
val_aval = core.ShapedArray(val_shape, val_dtype)
|
||||
def f(x_ref, val):
|
||||
for at_idx in at_indices:
|
||||
x_ref = x_ref.at[at_idx]
|
||||
ref_addupdate(x_ref, idx, val)
|
||||
return []
|
||||
if should_error:
|
||||
@ -1595,7 +1672,6 @@ if CAN_USE_HYPOTHESIS:
|
||||
def test_vjp(self, data):
|
||||
|
||||
spec = data.draw(func_spec())
|
||||
print(spec)
|
||||
|
||||
def impl(x):
|
||||
return spec.call((x, jnp.zeros_like(x)))[1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user