mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas] Refactor indexing primitives to use NDIndexer abstraction
Some notes about this change: * This change upgrades the `RefView` abstraction to store multiple indexers. This allows doing things like `ref.at[0].at[0]` to recursively create a view of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s. * This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p) but does *not* generalize their rules. Most of the rules will raise a NotImplementedError if you use multiple `NDIndexer`s. Adding support will be done in a future CL. * With the above in mind, this change only preserves existing public facing APIs and adding actual support will involve updating the rules. PiperOrigin-RevId: 595229523
This commit is contained in:
parent
8c5e7b26ba
commit
836563fadf
@ -738,15 +738,17 @@ pytype_strict_library(
|
|||||||
name = "state_types",
|
name = "state_types",
|
||||||
srcs = [
|
srcs = [
|
||||||
"_src/state/__init__.py",
|
"_src/state/__init__.py",
|
||||||
|
"_src/state/indexing.py",
|
||||||
"_src/state/types.py",
|
"_src/state/types.py",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
":effects",
|
":effects",
|
||||||
":pretty_printer",
|
":pretty_printer",
|
||||||
|
":tree_util",
|
||||||
":typing",
|
":typing",
|
||||||
":util",
|
":util",
|
||||||
],
|
] + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
pytype_strict_library(
|
pytype_strict_library(
|
||||||
|
@ -1,155 +0,0 @@
|
|||||||
# Copyright 2023 The JAX Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import jax
|
|
||||||
from jax import core as jax_core
|
|
||||||
from jax import tree_util
|
|
||||||
from jax._src.interpreters import mlir
|
|
||||||
from jax._src.util import merge_lists
|
|
||||||
from jax._src.util import partition_list
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
|
|
||||||
# We could use `jnp.broadcast_to` but that lowers to squeezing,
|
|
||||||
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
|
|
||||||
# so in the lowering, we have to expand out those squeezed dimensions again.
|
|
||||||
# Having a simple `broadcast_to` primitive allows us to lower directly
|
|
||||||
# to `tl.broadcast_to`.
|
|
||||||
broadcast_to_p = jax_core.Primitive('broadcast_to')
|
|
||||||
|
|
||||||
def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array:
|
|
||||||
if a.shape == shape:
|
|
||||||
return a
|
|
||||||
return broadcast_to_p.bind(a, shape=shape)
|
|
||||||
|
|
||||||
@broadcast_to_p.def_impl
|
|
||||||
def _broadcast_to_impl(a, *, shape):
|
|
||||||
return jnp.broadcast_to(a, shape)
|
|
||||||
|
|
||||||
@broadcast_to_p.def_abstract_eval
|
|
||||||
def _broadcast_to_abstract_eval(aval, *, shape):
|
|
||||||
return jax_core.ShapedArray(shape, aval.dtype)
|
|
||||||
|
|
||||||
mlir.register_lowering(
|
|
||||||
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tree_util.register_pytree_node_class
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Slice:
|
|
||||||
"""Represents a slice with a dynamic start index and a fixed size."""
|
|
||||||
start: Any
|
|
||||||
size: int
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.size < 0:
|
|
||||||
raise ValueError("`size` must not be negative.")
|
|
||||||
|
|
||||||
def tree_flatten(self):
|
|
||||||
# If `start` is statically known, we treat it as static information
|
|
||||||
if isinstance(self.start, int):
|
|
||||||
return (), (self.start, self.size)
|
|
||||||
return (self.start,), (self.size,)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tree_unflatten(cls, aux_data, children) -> Slice:
|
|
||||||
return cls(*children, *aux_data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_slice(cls, slc: slice, size: int) -> Slice:
|
|
||||||
start, stop, step = slc.indices(size)
|
|
||||||
if step != 1:
|
|
||||||
raise ValueError(f"slice must have a step of 1 (found: {step})")
|
|
||||||
return cls(start, stop - start)
|
|
||||||
|
|
||||||
|
|
||||||
def dslice(start: int | jax.Array | None, size: int | None = None
|
|
||||||
) -> slice | Slice:
|
|
||||||
"""Constructs a `Slice` from a start and a size."""
|
|
||||||
if start is None:
|
|
||||||
return slice(None)
|
|
||||||
if size is None:
|
|
||||||
if not isinstance(start, int):
|
|
||||||
raise ValueError("Non-static `dslice`")
|
|
||||||
return Slice(0, start)
|
|
||||||
return Slice(start, size)
|
|
||||||
ds = dslice # Handy alias
|
|
||||||
|
|
||||||
@tree_util.register_pytree_node_class
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class NDIndexer:
|
|
||||||
indices: tuple[int | Slice | jax.Array, ...]
|
|
||||||
shape: tuple[int, ...]
|
|
||||||
int_indexer_shape: tuple[int, ...]
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if len(self.indices) != len(self.shape):
|
|
||||||
raise ValueError("`indices` must be the same length as `Ref` shape.")
|
|
||||||
|
|
||||||
def tree_flatten(self):
|
|
||||||
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
|
|
||||||
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
|
|
||||||
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
|
|
||||||
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
|
|
||||||
self.int_indexer_shape)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tree_unflatten(cls, data, flat_idx):
|
|
||||||
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
|
|
||||||
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
|
|
||||||
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
|
|
||||||
return NDIndexer(tuple(indices), shape, int_indexer_shape)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
|
||||||
if len(indices) > len(shape):
|
|
||||||
raise ValueError("`indices` must not be longer than `shape`.")
|
|
||||||
# Pad out indices with slice(None)
|
|
||||||
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
|
|
||||||
# Convert all `slice`s to `Slice`s
|
|
||||||
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
|
|
||||||
else i for i, s in zip(indices, shape))
|
|
||||||
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
|
||||||
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
|
||||||
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
|
|
||||||
int_indexers]
|
|
||||||
indexer_shapes = [i.shape for i in int_indexers]
|
|
||||||
if indexer_shapes:
|
|
||||||
try:
|
|
||||||
bcast_shape = np.broadcast_shapes(*indexer_shapes)
|
|
||||||
except ValueError as e:
|
|
||||||
# Raise a nicer error than the NumPy one.
|
|
||||||
raise ValueError("Cannot broadcast shapes for indexing: "
|
|
||||||
f"{tuple(a for a in indexer_shapes)}") from e
|
|
||||||
else:
|
|
||||||
bcast_shape = ()
|
|
||||||
int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers]
|
|
||||||
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
|
||||||
return NDIndexer(tuple(indices), shape, bcast_shape)
|
|
||||||
|
|
||||||
def get_indexer_shape(self) -> tuple[int, ...]:
|
|
||||||
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
|
|
||||||
other_indexers, _ = partition_list(is_int_indexing, self.indices)
|
|
||||||
other_shape = [s.size for s in other_indexers] # type: ignore
|
|
||||||
return (*self.int_indexer_shape, *other_shape)
|
|
@ -42,12 +42,12 @@ from jax._src.lib.mlir.dialects import memref
|
|||||||
from jax._src.lib.mlir.dialects import scf
|
from jax._src.lib.mlir.dialects import scf
|
||||||
from jax._src.lib.mlir.dialects import vector
|
from jax._src.lib.mlir.dialects import vector
|
||||||
from jax._src.pallas import core
|
from jax._src.pallas import core
|
||||||
from jax._src.pallas import indexing
|
|
||||||
from jax._src.pallas import primitives
|
from jax._src.pallas import primitives
|
||||||
from jax._src.pallas import utils as pallas_utils
|
from jax._src.pallas import utils as pallas_utils
|
||||||
from jax._src.pallas.mosaic import core as tpu_core
|
from jax._src.pallas.mosaic import core as tpu_core
|
||||||
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
||||||
from jax._src.state import discharge as state_discharge
|
from jax._src.state import discharge as state_discharge
|
||||||
|
from jax._src.state import indexing
|
||||||
from jax._src.state import primitives as state_primitives
|
from jax._src.state import primitives as state_primitives
|
||||||
from jax._src.util import safe_map
|
from jax._src.util import safe_map
|
||||||
from jax._src.util import safe_zip
|
from jax._src.util import safe_zip
|
||||||
@ -610,12 +610,16 @@ def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
|
|||||||
|
|
||||||
|
|
||||||
def _get_lowering_rule(
|
def _get_lowering_rule(
|
||||||
ctx: LoweringRuleContext, ref, *non_slice_idx, indexed_dims: Sequence[bool]
|
ctx: LoweringRuleContext, ref, *idx, tree,
|
||||||
):
|
):
|
||||||
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
|
||||||
|
if len(indexers) > 1:
|
||||||
|
raise NotImplementedError("Only one indexer currently supported.")
|
||||||
# Call _load_lowering_rule (since it's more general)
|
# Call _load_lowering_rule (since it's more general)
|
||||||
ref_aval, *non_slice_idx_avals = ctx.avals_in
|
ref_aval, *non_slice_idx_avals = ctx.avals_in
|
||||||
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
|
nd_indexer = indexers[0]
|
||||||
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
|
nd_indexer_avals = indexers_avals[0]
|
||||||
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None))
|
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None))
|
||||||
avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None))
|
avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None))
|
||||||
ctx = ctx.replace(avals_in=avals_flat)
|
ctx = ctx.replace(avals_in=avals_flat)
|
||||||
@ -630,13 +634,17 @@ def _swap_lowering_rule(
|
|||||||
ctx: LoweringRuleContext,
|
ctx: LoweringRuleContext,
|
||||||
ref,
|
ref,
|
||||||
val,
|
val,
|
||||||
*non_slice_idx,
|
*idx,
|
||||||
indexed_dims: Sequence[bool],
|
tree
|
||||||
):
|
):
|
||||||
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
|
||||||
|
if len(indexers) > 1:
|
||||||
|
raise NotImplementedError("Only one indexer currently supported.")
|
||||||
# Call _masked_swap_lowering_rule (since it's more general)
|
# Call _masked_swap_lowering_rule (since it's more general)
|
||||||
ref_aval, val_aval, *non_slice_idx_avals = ctx.avals_in
|
ref_aval, val_aval, *_ = ctx.avals_in
|
||||||
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
|
nd_indexer = indexers[0]
|
||||||
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
|
nd_indexer_avals = indexers_avals[0]
|
||||||
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None))
|
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None))
|
||||||
avals_flat = tree_util.tree_leaves(
|
avals_flat = tree_util.tree_leaves(
|
||||||
(ref_aval, nd_indexer_avals, val_aval, None)
|
(ref_aval, nd_indexer_avals, val_aval, None)
|
||||||
|
@ -29,10 +29,10 @@ from jax._src import pretty_printer as pp
|
|||||||
from jax._src import state
|
from jax._src import state
|
||||||
from jax._src import tree_util
|
from jax._src import tree_util
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src.state import primitives as state_primitives
|
from jax._src.state import indexing
|
||||||
|
from jax._src.state import primitives as sp
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
from jax._src.pallas import indexing
|
|
||||||
from jax._src.pallas.mosaic import core as tpu_core
|
from jax._src.pallas.mosaic import core as tpu_core
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
@ -264,38 +264,6 @@ def _dma_start_abstract_eval(*args, tree, device_id_type):
|
|||||||
del args, tree, device_id_type
|
del args, tree, device_id_type
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _pp_slice(slc: indexing.Slice, dim: int, context: jax_core.JaxprPpContext
|
|
||||||
) -> str:
|
|
||||||
start, size = slc.start, slc.size
|
|
||||||
if isinstance(start, jax_core.Var):
|
|
||||||
start_str = jax_core.pp_var(start, context)
|
|
||||||
end_str = f'{start_str}+{size}'
|
|
||||||
else:
|
|
||||||
start_str = '' if start == 0 else str(start)
|
|
||||||
end = start + size
|
|
||||||
end_str = '' if end == dim else str(end)
|
|
||||||
return f'{start_str}:{end_str}'
|
|
||||||
|
|
||||||
def _pp_indexer(indexer: indexing.NDIndexer,
|
|
||||||
context: jax_core.JaxprPpContext) -> pp.Doc:
|
|
||||||
indices = []
|
|
||||||
for idx, dim in zip(indexer.indices, indexer.shape):
|
|
||||||
if isinstance(idx, indexing.Slice):
|
|
||||||
indices.append(_pp_slice(idx, dim, context))
|
|
||||||
else:
|
|
||||||
indices.append(jax_core.pp_var(idx, context)) # type: ignore
|
|
||||||
return pp.text(','.join(indices))
|
|
||||||
|
|
||||||
def _pp_ref(ref, indexer, context):
|
|
||||||
return state_primitives.pp_ref(
|
|
||||||
pp.concat([
|
|
||||||
pp.text(jax_core.pp_var(ref, context)),
|
|
||||||
pp.text("["),
|
|
||||||
_pp_indexer(indexer, context),
|
|
||||||
pp.text("]"),
|
|
||||||
])
|
|
||||||
)
|
|
||||||
|
|
||||||
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||||
context: jax_core.JaxprPpContext,
|
context: jax_core.JaxprPpContext,
|
||||||
settings: jax_core.JaxprPpSettings):
|
settings: jax_core.JaxprPpSettings):
|
||||||
@ -309,9 +277,9 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
|||||||
return pp.concat([
|
return pp.concat([
|
||||||
pp.text('dma_start'),
|
pp.text('dma_start'),
|
||||||
pp.text(' '),
|
pp.text(' '),
|
||||||
_pp_ref(src_ref, src_indexer, context),
|
sp.pp_ref_indexers(context, src_ref, (src_indexer,)),
|
||||||
pp.text(' -> '),
|
pp.text(' -> '),
|
||||||
_pp_ref(dst_ref, dst_indexer, context),
|
sp.pp_ref_indexers(context, dst_ref, (dst_indexer,)),
|
||||||
pp.text(' '),
|
pp.text(' '),
|
||||||
pp.text(jax_core.pp_var(dst_sem, context)),
|
pp.text(jax_core.pp_var(dst_sem, context)),
|
||||||
])
|
])
|
||||||
@ -358,13 +326,14 @@ def _dma_wait_abstract_eval(*args, tree, device_id_type):
|
|||||||
def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||||
context: jax_core.JaxprPpContext,
|
context: jax_core.JaxprPpContext,
|
||||||
settings: jax_core.JaxprPpSettings):
|
settings: jax_core.JaxprPpSettings):
|
||||||
|
del settings
|
||||||
invars = eqn.invars
|
invars = eqn.invars
|
||||||
tree = eqn.params["tree"]
|
tree = eqn.params["tree"]
|
||||||
sem, ref, indexer = tree_util.tree_unflatten(tree, invars)
|
sem, ref, indexer = tree_util.tree_unflatten(tree, invars)
|
||||||
return pp.concat([
|
return pp.concat([
|
||||||
pp.text('dma_wait'),
|
pp.text('dma_wait'),
|
||||||
pp.text(' '),
|
pp.text(' '),
|
||||||
_pp_ref(ref, indexer, context),
|
sp.pp_ref_indexers(context, ref, (indexer,)),
|
||||||
pp.text(' '),
|
pp.text(' '),
|
||||||
pp.text(jax_core.pp_var(sem, context)),
|
pp.text(jax_core.pp_var(sem, context)),
|
||||||
])
|
])
|
||||||
@ -373,7 +342,10 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
|
|||||||
|
|
||||||
def _get_ref_and_indexer(ref):
|
def _get_ref_and_indexer(ref):
|
||||||
if isinstance(ref, state.RefView):
|
if isinstance(ref, state.RefView):
|
||||||
return ref.ref, ref.indexer
|
indexers = ref.indexers
|
||||||
|
if len(indexers) > 1:
|
||||||
|
raise NotImplementedError("Only one indexer supported.")
|
||||||
|
return ref.ref, ref.indexers[0].indices
|
||||||
return ref, (slice(None),) * len(ref.shape)
|
return ref, (slice(None),) * len(ref.shape)
|
||||||
|
|
||||||
def make_async_copy(src_ref, dst_ref, sem):
|
def make_async_copy(src_ref, dst_ref, sem):
|
||||||
|
@ -27,15 +27,15 @@ from jax._src import core as jax_core
|
|||||||
from jax._src import pretty_printer as pp
|
from jax._src import pretty_printer as pp
|
||||||
from jax._src import state
|
from jax._src import state
|
||||||
from jax._src.util import (safe_map, safe_zip)
|
from jax._src.util import (safe_map, safe_zip)
|
||||||
from jax._src.state import primitives as state_primitives
|
|
||||||
from jax._src.state import discharge as state_discharge
|
from jax._src.state import discharge as state_discharge
|
||||||
|
from jax._src.state import indexing
|
||||||
|
from jax._src.state import primitives as sp
|
||||||
from jax.interpreters import ad
|
from jax.interpreters import ad
|
||||||
from jax.interpreters import mlir
|
from jax.interpreters import mlir
|
||||||
from jax.interpreters import xla
|
from jax.interpreters import xla
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from jax._src.pallas import core as pallas_core
|
from jax._src.pallas import core as pallas_core
|
||||||
from jax._src.pallas import indexing
|
|
||||||
|
|
||||||
# TODO(sharadmv): enable type checking
|
# TODO(sharadmv): enable type checking
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
@ -224,44 +224,13 @@ def _load_abstract_eval(*avals_flat, args_tree, **_):
|
|||||||
|
|
||||||
load_p.def_effectful_abstract_eval(_load_abstract_eval)
|
load_p.def_effectful_abstract_eval(_load_abstract_eval)
|
||||||
|
|
||||||
def _pp_dslice(dim: int, slice: Slice, context):
|
|
||||||
size = pp.text(str(slice.size))
|
|
||||||
if isinstance(slice.start, int):
|
|
||||||
if slice.start == 0:
|
|
||||||
start = pp.text("")
|
|
||||||
else:
|
|
||||||
start = pp.text(str(slice.start))
|
|
||||||
if slice.size == dim:
|
|
||||||
end = pp.text("")
|
|
||||||
else:
|
|
||||||
end = pp.text(str(slice.start + slice.size))
|
|
||||||
else:
|
|
||||||
start = pp.text(jax_core.pp_var(slice.start, context))
|
|
||||||
end = pp.concat([start, pp.text("+"), size])
|
|
||||||
return pp.concat([start, pp.text(":"), end])
|
|
||||||
|
|
||||||
def _pp_idx(ref_aval, idx: NDIndexer, context):
|
|
||||||
docs = [
|
|
||||||
_pp_dslice(d, s, context) if isinstance(s, Slice)
|
|
||||||
else pp.text(jax_core.pp_var(s, context))
|
|
||||||
for s, d in zip(idx.indices, ref_aval.shape)]
|
|
||||||
if not docs:
|
|
||||||
return pp.text("")
|
|
||||||
doc = [docs[0]]
|
|
||||||
for d in docs[1:]:
|
|
||||||
doc.append(pp.text(","))
|
|
||||||
doc.append(d)
|
|
||||||
return pp.concat(doc)
|
|
||||||
|
|
||||||
def _load_pp_rule(eqn, context, settings):
|
def _load_pp_rule(eqn, context, settings):
|
||||||
# Pretty prints `a = load x i` as `x[i] <- a`
|
# Pretty prints `a = load x i` as `x[i] <- a`
|
||||||
y, = eqn.outvars
|
y, = eqn.outvars
|
||||||
x, idx, _, _ = eqn.params["args_tree"].unflatten(eqn.invars)
|
x, idx, _, _ = eqn.params["args_tree"].unflatten(eqn.invars)
|
||||||
idx = _pp_idx(eqn.invars[0].aval, idx, context)
|
|
||||||
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||||
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
|
return pp.concat([
|
||||||
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']')
|
lhs, pp.text(' <- '), sp.pp_ref_indexers(context, x, (idx,))])
|
||||||
]))])
|
|
||||||
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
|
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
|
||||||
|
|
||||||
|
|
||||||
@ -339,15 +308,14 @@ def _swap_pp_rule(eqn, context, settings):
|
|||||||
# Pretty prints `_ = swap x v i` as `x[i] <- v`
|
# Pretty prints `_ = swap x v i` as `x[i] <- v`
|
||||||
y, = eqn.outvars
|
y, = eqn.outvars
|
||||||
x, idx, val, _ = eqn.params["args_tree"].unflatten(eqn.invars)
|
x, idx, val, _ = eqn.params["args_tree"].unflatten(eqn.invars)
|
||||||
idx = _pp_idx(eqn.invars[0].aval, idx, context)
|
x_i = sp.pp_ref_indexers(context, x, (idx,))
|
||||||
x_i = pp.concat([pp.text(jax_core.pp_var(x, context)),
|
|
||||||
pp.text('['), idx, pp.text(']')])
|
|
||||||
if isinstance(y, jax_core.DropVar):
|
if isinstance(y, jax_core.DropVar):
|
||||||
return pp.concat([state_primitives.pp_ref(
|
return pp.concat([
|
||||||
x_i), pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
|
x_i,
|
||||||
|
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
|
||||||
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||||
return pp.concat([y, pp.text(', '), state_primitives.pp_ref(x_i),
|
return pp.concat([y, pp.text(', '), x_i,
|
||||||
pp.text(' <- '), state_primitives.pp_ref(x_i),
|
pp.text(' <- '), x_i,
|
||||||
pp.text(', '), pp.text(jax_core.pp_var(val, context))])
|
pp.text(', '), pp.text(jax_core.pp_var(val, context))])
|
||||||
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||||||
|
|
||||||
|
@ -39,12 +39,12 @@ from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
|||||||
from jax._src.lib import hlo_helpers
|
from jax._src.lib import hlo_helpers
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.pallas import core as pallas_core
|
from jax._src.pallas import core as pallas_core
|
||||||
from jax._src.pallas import indexing
|
|
||||||
from jax._src.pallas import primitives
|
from jax._src.pallas import primitives
|
||||||
from jax._src.pallas import utils as pallas_utils
|
from jax._src.pallas import utils as pallas_utils
|
||||||
from jax._src.pallas.pallas_call import pallas_call_p
|
from jax._src.pallas.pallas_call import pallas_call_p
|
||||||
from jax._src.state import AbstractRef
|
from jax._src.state import AbstractRef
|
||||||
from jax._src.state import discharge
|
from jax._src.state import discharge
|
||||||
|
from jax._src.state import indexing
|
||||||
from jax._src.state import primitives as sp
|
from jax._src.state import primitives as sp
|
||||||
from jax._src.util import merge_lists
|
from jax._src.util import merge_lists
|
||||||
from jax._src.util import partition_list
|
from jax._src.util import partition_list
|
||||||
@ -438,7 +438,7 @@ _TRITON_FN_MAPPING = {
|
|||||||
lax.nextafter_p: tl.math.nextafter,
|
lax.nextafter_p: tl.math.nextafter,
|
||||||
ad_util.add_any_p: tl.semantic.add,
|
ad_util.add_any_p: tl.semantic.add,
|
||||||
# Other ops.
|
# Other ops.
|
||||||
indexing.broadcast_to_p: tl.broadcast_to,
|
sp.broadcast_to_p: tl.broadcast_to,
|
||||||
primitives.atomic_cas_p: tl.atomic_cas,
|
primitives.atomic_cas_p: tl.atomic_cas,
|
||||||
primitives.max_contiguous_p: tl.max_contiguous,
|
primitives.max_contiguous_p: tl.max_contiguous,
|
||||||
primitives.multiple_of_p: tl.multiple_of,
|
primitives.multiple_of_p: tl.multiple_of,
|
||||||
@ -727,28 +727,16 @@ def _pack_indices(non_slice_idx, indexed_dims):
|
|||||||
|
|
||||||
|
|
||||||
def _get_lowering_rule(
|
def _get_lowering_rule(
|
||||||
ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims
|
ctx: TritonLoweringRuleContext, ptr, *idx, tree
|
||||||
):
|
):
|
||||||
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
if not isinstance(ptr.type, tl.pointer_type):
|
if not isinstance(ptr.type, tl.pointer_type):
|
||||||
assert not non_slice_idx
|
assert len(indexers) == 0
|
||||||
return ptr
|
return ptr
|
||||||
|
if len(indexers) > 1:
|
||||||
ref_aval, *idx_avals = ctx.avals_in
|
raise NotImplementedError("No support for multiple indexers yet.")
|
||||||
idx_avals = _pack_indices(idx_avals, indexed_dims)
|
indexer = indexers[0]
|
||||||
if non_slice_idx:
|
args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, None, None))
|
||||||
(int_indexer_shape,) = {
|
|
||||||
i.shape for i in idx_avals if not isinstance(i, slice)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
int_indexer_shape = ()
|
|
||||||
|
|
||||||
idx = _pack_indices(non_slice_idx, indexed_dims)
|
|
||||||
idx = tuple(
|
|
||||||
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
|
||||||
for s, slc in zip(ref_aval.shape, idx)
|
|
||||||
)
|
|
||||||
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
|
|
||||||
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None))
|
|
||||||
return _masked_load_lowering_rule(
|
return _masked_load_lowering_rule(
|
||||||
ctx,
|
ctx,
|
||||||
*args_flat,
|
*args_flat,
|
||||||
@ -794,24 +782,16 @@ triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
|
|||||||
|
|
||||||
|
|
||||||
def _swap_lowering_rule(
|
def _swap_lowering_rule(
|
||||||
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
|
ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
|
||||||
):
|
):
|
||||||
ref_aval, _, *idx_avals = ctx.avals_in
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
idx_avals = _pack_indices(idx_avals, indexed_dims)
|
if not isinstance(ptr.type, tl.pointer_type):
|
||||||
if non_slice_idx:
|
assert len(indexers) == 0
|
||||||
(int_indexer_shape,) = {
|
return ptr
|
||||||
i.shape for i in idx_avals if not isinstance(i, slice)
|
if len(indexers) > 1:
|
||||||
}
|
raise NotImplementedError("No support for multiple indexers yet.")
|
||||||
else:
|
indexer = indexers[0]
|
||||||
int_indexer_shape = ()
|
args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, value, None))
|
||||||
|
|
||||||
idx = _pack_indices(non_slice_idx, indexed_dims)
|
|
||||||
idx = tuple(
|
|
||||||
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
|
||||||
for s, slc in zip(ref_aval.shape, idx)
|
|
||||||
)
|
|
||||||
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
|
|
||||||
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None))
|
|
||||||
return _masked_swap_lowering_rule(
|
return _masked_swap_lowering_rule(
|
||||||
ctx, *args_flat, args_tree=args_tree, eviction_policy=None
|
ctx, *args_flat, args_tree=args_tree, eviction_policy=None
|
||||||
)
|
)
|
||||||
@ -850,24 +830,17 @@ triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
|
|||||||
|
|
||||||
|
|
||||||
def _addupdate_lowering_rule(
|
def _addupdate_lowering_rule(
|
||||||
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
|
ctx: TritonLoweringRuleContext, ptr, value, *idx, tree
|
||||||
):
|
):
|
||||||
ref_block_info, *_ = ctx.block_infos
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
avals_in = ctx.avals_in
|
if not isinstance(ptr.type, tl.pointer_type):
|
||||||
idx = _pack_indices(non_slice_idx, indexed_dims)
|
assert len(indexers) == 0
|
||||||
if non_slice_idx:
|
return ptr
|
||||||
(int_indexer_shape,) = {
|
if len(indexers) > 1:
|
||||||
tuple(map(lambda x: x.value, i.shape)) for i in non_slice_idx
|
raise NotImplementedError("No support for multiple indexers yet.")
|
||||||
}
|
indexer = indexers[0]
|
||||||
else:
|
|
||||||
int_indexer_shape = ()
|
|
||||||
idx = tuple(
|
|
||||||
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
|
||||||
for s, slc in zip(avals_in[0].shape, idx)
|
|
||||||
)
|
|
||||||
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
|
|
||||||
ptr = _compute_pointers_from_indices(
|
ptr = _compute_pointers_from_indices(
|
||||||
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0].shape, ctx.builder
|
||||||
)
|
)
|
||||||
tl.atomic_add(ptr, value, _builder=ctx.builder)
|
tl.atomic_add(ptr, value, _builder=ctx.builder)
|
||||||
return []
|
return []
|
||||||
|
@ -34,9 +34,11 @@ from jax._src.interpreters import mlir
|
|||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
from jax._src.lax import lax
|
from jax._src.lax import lax
|
||||||
from jax._src.lax import slicing as lax_slicing
|
from jax._src.lax import slicing as lax_slicing
|
||||||
|
from jax._src.state import indexing
|
||||||
from jax._src.state.types import AbstractRef, RefEffect
|
from jax._src.state.types import AbstractRef, RefEffect
|
||||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||||
from jax._src.state.utils import hoist_consts_to_refs
|
from jax._src.state.utils import hoist_consts_to_refs
|
||||||
|
from jax._src.typing import Array
|
||||||
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
|
from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
|
||||||
partition_list, merge_lists, split_dict)
|
partition_list, merge_lists, split_dict)
|
||||||
|
|
||||||
@ -144,34 +146,112 @@ def _eval_jaxpr_discharge_state(
|
|||||||
env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge])
|
env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge])
|
||||||
return out_vals + ref_vals
|
return out_vals + ref_vals
|
||||||
|
|
||||||
|
def _is_trivial_indexer(indexer: indexing.NDIndexer):
|
||||||
|
for s, idx in zip(indexer.shape, indexer.indices):
|
||||||
|
if not isinstance(idx, indexing.Slice):
|
||||||
|
return False
|
||||||
|
if not isinstance(idx.start, int):
|
||||||
|
return False
|
||||||
|
if idx.start:
|
||||||
|
return False
|
||||||
|
if idx.size != s:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _convert_to_array_indexer(indexer: indexing.NDIndexer
|
||||||
|
) -> tuple[int | Array, ...]:
|
||||||
|
# This is the general gather case. We need to create the gather arrays.
|
||||||
|
is_integer_indexer, _, integer_indexer = (
|
||||||
|
indexing.unpack_ndindexer(indexer)
|
||||||
|
)
|
||||||
|
total_shape = indexer.get_indexer_shape()
|
||||||
|
int_indexer_shape = indexer.int_indexer_shape
|
||||||
|
slice_shape = total_shape[len(int_indexer_shape):]
|
||||||
|
slice_dims = tuple(
|
||||||
|
i + len(int_indexer_shape) for i in range(len(slice_shape))
|
||||||
|
)
|
||||||
|
slice_dim_iter = iter(slice_dims)
|
||||||
|
slice_indexer: list[Array] = []
|
||||||
|
for idx, is_int_index in zip(indexer.indices, is_integer_indexer):
|
||||||
|
if not is_int_index:
|
||||||
|
assert isinstance(idx, indexing.Slice)
|
||||||
|
slice_indices = lax.broadcasted_iota(
|
||||||
|
np.dtype("int32"), total_shape, next(slice_dim_iter)
|
||||||
|
) + idx.start
|
||||||
|
slice_indexer.append(slice_indices)
|
||||||
|
integer_indexer = tuple(
|
||||||
|
lax.expand_dims(idx, (-1,)) for idx in integer_indexer
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
assert next(slice_dim_iter, None) is None
|
||||||
|
return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer))
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_convert_to_dynamic_slice(
|
||||||
|
indexer: indexing.NDIndexer,
|
||||||
|
) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None:
|
||||||
|
# An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice`
|
||||||
|
# if each of the indexers is a `Slice` or a ()-shaped value.
|
||||||
|
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
|
||||||
|
for i in indexer.indices):
|
||||||
|
return None
|
||||||
|
_convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32"))
|
||||||
|
starts = tuple(
|
||||||
|
_convert_i32(i.start) if isinstance(i, indexing.Slice)
|
||||||
|
else _convert_i32(i) for i in indexer.indices
|
||||||
|
)
|
||||||
|
sizes = tuple(
|
||||||
|
i.size if isinstance(i, indexing.Slice) else 1 for i in indexer.indices
|
||||||
|
)
|
||||||
|
squeeze_dims = tuple(
|
||||||
|
i
|
||||||
|
for i, idx in enumerate(indexer.indices)
|
||||||
|
if not isinstance(idx, indexing.Slice)
|
||||||
|
)
|
||||||
|
return starts, sizes, squeeze_dims
|
||||||
|
|
||||||
|
|
||||||
@register_discharge_rule(get_p)
|
@register_discharge_rule(get_p)
|
||||||
def _get_discharge_rule(
|
def _get_discharge_rule(
|
||||||
in_avals: Sequence[core.AbstractValue],
|
in_avals: Sequence[core.AbstractValue],
|
||||||
out_avals: Sequence[core.AbstractValue], x, *non_slice_idx,
|
out_avals: Sequence[core.AbstractValue], x, *idx,
|
||||||
indexed_dims: Sequence[bool]):
|
tree):
|
||||||
del in_avals, out_avals
|
del in_avals, out_avals
|
||||||
y = _get_discharge(x, non_slice_idx, indexed_dims)
|
y = _get_discharge(x, idx, tree)
|
||||||
return (None,) * (len(non_slice_idx) + 1), y
|
return (None,) * (len(idx) + 1), y
|
||||||
|
|
||||||
def _get_discharge(x, idx, indexed_dims):
|
def _prepend_gather(x, indexer):
|
||||||
if not any(indexed_dims):
|
|
||||||
return x
|
|
||||||
if all(not i.shape for i in idx):
|
|
||||||
return _dynamic_index(x, idx, indexed_dims)
|
|
||||||
else:
|
|
||||||
return _prepend_gather(x, idx, indexed_dims)
|
|
||||||
|
|
||||||
def _prepend_gather(x, idx, indexed_dims):
|
|
||||||
indexer = _indexer(idx, indexed_dims)
|
|
||||||
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
|
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
|
||||||
return x[None][(np.array(0, 'int32'), *indexer)]
|
return x[None][(np.array(0, 'int32'), *indexer)]
|
||||||
|
|
||||||
def _prepend_scatter(x, idx, indexed_dims, val, *, add=False):
|
def _prepend_scatter(x, indexer, val, *, add=False):
|
||||||
indexer = _indexer(idx, indexed_dims)
|
# NumPy advanced int indexing won't prepend w/ only one dim, so add dummy.
|
||||||
|
# However, since this is scatter, we need to remove the 1-sized dimension
|
||||||
|
# we added at the front.
|
||||||
if add:
|
if add:
|
||||||
return x[None].at[(0, *indexer)].add(val)[0]
|
return x[None].at[(0, *indexer)].add(val)[0]
|
||||||
return x[None].at[(0, *indexer)].set(val)[0]
|
return x[None].at[(0, *indexer)].set(val)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_discharge(x, idx, tree):
|
||||||
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
|
if len(indexers) > 1:
|
||||||
|
raise NotImplementedError("Only single indexer is supported.")
|
||||||
|
indexer = indexers[0]
|
||||||
|
if _is_trivial_indexer(indexer):
|
||||||
|
return x
|
||||||
|
# If everything in the indexer is a slice or ()-shaped, we can also
|
||||||
|
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
||||||
|
# We need to squeeze out the the 1-sized slices at the end.
|
||||||
|
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
||||||
|
starts, sizes, squeeze_dims = maybe_slice
|
||||||
|
y = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||||
|
return lax.squeeze(y, squeeze_dims)
|
||||||
|
indexer = _convert_to_array_indexer(indexer)
|
||||||
|
if indexer is None:
|
||||||
|
return x
|
||||||
|
return x[None][(np.array(0, 'int32'), *indexer)]
|
||||||
|
|
||||||
def _indexer(idx, indexed_dims):
|
def _indexer(idx, indexed_dims):
|
||||||
idx_ = iter(idx)
|
idx_ = iter(idx)
|
||||||
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
|
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
|
||||||
@ -181,59 +261,59 @@ def _indexer(idx, indexed_dims):
|
|||||||
@register_discharge_rule(swap_p)
|
@register_discharge_rule(swap_p)
|
||||||
def _swap_discharge_rule(
|
def _swap_discharge_rule(
|
||||||
in_avals: Sequence[core.AbstractValue],
|
in_avals: Sequence[core.AbstractValue],
|
||||||
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
|
out_avals: Sequence[core.AbstractValue], x, val, *idx,
|
||||||
indexed_dims: Sequence[bool]):
|
tree):
|
||||||
del in_avals, out_avals
|
del in_avals, out_avals
|
||||||
if not any(indexed_dims):
|
z, x_new = _swap_discharge(x, val, idx, tree)
|
||||||
z, x_new = x, val
|
return (x_new, None) + (None,) * len(idx), z
|
||||||
z, x_new = _swap_discharge(x, val, non_slice_idx, indexed_dims)
|
|
||||||
return (x_new, None) + (None,) * len(non_slice_idx), z
|
|
||||||
|
|
||||||
def _swap_discharge(x, val, idx, indexed_dims):
|
def _swap_discharge(x, val, idx, tree):
|
||||||
if not any(indexed_dims):
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
z, x_new = x, val
|
if len(indexers) > 1:
|
||||||
elif all(not i.shape for i in idx):
|
raise NotImplementedError("Only single indexer is supported.")
|
||||||
z = _dynamic_index(x, idx, indexed_dims)
|
indexer = indexers[0]
|
||||||
x_new = _dynamic_update_index(x, idx, val, indexed_dims)
|
if _is_trivial_indexer(indexer):
|
||||||
else:
|
return x, val
|
||||||
z = _prepend_gather(x, idx, indexed_dims)
|
# If everything in the indexer is a slice or ()-shaped, we can also
|
||||||
x_new = _prepend_scatter(x, idx, indexed_dims, val)
|
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
||||||
return z, x_new
|
# We need to squeeze out the the 1-sized slices at the end.
|
||||||
|
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
||||||
|
starts, sizes, squeeze_dims = maybe_slice
|
||||||
|
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||||
|
val = lax.expand_dims(val, squeeze_dims)
|
||||||
|
y = lax_slicing.dynamic_update_slice(x, val, starts)
|
||||||
|
return lax.squeeze(x_old, squeeze_dims), y
|
||||||
|
indexer = _convert_to_array_indexer(indexer)
|
||||||
|
x_old = _prepend_gather(x, indexer)
|
||||||
|
return x_old, _prepend_scatter(x, indexer, val)
|
||||||
|
|
||||||
@register_discharge_rule(addupdate_p)
|
@register_discharge_rule(addupdate_p)
|
||||||
def _addupdate_discharge_rule(
|
def _addupdate_discharge_rule(
|
||||||
in_avals: Sequence[core.AbstractValue],
|
in_avals: Sequence[core.AbstractValue],
|
||||||
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
|
out_avals: Sequence[core.AbstractValue], x, val, *idx,
|
||||||
indexed_dims: Sequence[bool]):
|
tree):
|
||||||
del in_avals, out_avals
|
del in_avals, out_avals
|
||||||
ans = _addupdate_discharge(x, val, non_slice_idx, indexed_dims)
|
ans = _addupdate_discharge(x, val, idx, tree)
|
||||||
return (ans, None) + (None,) * len(non_slice_idx), []
|
return (ans, None) + (None,) * len(idx), []
|
||||||
|
|
||||||
def _addupdate_discharge(x, val, idx, indexed_dims):
|
def _addupdate_discharge(x, val, idx, tree):
|
||||||
if not any(indexed_dims):
|
indexers = tree_util.tree_unflatten(tree, idx)
|
||||||
|
if len(indexers) > 1:
|
||||||
|
raise NotImplementedError("Only single indexer is supported.")
|
||||||
|
indexer = indexers[0]
|
||||||
|
if _is_trivial_indexer(indexer):
|
||||||
return x + val
|
return x + val
|
||||||
if all(not i.shape for i in idx):
|
# If everything in the indexer is a slice or ()-shaped, we can also
|
||||||
y = val + _dynamic_index(x, idx, indexed_dims)
|
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
||||||
return _dynamic_update_index(x, idx, y, indexed_dims)
|
# We need to squeeze out the the 1-sized slices at the end.
|
||||||
else:
|
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
|
||||||
return _prepend_scatter(x, idx, indexed_dims, val, add=True)
|
starts, sizes, squeeze_dims = maybe_slice
|
||||||
|
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||||
def _dynamic_index(x, idx, indexed_dims):
|
val = lax.expand_dims(val, squeeze_dims)
|
||||||
assert isinstance(idx, (list, tuple)) and idx
|
y = lax_slicing.dynamic_update_slice(x, x_old + val, starts)
|
||||||
idx_ = iter(idx)
|
return y
|
||||||
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
|
indexer = _convert_to_array_indexer(indexer)
|
||||||
assert next(idx_, None) is None
|
return _prepend_scatter(x, indexer, val, add=True)
|
||||||
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
|
|
||||||
out = lax_slicing.dynamic_slice(x, starts, sizes)
|
|
||||||
return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b])
|
|
||||||
|
|
||||||
def _dynamic_update_index(x, idx, val, indexed_dims):
|
|
||||||
assert isinstance(idx, (list, tuple)) and idx
|
|
||||||
idx_ = iter(idx)
|
|
||||||
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
|
|
||||||
assert next(idx_, None) is None
|
|
||||||
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
|
|
||||||
return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts)
|
|
||||||
|
|
||||||
@weakref_lru_cache
|
@weakref_lru_cache
|
||||||
def _cached_closed_jaxpr_discharge(closed_jaxpr):
|
def _cached_closed_jaxpr_discharge(closed_jaxpr):
|
||||||
|
192
jax/_src/state/indexing.py
Normal file
192
jax/_src/state/indexing.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
# Copyright 2023 The JAX Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from jax._src import core
|
||||||
|
from jax._src import tree_util
|
||||||
|
from jax._src.typing import Array
|
||||||
|
from jax._src.util import merge_lists
|
||||||
|
from jax._src.util import partition_list
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@tree_util.register_pytree_node_class
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Slice:
|
||||||
|
"""Represents a slice with a dynamic start index and a fixed size."""
|
||||||
|
start: Any
|
||||||
|
size: int
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.size < 0:
|
||||||
|
raise ValueError("`size` must not be negative.")
|
||||||
|
|
||||||
|
def tree_flatten(self):
|
||||||
|
# If `start` is statically known, we treat it as static information
|
||||||
|
if isinstance(self.start, int):
|
||||||
|
return (), (self.start, self.size)
|
||||||
|
return (self.start,), (self.size,)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tree_unflatten(cls, aux_data, children) -> Slice:
|
||||||
|
return cls(*children, *aux_data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_slice(cls, slc: slice, size: int) -> Slice:
|
||||||
|
start, stop, step = slc.indices(size)
|
||||||
|
if step != 1:
|
||||||
|
raise ValueError(f"slice must have a step of 1 (found: {step})")
|
||||||
|
return cls(start, stop - start)
|
||||||
|
|
||||||
|
|
||||||
|
def dslice(start: int | Array | None, size: int | None = None
|
||||||
|
) -> slice | Slice:
|
||||||
|
"""Constructs a `Slice` from a start and a size."""
|
||||||
|
if start is None:
|
||||||
|
return slice(None)
|
||||||
|
if size is None:
|
||||||
|
if not isinstance(start, int):
|
||||||
|
raise ValueError("Non-static `dslice`")
|
||||||
|
return Slice(0, start)
|
||||||
|
return Slice(start, size)
|
||||||
|
ds = dslice # Handy alias
|
||||||
|
|
||||||
|
|
||||||
|
IntIndexer = Union[int, Array]
|
||||||
|
DimIndexer = Union[IntIndexer, Slice]
|
||||||
|
|
||||||
|
def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
|
||||||
|
tuple[Slice, ...],
|
||||||
|
tuple[IntIndexer, ...]]:
|
||||||
|
is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
|
||||||
|
slice_indexers, int_indexers = partition_list(
|
||||||
|
is_int_indexing, indexer.indices)
|
||||||
|
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
|
||||||
|
|
||||||
|
def _maybe_concretize(x: Any):
|
||||||
|
try:
|
||||||
|
return core.concrete_or_error(None, x)
|
||||||
|
except core.ConcretizationTypeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@tree_util.register_pytree_node_class
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class NDIndexer:
|
||||||
|
indices: tuple[DimIndexer, ...]
|
||||||
|
shape: tuple[int, ...]
|
||||||
|
int_indexer_shape: tuple[int, ...]
|
||||||
|
validate: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.validate:
|
||||||
|
return
|
||||||
|
if len(self.indices) != len(self.shape):
|
||||||
|
raise ValueError(
|
||||||
|
f"`indices` must be the same length as `Ref` shape.: {self}."
|
||||||
|
)
|
||||||
|
# We validate integer indexing shapes here
|
||||||
|
for idx, s in zip(self.indices, self.shape):
|
||||||
|
if isinstance(idx, Slice):
|
||||||
|
start = idx.start
|
||||||
|
if value := _maybe_concretize(start):
|
||||||
|
if value >= s:
|
||||||
|
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
|
||||||
|
if value + idx.size > s:
|
||||||
|
raise ValueError(
|
||||||
|
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
# The shape of indexer integers should be broadcastable up to the
|
||||||
|
# int_indexer_shape of the whole NDIndexer
|
||||||
|
if not np.shape(idx):
|
||||||
|
if (value := _maybe_concretize(idx)) and value >= s:
|
||||||
|
raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
|
||||||
|
# For ()-shaped indexers, we can broadcast no problm.
|
||||||
|
continue
|
||||||
|
# If we don't have a ()-shaped indexer, the rank must match
|
||||||
|
# int_indexer_shape
|
||||||
|
if np.ndim(idx) != len(self.int_indexer_shape):
|
||||||
|
raise ValueError(
|
||||||
|
f"Indexer must have rank {np.ndim(idx)}: {idx=} vs."
|
||||||
|
f" {self.int_indexer_shape=}"
|
||||||
|
)
|
||||||
|
# Here we check that the shapes broadcast.
|
||||||
|
try:
|
||||||
|
np.broadcast_shapes(np.shape(idx), self.int_indexer_shape)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not broadcast integer indexer: {idx=} vs."
|
||||||
|
f" {self.int_indexer_shape=}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def tree_flatten(self):
|
||||||
|
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
|
||||||
|
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tree_unflatten(cls, data, flat_idx):
|
||||||
|
idx_tree, shape, int_indexer_shape = data
|
||||||
|
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
|
||||||
|
return NDIndexer(tuple(indices), shape, int_indexer_shape)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
||||||
|
if indices == ...:
|
||||||
|
indices = (slice(None),) * len(shape)
|
||||||
|
if not isinstance(indices, tuple):
|
||||||
|
indices = (indices,)
|
||||||
|
if any(idx is ... for idx in indices):
|
||||||
|
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
|
||||||
|
# e.g. x[0, ..., 1].
|
||||||
|
raise NotImplementedError("Ellipsis in indexer not supported yet.")
|
||||||
|
if len(indices) > len(shape):
|
||||||
|
raise ValueError("`indices` must not be longer than `shape`: "
|
||||||
|
f"{indices=}, {shape=}")
|
||||||
|
# Pad out indices with slice(None)
|
||||||
|
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
|
||||||
|
# Convert all `slice`s to `Slice`s
|
||||||
|
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
|
||||||
|
else i for i, s in zip(indices, shape))
|
||||||
|
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
||||||
|
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
||||||
|
indexer_shapes = [core.get_aval(i).shape for i in int_indexers]
|
||||||
|
if indexer_shapes:
|
||||||
|
try:
|
||||||
|
bcast_shape = np.broadcast_shapes(*indexer_shapes)
|
||||||
|
except ValueError as e:
|
||||||
|
# Raise a nicer error than the NumPy one.
|
||||||
|
raise ValueError("Cannot broadcast shapes for indexing: "
|
||||||
|
f"{tuple(a for a in indexer_shapes)}") from e
|
||||||
|
else:
|
||||||
|
bcast_shape = ()
|
||||||
|
# Here we use the `broadcast_to` primitive instead of composing lax
|
||||||
|
# primitives together because it is easier to lower in targets like
|
||||||
|
# Triton/Mosaic.
|
||||||
|
from jax._src.state import primitives as sp # pytype: disable=import-error
|
||||||
|
int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers]
|
||||||
|
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
||||||
|
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
|
||||||
|
|
||||||
|
def get_indexer_shape(self) -> tuple[int, ...]:
|
||||||
|
_, slice_indexers, _ = unpack_ndindexer(self)
|
||||||
|
slice_shape = [s.size for s in slice_indexers]
|
||||||
|
# In NDIndexers, the int_indexer_shape is *always* at the front of the
|
||||||
|
# result.
|
||||||
|
return (*self.int_indexer_shape, *slice_shape)
|
@ -23,14 +23,17 @@ import numpy as np
|
|||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import pretty_printer as pp
|
from jax._src import pretty_printer as pp
|
||||||
|
from jax._src import tree_util
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
|
from jax._src.interpreters import mlir
|
||||||
from jax._src.lax import lax
|
from jax._src.lax import lax
|
||||||
from jax._src.typing import Array
|
from jax._src.typing import Array
|
||||||
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
from jax._src.state import indexing
|
||||||
|
from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect,
|
||||||
AccumEffect)
|
AccumEffect)
|
||||||
from jax._src.util import safe_map, safe_zip, tuple_insert
|
from jax._src.util import safe_map, safe_zip
|
||||||
|
|
||||||
|
|
||||||
## General utilities
|
## General utilities
|
||||||
@ -51,41 +54,14 @@ zip, unsafe_zip = safe_zip, zip
|
|||||||
# a:f32[3] <- x[]
|
# a:f32[3] <- x[]
|
||||||
get_p = core.Primitive("get")
|
get_p = core.Primitive("get")
|
||||||
|
|
||||||
def _get_impl(ref: AbstractRef, *idx: int, **_):
|
def _get_impl(ref: AbstractRef, *args: Any, tree):
|
||||||
del ref, idx
|
del ref, args, tree
|
||||||
raise ValueError("Cannot run stateful primitive.")
|
raise ValueError("Cannot run stateful primitive.")
|
||||||
get_p.def_impl(_get_impl)
|
get_p.def_impl(_get_impl)
|
||||||
|
|
||||||
Indexer = tuple[Union[int, slice, Array], ...]
|
Indexer = tuple[Union[int, slice, Array], ...]
|
||||||
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
||||||
|
|
||||||
def _is_trivial_indexer(idx: Indexer) -> bool:
|
|
||||||
if idx is ...:
|
|
||||||
return True
|
|
||||||
if type(idx) is tuple:
|
|
||||||
if len(idx) == 0:
|
|
||||||
return True
|
|
||||||
return len(idx) == 1 and idx[0] is ...
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _unpack_idx(idx: Indexer, ndim: int
|
|
||||||
) -> tuple[tuple[Array, ...], tuple[bool, ...]]:
|
|
||||||
if _is_trivial_indexer(idx):
|
|
||||||
idx = tuple(slice(None) for _ in range(ndim))
|
|
||||||
indexed_dims_ = []
|
|
||||||
non_slice_idx = []
|
|
||||||
for i in idx:
|
|
||||||
if isinstance(i, slice):
|
|
||||||
if i.start is not None or i.stop is not None or i.step is not None:
|
|
||||||
raise NotImplementedError("Reference indexing only supports trivial slices")
|
|
||||||
indexed_dims_.append(False)
|
|
||||||
else:
|
|
||||||
non_slice_idx.append(i)
|
|
||||||
indexed_dims_.append(True)
|
|
||||||
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
|
|
||||||
import jax.numpy as jnp
|
|
||||||
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
|
|
||||||
|
|
||||||
def _get_slice_output_shape(in_shape: tuple[int, ...],
|
def _get_slice_output_shape(in_shape: tuple[int, ...],
|
||||||
idx_shapes: tuple[tuple[int, ...], ...],
|
idx_shapes: tuple[tuple[int, ...], ...],
|
||||||
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
|
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
|
||||||
@ -95,24 +71,30 @@ def _get_slice_output_shape(in_shape: tuple[int, ...],
|
|||||||
shape = (*shape_prefix, *shape_suffix)
|
shape = (*shape_prefix, *shape_suffix)
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
def _get_indexer(ref: AbstractRef, idx: Indexer
|
|
||||||
) -> tuple[Indexer, tuple[bool, ...]]:
|
|
||||||
if isinstance(ref.inner_aval, core.ShapedArray):
|
|
||||||
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
|
||||||
else:
|
|
||||||
if not _is_trivial_indexer(idx):
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
|
|
||||||
non_slice_idx, indexed_dims = (), ()
|
|
||||||
return non_slice_idx, indexed_dims
|
|
||||||
|
|
||||||
def ref_get(ref: Any, idx: Indexer) -> Array:
|
def _get_indexer(
|
||||||
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
ref_or_view: Any, idx: Indexer | None, function_name: str
|
||||||
|
) -> tuple[Any, tuple[indexing.NDIndexer, ...]]:
|
||||||
|
if isinstance(ref_or_view, RefView):
|
||||||
|
ref, indexers = ref_or_view.ref, ref_or_view.indexers
|
||||||
|
else:
|
||||||
|
ref, indexers = ref_or_view, ()
|
||||||
ref_aval = core.get_aval(ref)
|
ref_aval = core.get_aval(ref)
|
||||||
if not isinstance(ref_aval, AbstractRef):
|
if not isinstance(ref_aval, AbstractRef):
|
||||||
raise ValueError(f"Can only call `get` on a `Ref`: {ref}")
|
raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
|
||||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||||
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)
|
return ref, ()
|
||||||
|
if idx is None:
|
||||||
|
return ref, indexers
|
||||||
|
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
|
||||||
|
return ref, (*indexers, nd_indexer)
|
||||||
|
|
||||||
|
|
||||||
|
def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
|
||||||
|
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||||||
|
ref, indexers = _get_indexer(ref_or_view, idx, "ref_get")
|
||||||
|
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||||
|
return get_p.bind(ref, *flat_indexers, tree=tree)
|
||||||
|
|
||||||
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
||||||
# b = swap_p.bind(x, a)
|
# b = swap_p.bind(x, a)
|
||||||
@ -132,22 +114,21 @@ def ref_get(ref: Any, idx: Indexer) -> Array:
|
|||||||
# x:Ref{f32[3]}[i, j] <- a
|
# x:Ref{f32[3]}[i, j] <- a
|
||||||
swap_p = core.Primitive("swap")
|
swap_p = core.Primitive("swap")
|
||||||
|
|
||||||
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_):
|
def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
|
||||||
del ref, value, idx
|
del ref, value, idx, tree
|
||||||
raise ValueError("Cannot run stateful primitive.")
|
raise ValueError("Cannot run stateful primitive.")
|
||||||
swap_p.def_impl(_swap_impl)
|
swap_p.def_impl(_swap_impl)
|
||||||
|
|
||||||
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array:
|
def ref_swap(ref_or_view: AbstractRef, idx: Indexer | None, value: Array,
|
||||||
|
*, _function_name: str = "ref_swap") -> Array:
|
||||||
"""Sets a `Ref`'s value and returns the original value."""
|
"""Sets a `Ref`'s value and returns the original value."""
|
||||||
ref_aval = core.get_aval(ref)
|
ref, indexers = _get_indexer(ref_or_view, idx, _function_name)
|
||||||
if not isinstance(ref_aval, AbstractRef):
|
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||||
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}")
|
return swap_p.bind(ref, value, *flat_indexers, tree=tree)
|
||||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
|
||||||
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
|
|
||||||
|
|
||||||
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
|
def ref_set(ref: AbstractRef, idx: Indexer | None, value: Array) -> None:
|
||||||
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
||||||
ref_swap(ref, idx, value)
|
ref_swap(ref, idx, value, _function_name="ref_set")
|
||||||
|
|
||||||
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
|
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
|
||||||
# Semantically,
|
# Semantically,
|
||||||
@ -163,38 +144,40 @@ def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
|
|||||||
addupdate_p = core.Primitive('addupdate')
|
addupdate_p = core.Primitive('addupdate')
|
||||||
addupdate_p.multiple_results = True
|
addupdate_p.multiple_results = True
|
||||||
|
|
||||||
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int):
|
def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
|
||||||
del ref, idx, value
|
del ref, value, args, tree
|
||||||
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
||||||
addupdate_p.def_impl(_addupdate_impl)
|
addupdate_p.def_impl(_addupdate_impl)
|
||||||
|
|
||||||
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None:
|
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
|
||||||
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||||||
ref_aval = core.get_aval(ref)
|
ref, indexers = _get_indexer(ref_or_view, idx, "ref_addupdate")
|
||||||
if not isinstance(ref_aval, AbstractRef):
|
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||||
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}")
|
return addupdate_p.bind(ref, x, *flat_indexers, tree=tree)
|
||||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
|
||||||
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
|
|
||||||
|
|
||||||
## get/set/addupdate abstract evaluation rules
|
## get/set/addupdate abstract evaluation rules
|
||||||
|
|
||||||
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
|
|
||||||
indexed_dims):
|
def _shape_after_indexing(
|
||||||
|
shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...]
|
||||||
|
) -> tuple[int, ...]:
|
||||||
|
for indexer in indexers:
|
||||||
|
# Run some simple checks that all the indexers have consistent shapes
|
||||||
|
assert indexer.shape == shape, (indexer.shape, shape)
|
||||||
|
shape = indexer.get_indexer_shape()
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
|
def _get_abstract_eval(ref_aval: AbstractRef, *args,
|
||||||
|
tree):
|
||||||
|
indexers = tree_util.tree_unflatten(tree, args)
|
||||||
if not isinstance(ref_aval, AbstractRef):
|
if not isinstance(ref_aval, AbstractRef):
|
||||||
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
|
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
|
||||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||||
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
out_shape = _shape_after_indexing(ref_aval.shape, indexers)
|
||||||
raise ValueError("`get` with nontrivial indexing must be called "
|
out_aval = ref_aval.inner_aval.update(shape=out_shape)
|
||||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
|
||||||
if len(indexed_dims) != len(ref_aval.shape):
|
|
||||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
|
||||||
if sum(indexed_dims) != len(idx):
|
|
||||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
|
||||||
idx_shapes = tuple(i.shape for i in idx)
|
|
||||||
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
|
|
||||||
out_aval = ref_aval.inner_aval.update(shape=shape)
|
|
||||||
else:
|
else:
|
||||||
if idx:
|
if indexers:
|
||||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||||
out_aval = ref_aval.inner_aval
|
out_aval = ref_aval.inner_aval
|
||||||
return (out_aval, {ReadEffect(0)})
|
return (out_aval, {ReadEffect(0)})
|
||||||
@ -202,34 +185,29 @@ get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
|||||||
|
|
||||||
def _swap_abstract_eval(ref_aval: AbstractRef,
|
def _swap_abstract_eval(ref_aval: AbstractRef,
|
||||||
val_aval: core.AbstractValue,
|
val_aval: core.AbstractValue,
|
||||||
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
|
*args: Any, tree):
|
||||||
|
indexers = tree_util.tree_unflatten(tree, args)
|
||||||
out_aval: core.AbstractValue
|
out_aval: core.AbstractValue
|
||||||
if not isinstance(ref_aval, AbstractRef):
|
if not isinstance(ref_aval, AbstractRef):
|
||||||
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||||
if len(indexed_dims) != len(ref_aval.shape):
|
|
||||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
|
||||||
if sum(indexed_dims) != len(idx):
|
|
||||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
|
||||||
val_aval = core.raise_to_shaped(val_aval)
|
val_aval = core.raise_to_shaped(val_aval)
|
||||||
assert isinstance(val_aval, core.ShapedArray)
|
assert isinstance(val_aval, core.ShapedArray)
|
||||||
idx_shapes = tuple(i.shape for i in idx)
|
expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers)
|
||||||
expected_output_shape = _get_slice_output_shape(
|
if expected_out_shape != val_aval.shape:
|
||||||
ref_aval.shape, idx_shapes, indexed_dims)
|
|
||||||
if expected_output_shape != val_aval.shape:
|
|
||||||
raise ValueError("Invalid shape for `swap`. "
|
raise ValueError("Invalid shape for `swap`. "
|
||||||
f"Ref shape: {ref_aval.shape}. "
|
f"Ref shape: {ref_aval.shape}. "
|
||||||
|
f"Expected shape: {expected_out_shape}. "
|
||||||
f"Value shape: {val_aval.shape}. "
|
f"Value shape: {val_aval.shape}. "
|
||||||
f"Indices: {idx}. ")
|
f"Indices: {indexers}. ")
|
||||||
if ref_aval.dtype != val_aval.dtype:
|
if ref_aval.dtype != val_aval.dtype:
|
||||||
raise ValueError("Invalid dtype for `swap`. "
|
raise ValueError("Invalid dtype for `swap`. "
|
||||||
f"Ref dtype: {ref_aval.dtype}. "
|
f"Ref dtype: {ref_aval.dtype}. "
|
||||||
f"Value shape: {val_aval.dtype}. ")
|
f"Value shape: {val_aval.dtype}. ")
|
||||||
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype)
|
out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype)
|
||||||
else:
|
else:
|
||||||
if idx:
|
if indexers:
|
||||||
raise ValueError("`swap` with nontrivial indexing must be called "
|
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
|
||||||
out_aval = ref_aval.inner_aval
|
out_aval = ref_aval.inner_aval
|
||||||
return (out_aval, {WriteEffect(0)})
|
return (out_aval, {WriteEffect(0)})
|
||||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||||
@ -237,93 +215,118 @@ swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
|||||||
|
|
||||||
def _addupdate_abstract_eval(ref_aval: AbstractRef,
|
def _addupdate_abstract_eval(ref_aval: AbstractRef,
|
||||||
val_aval: core.AbstractValue,
|
val_aval: core.AbstractValue,
|
||||||
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
|
*args: Any, tree):
|
||||||
|
indexers = tree_util.tree_unflatten(tree, args)
|
||||||
if not isinstance(ref_aval, AbstractRef):
|
if not isinstance(ref_aval, AbstractRef):
|
||||||
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||||||
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
|
||||||
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
|
||||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
|
||||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||||
if len(indexed_dims) != len(ref_aval.shape):
|
|
||||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
|
||||||
if sum(indexed_dims) != len(idx):
|
|
||||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
|
||||||
val_aval = core.raise_to_shaped(val_aval)
|
val_aval = core.raise_to_shaped(val_aval)
|
||||||
|
slice_shape = _shape_after_indexing(ref_aval.shape, indexers)
|
||||||
assert isinstance(val_aval, core.ShapedArray)
|
assert isinstance(val_aval, core.ShapedArray)
|
||||||
idx_shapes = tuple(i.shape for i in idx)
|
|
||||||
slice_shape = _get_slice_output_shape(
|
|
||||||
ref_aval.shape, idx_shapes, indexed_dims)
|
|
||||||
if slice_shape != val_aval.shape:
|
if slice_shape != val_aval.shape:
|
||||||
raise ValueError("Invalid shape for `addupdate`. "
|
raise ValueError("Invalid shape for `addupdate`. "
|
||||||
f"Ref shape: {ref_aval.shape}. "
|
f"Ref shape: {ref_aval.shape}. "
|
||||||
|
f"Slice shape: {slice_shape}. "
|
||||||
f"Value shape: {val_aval.shape}. "
|
f"Value shape: {val_aval.shape}. "
|
||||||
f"Indices: {idx}. ")
|
f"Indices: {indexers}. ")
|
||||||
if ref_aval.dtype != val_aval.dtype:
|
if ref_aval.dtype != val_aval.dtype:
|
||||||
raise ValueError("Invalid dtype for `addupdate`. "
|
raise ValueError("Invalid dtype for `addupdate`. "
|
||||||
f"Ref dtype: {ref_aval.dtype}. "
|
f"Ref dtype: {ref_aval.dtype}. "
|
||||||
f"Value shape: {val_aval.dtype}. ")
|
f"Value shape: {val_aval.dtype}. ")
|
||||||
elif idx:
|
else:
|
||||||
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
# Check that the indexers are valid
|
||||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
if indexers:
|
||||||
|
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||||
return [], {AccumEffect(0)}
|
return [], {AccumEffect(0)}
|
||||||
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
||||||
|
|
||||||
## Pretty printing for `get` and `swap` in jaxprs
|
## Pretty printing for `get` and `swap` in jaxprs
|
||||||
|
|
||||||
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
|
pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL,
|
||||||
foreground=pp.Color.GREEN)
|
foreground=pp.Color.GREEN)
|
||||||
|
|
||||||
def _pp_idx(context, non_slice_idx, indexed_dims):
|
def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice
|
||||||
idx_iter = iter(non_slice_idx)
|
) -> str:
|
||||||
idx = ','.join(core.pp_var(next(idx_iter), context) if indexed else ':'
|
start, size = slc.start, slc.size
|
||||||
for indexed in indexed_dims)
|
if isinstance(start, core.Var):
|
||||||
assert next(idx_iter, None) is None
|
start_str = core.pp_var(start, context)
|
||||||
return pp.text(idx)
|
end_str = f'{start_str}+{size}'
|
||||||
|
else:
|
||||||
|
start_str = '' if start == 0 else str(start)
|
||||||
|
end = start + size
|
||||||
|
end_str = '' if end == dim else str(end)
|
||||||
|
return f'{start_str}:{end_str}'
|
||||||
|
|
||||||
|
def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer
|
||||||
|
) -> pp.Doc:
|
||||||
|
indices = []
|
||||||
|
for idx, dim in zip(indexer.indices, indexer.shape):
|
||||||
|
if isinstance(idx, indexing.Slice):
|
||||||
|
indices.append(_pp_slice(context, dim, idx))
|
||||||
|
else:
|
||||||
|
indices.append(core.pp_var(idx, context)) # type: ignore
|
||||||
|
return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")])
|
||||||
|
|
||||||
|
def _pp_indexers(
|
||||||
|
context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...],
|
||||||
|
):
|
||||||
|
return pp.concat(
|
||||||
|
[pp_indexer(context, indexer) for indexer in indexers]
|
||||||
|
)
|
||||||
|
|
||||||
|
def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers):
|
||||||
|
return pp_ref_var(
|
||||||
|
pp.concat([
|
||||||
|
pp.text(core.pp_var(ref, context)),
|
||||||
|
_pp_indexers(context, indexers),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
|
||||||
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
|
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||||
# Pretty prints `a = get x i` as `x[i] <- a`
|
# Pretty prints `a = get x i` as `x[i] <- a`
|
||||||
y, = eqn.outvars
|
y, = eqn.outvars
|
||||||
x, *idx = eqn.invars
|
x, *flat_idx = eqn.invars
|
||||||
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
|
||||||
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||||
# TODO more general get
|
return pp.concat([
|
||||||
return pp.concat([lhs, pp.text(' <- '), pp_ref(pp.concat([
|
lhs,
|
||||||
pp.text(core.pp_var(x, context)), pp.text('['), idx, pp.text(']')]))])
|
pp.text(' <- '),
|
||||||
|
pp_ref_indexers(context, x, indexers)
|
||||||
|
])
|
||||||
core.pp_eqn_rules[get_p] = _get_pp_rule
|
core.pp_eqn_rules[get_p] = _get_pp_rule
|
||||||
|
|
||||||
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
|
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||||
y, = eqn.outvars
|
y, = eqn.outvars
|
||||||
x, v, *idx = eqn.invars
|
x, v, *flat_idx = eqn.invars
|
||||||
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
|
||||||
if type(y) is core.DropVar:
|
if type(y) is core.DropVar:
|
||||||
# In the case of a set (ignored return value),
|
# In the case of a set (ignored return value),
|
||||||
# pretty print `_ = swap x v i` as `x[i] <- v`
|
# pretty print `_ = swap x v i` as `x[i] <- v`
|
||||||
del y
|
del y
|
||||||
return pp.concat([
|
return pp.concat([
|
||||||
pp_ref(pp.concat([
|
pp_ref_indexers(context, x, indexers),
|
||||||
pp.text(core.pp_var(x, context)),
|
pp.text(' <- '),
|
||||||
pp.text('['), idx, pp.text(']')
|
pp.text(core.pp_var(v, context))
|
||||||
])), pp.text(' <- '), pp.text(core.pp_var(v, context))])
|
])
|
||||||
else:
|
else:
|
||||||
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
|
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
|
||||||
x_i = pp.concat([pp.text(core.pp_var(x, context)),
|
x_i = pp_ref_indexers(context, x, indexers)
|
||||||
pp.text('['), idx, pp.text(']')])
|
|
||||||
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||||
return pp.concat([y, pp.text(', '), pp_ref(x_i), pp.text(' <- '),
|
return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '),
|
||||||
pp_ref(x_i), pp.text(', '),
|
x_i, pp.text(', '),
|
||||||
pp.text(core.pp_var(v, context))])
|
pp.text(core.pp_var(v, context))])
|
||||||
core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||||||
|
|
||||||
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
|
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||||
|
del settings
|
||||||
# pretty-print ` = addupdate x i v` as `x[i] += v`
|
# pretty-print ` = addupdate x i v` as `x[i] += v`
|
||||||
() = eqn.outvars
|
() = eqn.outvars
|
||||||
x, v, *idx = eqn.invars
|
x, v, *flat_idx = eqn.invars
|
||||||
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
|
||||||
return pp.concat([
|
return pp.concat([
|
||||||
pp_ref(pp.concat([
|
pp_ref_indexers(context, x, indexers),
|
||||||
pp.text(core.pp_var(x, context)),
|
pp.text(' += '),
|
||||||
pp.text('['), idx, pp.text(']')
|
pp.text(core.pp_var(v, context))])
|
||||||
])), pp.text(' += '), pp.text(core.pp_var(v, context))])
|
|
||||||
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
||||||
|
|
||||||
## get/swap/addupdate JVP rules
|
## get/swap/addupdate JVP rules
|
||||||
@ -366,6 +369,7 @@ def _get_transpose(g, ref, *idx, **params):
|
|||||||
ad.primitive_transposes[get_p] = _get_transpose
|
ad.primitive_transposes[get_p] = _get_transpose
|
||||||
|
|
||||||
def _swap_transpose(g, ref, x, *idx, **params):
|
def _swap_transpose(g, ref, x, *idx, **params):
|
||||||
|
del x # old value doesn't matter anymore
|
||||||
# swap transpose is swap
|
# swap transpose is swap
|
||||||
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
|
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
|
||||||
return [None, x_bar] + [None] * len(idx)
|
return [None, x_bar] + [None] * len(idx)
|
||||||
@ -403,101 +407,162 @@ def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
|
|||||||
num_idxs_to_left = sum(indexed_dims[:ref_dim])
|
num_idxs_to_left = sum(indexed_dims[:ref_dim])
|
||||||
return ref_dim - num_idxs_to_left + len(idxs_shape)
|
return ref_dim - num_idxs_to_left + len(idxs_shape)
|
||||||
|
|
||||||
def _get_vmap(batched_args, batched_dims, *, indexed_dims):
|
def _batch_indexer(indexer: indexing.NDIndexer, dims,
|
||||||
|
axis_size: int,
|
||||||
|
ref_shape: tuple[int, ...],
|
||||||
|
ref_dim: int | batching.NotMapped,
|
||||||
|
idx_is_batched: bool) -> indexing.NDIndexer:
|
||||||
|
indices = indexer.indices
|
||||||
|
indices_dims = dims.indices
|
||||||
|
new_indices: list[Array | indexing.Slice | int] = []
|
||||||
|
new_integer_indexer_shape = (axis_size, *indexer.int_indexer_shape)
|
||||||
|
for idx, dim in zip(indices, indices_dims):
|
||||||
|
if idx_is_batched:
|
||||||
|
# If at least one of the idx is batched, we broadcast them all and move the
|
||||||
|
# batch dim to the front.
|
||||||
|
if isinstance(idx, indexing.Slice):
|
||||||
|
# size is static, but start can be dynamic
|
||||||
|
# Check if start is static (which it can be)
|
||||||
|
is_static_slice = len(tree_util.tree_leaves(idx)) == 0
|
||||||
|
if is_static_slice:
|
||||||
|
new_indices.append(idx)
|
||||||
|
continue
|
||||||
|
dim = dim.start
|
||||||
|
if dim is batching.not_mapped:
|
||||||
|
# Broadcasting the slice is free (the start index stays the same)
|
||||||
|
new_indices.append(idx)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"No support for vmapping over nontrivial slices just yet: {idx}")
|
||||||
|
else:
|
||||||
|
# Check if we are indexing with a scalar or not. If we are indexing
|
||||||
|
# with a scalar and we are not batched, we can avoid broadcasting it.
|
||||||
|
assert hasattr(idx, "shape")
|
||||||
|
if not idx.shape:
|
||||||
|
if dim is not batching.not_mapped:
|
||||||
|
assert idx.shape == (axis_size,)
|
||||||
|
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, (0,))
|
||||||
|
new_indices.append(idx)
|
||||||
|
else:
|
||||||
|
if dim is batching.not_mapped:
|
||||||
|
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
|
||||||
|
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
|
||||||
|
bcast_dims)
|
||||||
|
else:
|
||||||
|
idx = batching.moveaxis(idx, dim, 0)
|
||||||
|
new_indices.append(idx)
|
||||||
|
else:
|
||||||
|
if ref_dim is not batching.not_mapped:
|
||||||
|
if not isinstance(idx, indexing.Slice):
|
||||||
|
assert hasattr(idx, "shape")
|
||||||
|
if idx.shape:
|
||||||
|
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
|
||||||
|
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
|
||||||
|
bcast_dims)
|
||||||
|
new_indices.append(idx)
|
||||||
|
if ref_dim is not batching.not_mapped:
|
||||||
|
iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0)
|
||||||
|
new_indices.insert(ref_dim, iota)
|
||||||
|
return indexing.NDIndexer(tuple(new_indices), ref_shape,
|
||||||
|
new_integer_indexer_shape,
|
||||||
|
validate=True)
|
||||||
|
|
||||||
|
def _get_vmap(batched_args, batched_dims, *, tree):
|
||||||
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||||||
if d is not batching.not_mapped}
|
if d is not batching.not_mapped}
|
||||||
ref, *idxs = batched_args
|
ref, *flat_idxs = batched_args
|
||||||
ref_dim, *idx_dims = batched_dims
|
ref_dim, *flat_idx_dims = batched_dims
|
||||||
|
indexers = tree_util.tree_unflatten(tree, flat_idxs)
|
||||||
|
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
|
||||||
|
|
||||||
ref_is_batched = ref_dim is not batching.not_mapped
|
idx_is_batched = any(i_dim is not batching.not_mapped
|
||||||
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
for i_dim in flat_idx_dims)
|
||||||
bdim_out = 0
|
if len(indexers) > 1:
|
||||||
|
raise NotImplementedError("Batching with multiple indexers not supported.")
|
||||||
if idx_is_batched:
|
# TODO(sharadmv): handle vmap of multiple indexers
|
||||||
# If at least one of the idx is batched, we broadcast them all and move the
|
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
|
||||||
# batch dim to the front.
|
ref.shape, ref_dim, idx_is_batched)
|
||||||
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
|
for indexer, dims in zip(indexers, indexers_dims))
|
||||||
in zip(idxs, idx_dims))
|
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
return get_p.bind(ref, *flat_indexers, tree=tree), 0
|
||||||
if ref_is_batched:
|
|
||||||
# If ref is batched, we are doing a `get` with an additional axis. If `idxs`
|
|
||||||
# are also batched, then we are indexing into the batch axis with an `iota`.
|
|
||||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, idx_is_batched)
|
|
||||||
if idx_is_batched:
|
|
||||||
# If we have batched idx, we need to insert the new iota index. The place
|
|
||||||
# where we add in the new `iota` index is `ref_dim` so we need to compute
|
|
||||||
# what `ref_dim` *would be* if we inserted it into `idxs` instead, because
|
|
||||||
# `idxs` doesn't include the non indexed dims.
|
|
||||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
|
||||||
if i_dim].index(ref_dim)
|
|
||||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
|
||||||
idxs = tuple_insert(idxs, idx_place, iota)
|
|
||||||
else:
|
|
||||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
|
||||||
return get_p.bind(ref, *idxs, indexed_dims=indexed_dims), bdim_out
|
|
||||||
batching.primitive_batchers[get_p] = _get_vmap
|
batching.primitive_batchers[get_p] = _get_vmap
|
||||||
|
|
||||||
def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
|
def _swap_vmap(batched_args, batched_dims, *, tree):
|
||||||
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||||||
if d is not batching.not_mapped}
|
if d is not batching.not_mapped}
|
||||||
ref, val, *idxs = batched_args
|
ref, val, *flat_idxs = batched_args
|
||||||
ref_dim, val_dim, *idx_dims = batched_dims
|
ref_dim, val_dim, *flat_idx_dims = batched_dims
|
||||||
|
indexers = tree_util.tree_unflatten(tree, flat_idxs)
|
||||||
|
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
|
||||||
|
|
||||||
ref_is_batched = ref_dim is not batching.not_mapped
|
ref_is_batched = ref_dim is not batching.not_mapped
|
||||||
val_is_batched = val_dim is not batching.not_mapped
|
val_is_batched = val_dim is not batching.not_mapped
|
||||||
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
idx_is_batched = any(i_dim is not batching.not_mapped
|
||||||
if idx_is_batched:
|
for i_dim in flat_idx_dims)
|
||||||
# If at least one of the idx is batched, we broadcast them all and move the
|
if len(indexers) > 1:
|
||||||
# batch dim to the front.
|
raise NotImplementedError("Batching with multiple indexers not supported.")
|
||||||
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
|
# TODO(sharadmv): handle vmap of multiple indexers
|
||||||
in zip(idxs, idx_dims))
|
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
|
||||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
ref.shape, ref_dim, idx_is_batched)
|
||||||
if ref_is_batched and not idx_is_batched:
|
for indexer, dims in zip(indexers, indexers_dims))
|
||||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
|
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
if (ref_is_batched or idx_is_batched) and not val_is_batched:
|
||||||
if not val_is_batched:
|
val = batching.broadcast(val, axis_size, 0)
|
||||||
val = batching.broadcast(val, axis_size, 0)
|
if val_is_batched:
|
||||||
val_dim = 0
|
|
||||||
val = batching.moveaxis(val, val_dim, bdim_out)
|
|
||||||
elif idx_is_batched:
|
|
||||||
assert ref_is_batched and val_is_batched
|
|
||||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
|
||||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
|
||||||
if i_dim].index(ref_dim)
|
|
||||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
|
||||||
idxs = tuple_insert(idxs, idx_place, iota)
|
|
||||||
val = batching.moveaxis(val, val_dim, 0)
|
val = batching.moveaxis(val, val_dim, 0)
|
||||||
bdim_out = 0
|
return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0
|
||||||
return swap_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), bdim_out
|
|
||||||
batching.primitive_batchers[swap_p] = _swap_vmap
|
batching.primitive_batchers[swap_p] = _swap_vmap
|
||||||
|
|
||||||
def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
|
def _addupdate_vmap(batched_args, batched_dims, *, tree):
|
||||||
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||||||
if d is not batching.not_mapped}
|
if d is not batching.not_mapped}
|
||||||
ref, val, *idxs = batched_args
|
ref, val, *flat_idxs = batched_args
|
||||||
ref_dim, val_dim, *idx_dims = batched_dims
|
ref_dim, val_dim, *flat_idx_dims = batched_dims
|
||||||
|
indexers = tree_util.tree_unflatten(tree, flat_idxs)
|
||||||
|
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
|
||||||
|
|
||||||
ref_is_batched = ref_dim is not batching.not_mapped
|
ref_is_batched = ref_dim is not batching.not_mapped
|
||||||
val_is_batched = val_dim is not batching.not_mapped
|
val_is_batched = val_dim is not batching.not_mapped
|
||||||
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
idx_is_batched = any(i_dim is not batching.not_mapped
|
||||||
if idx_is_batched:
|
for i_dim in flat_idx_dims)
|
||||||
# If at least one of the idx is batched, we ensure all have bdims at front.
|
if len(indexers) > 1:
|
||||||
idxs = tuple(batching.bdim_at_front(i, d, axis_size)
|
raise NotImplementedError("Batching with multiple indexers not supported.")
|
||||||
for i, d in zip(idxs, idx_dims))
|
# TODO(sharadmv): handle vmap of multiple indexers
|
||||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
indexers = tuple(_batch_indexer(indexer, dims, axis_size,
|
||||||
if ref_is_batched and not idx_is_batched:
|
ref.shape, ref_dim, idx_is_batched)
|
||||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
|
for indexer, dims in zip(indexers, indexers_dims))
|
||||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
flat_indexers, tree = tree_util.tree_flatten(indexers)
|
||||||
if not val_is_batched:
|
if (ref_is_batched or idx_is_batched) and not val_is_batched:
|
||||||
val = batching.broadcast(val, axis_size, 0)
|
val = batching.broadcast(val, axis_size, 0)
|
||||||
val_dim = 0
|
if val_is_batched:
|
||||||
val = batching.moveaxis(val, val_dim, bdim_out)
|
|
||||||
elif idx_is_batched:
|
|
||||||
assert ref_is_batched and val_is_batched
|
|
||||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
|
||||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
|
||||||
if i_dim].index(ref_dim)
|
|
||||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
|
||||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
|
||||||
idxs = tuple_insert(idxs, idx_place, iota)
|
|
||||||
val = batching.moveaxis(val, val_dim, 0)
|
val = batching.moveaxis(val, val_dim, 0)
|
||||||
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
|
return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), []
|
||||||
batching.primitive_batchers[addupdate_p] = _addupdate_vmap
|
batching.primitive_batchers[addupdate_p] = _addupdate_vmap
|
||||||
|
|
||||||
|
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
|
||||||
|
# We could use `jnp.broadcast_to` but that lowers to squeezing,
|
||||||
|
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
|
||||||
|
# so in the lowering, we have to expand out those squeezed dimensions again.
|
||||||
|
# Having a simple `broadcast_to` primitive allows us to lower directly
|
||||||
|
# to `tl.broadcast_to`.
|
||||||
|
broadcast_to_p = core.Primitive('broadcast_to')
|
||||||
|
|
||||||
|
def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array:
|
||||||
|
import jax.numpy as jnp
|
||||||
|
a = jnp.asarray(a)
|
||||||
|
if a.shape == shape:
|
||||||
|
return a
|
||||||
|
return broadcast_to_p.bind(a, shape=shape)
|
||||||
|
|
||||||
|
@broadcast_to_p.def_impl
|
||||||
|
def _broadcast_to_impl(a, *, shape):
|
||||||
|
import jax.numpy as jnp
|
||||||
|
return jnp.broadcast_to(a, shape)
|
||||||
|
|
||||||
|
@broadcast_to_p.def_abstract_eval
|
||||||
|
def _broadcast_to_abstract_eval(aval, *, shape):
|
||||||
|
return core.ShapedArray(shape, aval.dtype)
|
||||||
|
|
||||||
|
mlir.register_lowering(
|
||||||
|
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
|
||||||
|
)
|
||||||
|
@ -23,6 +23,7 @@ from typing import Any, Generic, TypeVar, Union
|
|||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import effects
|
from jax._src import effects
|
||||||
from jax._src import pretty_printer as pp
|
from jax._src import pretty_printer as pp
|
||||||
|
from jax._src.state import indexing
|
||||||
from jax._src.util import safe_map, safe_zip
|
from jax._src.util import safe_map, safe_zip
|
||||||
from jax._src.typing import Array
|
from jax._src.typing import Array
|
||||||
|
|
||||||
@ -77,21 +78,34 @@ Aval = TypeVar("Aval", bound=core.AbstractValue)
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class RefIndexer:
|
class RefIndexer:
|
||||||
ref: Any
|
ref_or_view: Any
|
||||||
|
|
||||||
def __getitem__(self, slc):
|
def __getitem__(self, slc):
|
||||||
if not isinstance(slc, tuple):
|
if not isinstance(slc, tuple):
|
||||||
slc = (slc,)
|
slc = (slc,)
|
||||||
return RefView(self.ref, slc)
|
indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)
|
||||||
|
if isinstance(self.ref_or_view, RefView):
|
||||||
|
view = self.ref_or_view
|
||||||
|
return RefView(view.ref, (*view.indexers, indexer))
|
||||||
|
return RefView(self.ref_or_view, (indexer,))
|
||||||
|
|
||||||
|
Indexer = Any
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class RefView:
|
class RefView:
|
||||||
ref: Any
|
ref: Any
|
||||||
indexer: Any
|
indexers: tuple[indexing.NDIndexer, ...]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def at(self):
|
def shape(self) -> tuple[int, ...]:
|
||||||
raise NotImplementedError("Can't call `.at` multiple times.")
|
assert (
|
||||||
|
len(self.indexers) > 0
|
||||||
|
), "Should not be able to create a trivial RefView"
|
||||||
|
return self.indexers[-1].get_indexer_shape()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def at(self) -> RefIndexer:
|
||||||
|
return RefIndexer(self)
|
||||||
|
|
||||||
|
|
||||||
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
||||||
@ -137,14 +151,10 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
|
|||||||
return ref_set(tracer, idx, value)
|
return ref_set(tracer, idx, value)
|
||||||
|
|
||||||
def _getitem(self, tracer, idx) -> Array:
|
def _getitem(self, tracer, idx) -> Array:
|
||||||
if not isinstance(idx, tuple):
|
|
||||||
idx = idx,
|
|
||||||
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
||||||
return ref_get(tracer, idx)
|
return ref_get(tracer, idx)
|
||||||
|
|
||||||
def _setitem(self, tracer, idx, value) -> None:
|
def _setitem(self, tracer, idx, value) -> None:
|
||||||
if not isinstance(idx, tuple):
|
|
||||||
idx = idx,
|
|
||||||
from jax._src.state.primitives import ref_set # pytype: disable=import-error
|
from jax._src.state.primitives import ref_set # pytype: disable=import-error
|
||||||
return ref_set(tracer, idx, value)
|
return ref_set(tracer, idx, value)
|
||||||
|
|
||||||
|
@ -1472,6 +1472,7 @@ tf_not_yet_impl = [
|
|||||||
"for",
|
"for",
|
||||||
"inspect_sharding",
|
"inspect_sharding",
|
||||||
"io_callback",
|
"io_callback",
|
||||||
|
"broadcast_to",
|
||||||
"shard_map",
|
"shard_map",
|
||||||
"global_array_to_host_local_array",
|
"global_array_to_host_local_array",
|
||||||
"host_local_array_to_global_array",
|
"host_local_array_to_global_array",
|
||||||
|
@ -17,10 +17,6 @@
|
|||||||
from jax._src import pallas
|
from jax._src import pallas
|
||||||
from jax._src.pallas.core import BlockSpec
|
from jax._src.pallas.core import BlockSpec
|
||||||
from jax._src.pallas.core import no_block_spec
|
from jax._src.pallas.core import no_block_spec
|
||||||
from jax._src.pallas.indexing import broadcast_to
|
|
||||||
from jax._src.pallas.indexing import ds
|
|
||||||
from jax._src.pallas.indexing import dslice
|
|
||||||
from jax._src.pallas.indexing import Slice
|
|
||||||
from jax._src.pallas.pallas_call import pallas_call
|
from jax._src.pallas.pallas_call import pallas_call
|
||||||
from jax._src.pallas.pallas_call import pallas_call_p
|
from jax._src.pallas.pallas_call import pallas_call_p
|
||||||
from jax._src.pallas.primitives import atomic_add
|
from jax._src.pallas.primitives import atomic_add
|
||||||
@ -42,6 +38,10 @@ from jax._src.pallas.utils import cdiv
|
|||||||
from jax._src.pallas.utils import next_power_of_2
|
from jax._src.pallas.utils import next_power_of_2
|
||||||
from jax._src.pallas.utils import strides_from_shape
|
from jax._src.pallas.utils import strides_from_shape
|
||||||
from jax._src.pallas.utils import when
|
from jax._src.pallas.utils import when
|
||||||
|
from jax._src.state.primitives import broadcast_to
|
||||||
|
from jax._src.state.indexing import ds
|
||||||
|
from jax._src.state.indexing import dslice
|
||||||
|
from jax._src.state.indexing import Slice
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from jax.experimental.pallas import gpu # pytype: disable=import-error
|
from jax.experimental.pallas import gpu # pytype: disable=import-error
|
||||||
|
@ -22,7 +22,7 @@ from absl.testing import absltest
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import jax
|
import jax
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src.pallas import indexing
|
from jax._src.state import indexing
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -49,7 +49,8 @@ def int_indexer_strategy(dim) -> hps.SearchStrategy[int]:
|
|||||||
@hps.composite
|
@hps.composite
|
||||||
def slice_indexer_strategy(draw, dim) -> Slice | slice:
|
def slice_indexer_strategy(draw, dim) -> Slice | slice:
|
||||||
start = draw(int_indexer_strategy(dim))
|
start = draw(int_indexer_strategy(dim))
|
||||||
size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max))
|
max_size = dim - start
|
||||||
|
size = draw(hps.integers(min_value=0, max_value=max_size))
|
||||||
return draw(
|
return draw(
|
||||||
hps.one_of(
|
hps.one_of(
|
||||||
hps.just(Slice(start, size)), hps.just(slice(start, start + size))
|
hps.just(Slice(start, size)), hps.just(slice(start, start + size))
|
||||||
@ -78,8 +79,8 @@ def indexer_strategy(draw, dim, int_indexer_shape
|
|||||||
def nd_indexer_strategy(draw, shape) -> NDIndexer:
|
def nd_indexer_strategy(draw, shape) -> NDIndexer:
|
||||||
num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
|
num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
|
||||||
int_indexer_shape = draw(hnp.array_shapes())
|
int_indexer_shape = draw(hnp.array_shapes())
|
||||||
indices = [draw(indexer_strategy(dim, int_indexer_shape)) for dim
|
indices = tuple(draw(indexer_strategy(dim, int_indexer_shape))
|
||||||
in shape[:num_indices]]
|
for dim in shape[:num_indices])
|
||||||
return NDIndexer.from_indices_shape(indices, shape)
|
return NDIndexer.from_indices_shape(indices, shape)
|
||||||
|
|
||||||
|
|
||||||
@ -97,6 +98,24 @@ class IndexerTest(parameterized.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||||
|
|
||||||
|
def test_invalid_ndindexer_oob_int(self):
|
||||||
|
indices = (4, 0)
|
||||||
|
shape = (3, 5)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||||
|
|
||||||
|
def test_invalid_ndindexer_oob_slice_start(self):
|
||||||
|
indices = (slice(3, 2), 0)
|
||||||
|
shape = (3, 5)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||||
|
|
||||||
|
def test_invalid_ndindexer_oob_slice_end(self):
|
||||||
|
indices = (Slice(2, 2), 0)
|
||||||
|
shape = (3, 5)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||||
|
|
||||||
def test_ndindexer_with_padding(self):
|
def test_ndindexer_with_padding(self):
|
||||||
indices = ()
|
indices = ()
|
||||||
shape = (5, 5)
|
shape = (5, 5)
|
||||||
@ -137,17 +156,17 @@ class IndexerTest(parameterized.TestCase):
|
|||||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3))
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3))
|
||||||
|
|
||||||
indices = (0, slice(4, 10), np.arange(5))
|
indices = (0, slice(2, 10), np.arange(5))
|
||||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 0))
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 1))
|
||||||
|
|
||||||
indices = (0, 5, np.arange(5))
|
indices = (0, 1, np.arange(5))
|
||||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
|
||||||
|
|
||||||
indices = (ds(2, 3), np.arange(5)[:, None], np.arange(4)[None])
|
indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None])
|
||||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 3))
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2))
|
||||||
|
|
||||||
@hp.given(hps.data())
|
@hp.given(hps.data())
|
||||||
def test_ndindexer(self, data):
|
def test_ndindexer(self, data):
|
||||||
|
@ -1471,7 +1471,7 @@ class PallasPrimitivesTest(PallasTest):
|
|||||||
@parameterized.parameters(*[
|
@parameterized.parameters(*[
|
||||||
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "<- a[:,:,:]"),
|
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "<- a[:,:,:]"),
|
||||||
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "<- a[:3,:,:]"),
|
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "<- a[:3,:,:]"),
|
||||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:4,:,:4]"),
|
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:,:,:4]"),
|
||||||
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "<- a[b,:,:4]"),
|
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "<- a[b,:,:4]"),
|
||||||
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.ds(4)), "<- a[f,g,:4]"),
|
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.ds(4)), "<- a[f,g,:4]"),
|
||||||
])
|
])
|
||||||
@ -1486,7 +1486,7 @@ class PallasPrimitivesTest(PallasTest):
|
|||||||
@parameterized.parameters(*[
|
@parameterized.parameters(*[
|
||||||
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "a[:,:,:] <-"),
|
(lambda: (pl.dslice(0, 4), slice(None), slice(None)), "a[:,:,:] <-"),
|
||||||
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "a[:3,:,:] <-"),
|
(lambda: (pl.dslice(0, 3), slice(None), slice(None)), "a[:3,:,:] <-"),
|
||||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:4,:,:4] <-"),
|
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:,:,:4] <-"),
|
||||||
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "a[b,:,:4] <-"),
|
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "a[b,:,:4] <-"),
|
||||||
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), "a[m,n,:4] <-"),
|
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), "a[m,n,:4] <-"),
|
||||||
])
|
])
|
||||||
@ -1504,7 +1504,7 @@ class PallasPrimitivesTest(PallasTest):
|
|||||||
(lambda: (pl.dslice(0, 3), slice(None), slice(None)),
|
(lambda: (pl.dslice(0, 3), slice(None), slice(None)),
|
||||||
"c:i32[3,3,2], a[:3,:,:] <-"),
|
"c:i32[3,3,2], a[:3,:,:] <-"),
|
||||||
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)),
|
(lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)),
|
||||||
"c:i32[3,3,4], a[1:4,:,:4] <-"),
|
"c:i32[3,3,4], a[1:,:,:4] <-"),
|
||||||
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)),
|
(lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)),
|
||||||
"e:i32[5,3,4], a[b,:,:4] <-"),
|
"e:i32[5,3,4], a[b,:,:4] <-"),
|
||||||
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)),
|
(lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)),
|
||||||
|
@ -56,15 +56,15 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
def test_cant_eval_get_primitive(self):
|
def test_cant_eval_get_primitive(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
get_p.bind(jnp.ones(5))
|
get_p.bind(jnp.ones(5), tree=None)
|
||||||
|
|
||||||
def test_cant_eval_swap_primitive(self):
|
def test_cant_eval_swap_primitive(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
swap_p.bind(jnp.ones(5), jnp.zeros(5))
|
swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
|
||||||
|
|
||||||
def test_cant_eval_addupdate_primitive(self):
|
def test_cant_eval_addupdate_primitive(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
|
addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
|
||||||
|
|
||||||
def test_get_abstract_aval_must_take_in_refs(self):
|
def test_get_abstract_aval_must_take_in_refs(self):
|
||||||
ref_aval = core.ShapedArray((), jnp.float32)
|
ref_aval = core.ShapedArray((), jnp.float32)
|
||||||
@ -95,11 +95,37 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
|||||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||||
|
dict(testcase_name="get_with_nontrivial_slice",
|
||||||
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
|
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||||
|
dict(testcase_name="get_with_nontrivial_slice2",
|
||||||
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
|
||||||
|
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
|
||||||
|
dict(testcase_name="get_with_ref_simple_at",
|
||||||
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(1, 3), slice(None), slice(None)),
|
||||||
|
out_shape=(2, 2, 4), out_dtype=jnp.float32,
|
||||||
|
at_indices=((0,),)),
|
||||||
|
dict(testcase_name="get_with_ref_simple_at2",
|
||||||
|
ref_shape=(6, 1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(0, 2), slice(0, 1), slice(1, 3), slice(None), slice(None)),
|
||||||
|
out_shape=(2, 1, 2, 2, 4), out_dtype=jnp.float32,
|
||||||
|
at_indices=((slice(2, 6),),)),
|
||||||
|
dict(testcase_name="get_with_ref_multiple_at",
|
||||||
|
ref_shape=(1, 3, 5, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(None), slice(None), slice(0, 2)),
|
||||||
|
out_shape=(3, 1, 2), out_dtype=jnp.float32,
|
||||||
|
at_indices=((0,), (slice(None), slice(0, 1)))),
|
||||||
)
|
)
|
||||||
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
|
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
|
||||||
out_dtype=None, should_error=False):
|
out_dtype=None, at_indices=(),
|
||||||
|
should_error=False):
|
||||||
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||||
def f(x_ref):
|
def f(x_ref):
|
||||||
|
for at_idx in at_indices:
|
||||||
|
x_ref = x_ref.at[at_idx]
|
||||||
out = ref_get(x_ref, idx)
|
out = ref_get(x_ref, idx)
|
||||||
return [out]
|
return [out]
|
||||||
if should_error:
|
if should_error:
|
||||||
@ -160,13 +186,43 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
|||||||
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
||||||
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||||
|
dict(testcase_name="swap_with_nontrivial_slice",
|
||||||
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
|
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
||||||
|
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
||||||
|
dict(testcase_name="swap_with_nontrivial_slice2",
|
||||||
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
|
||||||
|
val_shape=(1, 2, 2, 4), val_dtype=jnp.float32,
|
||||||
|
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
|
||||||
|
dict(testcase_name="swap_with_ref_simple_at",
|
||||||
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(0, 1), slice(1, 3), slice(None),),
|
||||||
|
val_shape=(1, 1, 4), val_dtype=jnp.float32,
|
||||||
|
out_shape=(1, 1, 4), out_dtype=jnp.float32,
|
||||||
|
at_indices=((0,),),),
|
||||||
|
dict(testcase_name="swap_with_ref_simple_at2",
|
||||||
|
ref_shape=(4, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
|
||||||
|
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
|
||||||
|
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
|
||||||
|
at_indices=((slice(0, 2),),),),
|
||||||
|
dict(testcase_name="swap_with_ref_multiple_at2",
|
||||||
|
ref_shape=(1, 4, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
|
||||||
|
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
|
||||||
|
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
|
||||||
|
at_indices=((slice(None), slice(0, 2),), (0,)),),
|
||||||
)
|
)
|
||||||
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
|
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
|
||||||
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
||||||
should_error=False):
|
at_indices=(), should_error=False):
|
||||||
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||||
val_aval = core.ShapedArray(val_shape, val_dtype)
|
val_aval = core.ShapedArray(val_shape, val_dtype)
|
||||||
def f(x_ref, val):
|
def f(x_ref, val):
|
||||||
|
for at_idx in at_indices:
|
||||||
|
x_ref = x_ref.at[at_idx]
|
||||||
out = ref_swap(x_ref, idx, val)
|
out = ref_swap(x_ref, idx, val)
|
||||||
return [out]
|
return [out]
|
||||||
if should_error:
|
if should_error:
|
||||||
@ -192,13 +248,13 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
|||||||
idx=(slice(None),), should_error=True),
|
idx=(slice(None),), should_error=True),
|
||||||
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
|
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
|
||||||
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
|
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
|
||||||
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
|
idx=(),),
|
||||||
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
|
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
|
||||||
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
|
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
|
||||||
idx=(), should_error=True),
|
idx=(), should_error=True),
|
||||||
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
|
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
|
||||||
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
||||||
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
|
idx=(0,),),
|
||||||
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
|
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
|
||||||
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
|
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
|
||||||
idx=(slice(None), 0)),
|
idx=(slice(None), 0)),
|
||||||
@ -216,13 +272,34 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
|||||||
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
||||||
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
|
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
|
||||||
|
dict(testcase_name="ref_with_simple_at",
|
||||||
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
val_shape=(2, 2), val_dtype=jnp.float32,
|
||||||
|
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
|
at_indices=((0,),)),
|
||||||
|
dict(testcase_name="ref_with_simple_at2",
|
||||||
|
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
val_shape=(2, 3, 4), val_dtype=jnp.float32,
|
||||||
|
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
|
at_indices=((slice(0, 3),),)),
|
||||||
|
dict(testcase_name="ref_with_multiple_at",
|
||||||
|
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
val_shape=(2, 2), val_dtype=jnp.float32,
|
||||||
|
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
|
at_indices=((slice(0, 3),), (0,))),
|
||||||
|
dict(testcase_name="ref_with_multiple_at2",
|
||||||
|
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
|
||||||
|
val_shape=(2, 2), val_dtype=jnp.float32,
|
||||||
|
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
|
||||||
|
at_indices=((slice(None), slice(0, 3),), (0,))),
|
||||||
)
|
)
|
||||||
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
|
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
|
||||||
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
val_shape, val_dtype, idx, at_indices=(), should_error=False):
|
||||||
should_error=False):
|
|
||||||
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||||
val_aval = core.ShapedArray(val_shape, val_dtype)
|
val_aval = core.ShapedArray(val_shape, val_dtype)
|
||||||
def f(x_ref, val):
|
def f(x_ref, val):
|
||||||
|
for at_idx in at_indices:
|
||||||
|
x_ref = x_ref.at[at_idx]
|
||||||
ref_addupdate(x_ref, idx, val)
|
ref_addupdate(x_ref, idx, val)
|
||||||
return []
|
return []
|
||||||
if should_error:
|
if should_error:
|
||||||
@ -1595,7 +1672,6 @@ if CAN_USE_HYPOTHESIS:
|
|||||||
def test_vjp(self, data):
|
def test_vjp(self, data):
|
||||||
|
|
||||||
spec = data.draw(func_spec())
|
spec = data.draw(func_spec())
|
||||||
print(spec)
|
|
||||||
|
|
||||||
def impl(x):
|
def impl(x):
|
||||||
return spec.call((x, jnp.zeros_like(x)))[1]
|
return spec.call((x, jnp.zeros_like(x)))[1]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user