2024-01-02 15:52:57 -08:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import dataclasses
|
|
|
|
from typing import Any, Union
|
|
|
|
|
|
|
|
from jax._src import core
|
|
|
|
from jax._src import tree_util
|
|
|
|
from jax._src.typing import Array
|
|
|
|
from jax._src.util import merge_lists
|
|
|
|
from jax._src.util import partition_list
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class Slice:
|
|
|
|
"""Represents a slice with a dynamic start index and a fixed size."""
|
|
|
|
start: Any
|
|
|
|
size: int
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if self.size < 0:
|
|
|
|
raise ValueError("`size` must not be negative.")
|
|
|
|
|
|
|
|
def tree_flatten(self):
|
|
|
|
# If `start` is statically known, we treat it as static information
|
|
|
|
if isinstance(self.start, int):
|
|
|
|
return (), (self.start, self.size)
|
|
|
|
return (self.start,), (self.size,)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, aux_data, children) -> Slice:
|
|
|
|
return cls(*children, *aux_data)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_slice(cls, slc: slice, size: int) -> Slice:
|
|
|
|
start, stop, step = slc.indices(size)
|
|
|
|
if step != 1:
|
|
|
|
raise ValueError(f"slice must have a step of 1 (found: {step})")
|
|
|
|
return cls(start, stop - start)
|
|
|
|
|
|
|
|
|
|
|
|
def dslice(start: int | Array | None, size: int | None = None
|
|
|
|
) -> slice | Slice:
|
|
|
|
"""Constructs a `Slice` from a start and a size."""
|
|
|
|
if start is None:
|
|
|
|
return slice(None)
|
|
|
|
if size is None:
|
|
|
|
if not isinstance(start, int):
|
|
|
|
raise ValueError("Non-static `dslice`")
|
|
|
|
return Slice(0, start)
|
|
|
|
return Slice(start, size)
|
|
|
|
ds = dslice # Handy alias
|
|
|
|
|
|
|
|
|
|
|
|
IntIndexer = Union[int, Array]
|
|
|
|
DimIndexer = Union[IntIndexer, Slice]
|
|
|
|
|
|
|
|
def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
|
|
|
|
tuple[Slice, ...],
|
|
|
|
tuple[IntIndexer, ...]]:
|
|
|
|
is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
|
|
|
|
slice_indexers, int_indexers = partition_list(
|
|
|
|
is_int_indexing, indexer.indices)
|
|
|
|
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
|
|
|
|
|
|
|
|
def _maybe_concretize(x: Any):
|
|
|
|
try:
|
|
|
|
return core.concrete_or_error(None, x)
|
|
|
|
except core.ConcretizationTypeError:
|
|
|
|
return None
|
|
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class NDIndexer:
|
|
|
|
indices: tuple[DimIndexer, ...]
|
|
|
|
shape: tuple[int, ...]
|
|
|
|
int_indexer_shape: tuple[int, ...]
|
|
|
|
validate: bool = False
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if not self.validate:
|
|
|
|
return
|
|
|
|
if len(self.indices) != len(self.shape):
|
|
|
|
raise ValueError(
|
|
|
|
f"`indices` must be the same length as `Ref` shape.: {self}."
|
|
|
|
)
|
|
|
|
# We validate integer indexing shapes here
|
|
|
|
for idx, s in zip(self.indices, self.shape):
|
|
|
|
if isinstance(idx, Slice):
|
|
|
|
start = idx.start
|
|
|
|
if value := _maybe_concretize(start):
|
|
|
|
if value >= s:
|
|
|
|
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
|
|
|
|
if value + idx.size > s:
|
|
|
|
raise ValueError(
|
|
|
|
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
# The shape of indexer integers should be broadcastable up to the
|
|
|
|
# int_indexer_shape of the whole NDIndexer
|
|
|
|
if not np.shape(idx):
|
|
|
|
if (value := _maybe_concretize(idx)) and value >= s:
|
|
|
|
raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
|
|
|
|
# For ()-shaped indexers, we can broadcast no problm.
|
|
|
|
continue
|
|
|
|
# If we don't have a ()-shaped indexer, the rank must match
|
|
|
|
# int_indexer_shape
|
|
|
|
if np.ndim(idx) != len(self.int_indexer_shape):
|
|
|
|
raise ValueError(
|
|
|
|
f"Indexer must have rank {np.ndim(idx)}: {idx=} vs."
|
|
|
|
f" {self.int_indexer_shape=}"
|
|
|
|
)
|
|
|
|
# Here we check that the shapes broadcast.
|
|
|
|
try:
|
|
|
|
np.broadcast_shapes(np.shape(idx), self.int_indexer_shape)
|
|
|
|
except ValueError as e:
|
|
|
|
raise ValueError(
|
|
|
|
f"Could not broadcast integer indexer: {idx=} vs."
|
|
|
|
f" {self.int_indexer_shape=}"
|
|
|
|
) from e
|
|
|
|
|
|
|
|
def tree_flatten(self):
|
|
|
|
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
|
|
|
|
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, data, flat_idx):
|
|
|
|
idx_tree, shape, int_indexer_shape = data
|
|
|
|
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
|
|
|
|
return NDIndexer(tuple(indices), shape, int_indexer_shape)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
|
|
|
if not isinstance(indices, tuple):
|
|
|
|
indices = (indices,)
|
2024-01-02 21:53:30 -08:00
|
|
|
if len(indices) == 1 and indices[0] is ...:
|
|
|
|
indices = (slice(None),) * len(shape)
|
2024-01-02 15:52:57 -08:00
|
|
|
if any(idx is ... for idx in indices):
|
|
|
|
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
|
|
|
|
# e.g. x[0, ..., 1].
|
|
|
|
raise NotImplementedError("Ellipsis in indexer not supported yet.")
|
|
|
|
if len(indices) > len(shape):
|
|
|
|
raise ValueError("`indices` must not be longer than `shape`: "
|
|
|
|
f"{indices=}, {shape=}")
|
|
|
|
# Pad out indices with slice(None)
|
|
|
|
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
|
|
|
|
# Convert all `slice`s to `Slice`s
|
|
|
|
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
|
|
|
|
else i for i, s in zip(indices, shape))
|
|
|
|
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
|
|
|
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
|
|
|
indexer_shapes = [core.get_aval(i).shape for i in int_indexers]
|
|
|
|
if indexer_shapes:
|
|
|
|
try:
|
|
|
|
bcast_shape = np.broadcast_shapes(*indexer_shapes)
|
|
|
|
except ValueError as e:
|
|
|
|
# Raise a nicer error than the NumPy one.
|
|
|
|
raise ValueError("Cannot broadcast shapes for indexing: "
|
|
|
|
f"{tuple(a for a in indexer_shapes)}") from e
|
|
|
|
else:
|
|
|
|
bcast_shape = ()
|
|
|
|
# Here we use the `broadcast_to` primitive instead of composing lax
|
|
|
|
# primitives together because it is easier to lower in targets like
|
|
|
|
# Triton/Mosaic.
|
|
|
|
from jax._src.state import primitives as sp # pytype: disable=import-error
|
|
|
|
int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers]
|
|
|
|
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
|
|
|
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
|
|
|
|
|
|
|
|
def get_indexer_shape(self) -> tuple[int, ...]:
|
|
|
|
_, slice_indexers, _ = unpack_ndindexer(self)
|
|
|
|
slice_shape = [s.size for s in slice_indexers]
|
|
|
|
# In NDIndexers, the int_indexer_shape is *always* at the front of the
|
|
|
|
# result.
|
|
|
|
return (*self.int_indexer_shape, *slice_shape)
|