mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

I initially wanted to upgrade to 1.15, but it seems to have a bug in how ternary expressions are type checked. For example, def f(x: int) -> str: ... def g(x: int) -> str: ... callback = f if ... else g # has type object!
262 lines
9.1 KiB
Python
262 lines
9.1 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from typing import Any, Sequence, Union
|
|
|
|
from jax._src import core
|
|
from jax._src import tree_util
|
|
from jax._src.typing import Array
|
|
from jax._src.util import merge_lists
|
|
from jax._src.util import partition_list
|
|
import numpy as np
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
@dataclasses.dataclass
|
|
class Slice:
|
|
"""A slice with a start index and a size.
|
|
|
|
Both start index and size can either be static, i.e. known at tracing
|
|
and compilation time, or dynamic.
|
|
"""
|
|
|
|
start: int | Array
|
|
size: int | Array
|
|
stride: int = 1
|
|
|
|
def __post_init__(self):
|
|
if self.stride < 1:
|
|
raise ValueError("`stride` must be >= 1.")
|
|
|
|
@property
|
|
def is_dynamic_start(self):
|
|
return not core.is_dim(self.start)
|
|
|
|
@property
|
|
def is_dynamic_size(self):
|
|
return not core.is_dim(self.size)
|
|
|
|
def tree_flatten(self):
|
|
# If `start` is statically known, we treat it as static information
|
|
xs = ()
|
|
data = ()
|
|
xs += (self.start,) if self.is_dynamic_start else (None,)
|
|
data += (None,) if self.is_dynamic_start else (self.start,)
|
|
xs += (self.size,) if self.is_dynamic_size else (None,)
|
|
data += (None,) if self.is_dynamic_size else (self.size,)
|
|
data += (self.stride,)
|
|
return xs, data
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children) -> Slice:
|
|
start, size = (
|
|
a if a is not None else b for a, b in zip(children, aux_data[:2])
|
|
)
|
|
return cls(start, size, aux_data[2])
|
|
|
|
@classmethod
|
|
def from_slice(cls, slc: slice, size: int) -> Slice:
|
|
start, step, size = core.canonicalize_slice(slc, size)
|
|
if step < 1:
|
|
raise ValueError(f"slice must have a step >= 1 (found: {step})")
|
|
return cls(start, size, step)
|
|
|
|
|
|
def dslice(
|
|
start: int | Array | None,
|
|
size: int | Array | None = None,
|
|
stride: int | None = None,
|
|
) -> slice | Slice:
|
|
"""Constructs a ``Slice`` from a start index and a size.
|
|
|
|
The semantics of ``dslice`` mirror those of the builtin ``slice`` type:
|
|
|
|
* ``dslice(None)`` is ``:``
|
|
* ``dslice(j)`` is ``:j``
|
|
* ``dslice(i, j)`` is ``i:i+j``
|
|
* ``dslice(i, j, stride)`` is ``i:i+j:stride``
|
|
"""
|
|
if start is None:
|
|
return slice(None)
|
|
if stride is None:
|
|
stride = 1
|
|
if not isinstance(stride, int):
|
|
raise ValueError("Non-static stride in `dslice`")
|
|
if size is None:
|
|
if not isinstance(start, int):
|
|
raise ValueError("Non-static `dslice`")
|
|
return Slice(0, start, stride)
|
|
return Slice(start, size, stride)
|
|
|
|
|
|
ds = dslice # Handy alias
|
|
|
|
|
|
IntIndexer = Union[int, Array]
|
|
DimIndexer = Union[IntIndexer, Slice]
|
|
|
|
def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
|
|
tuple[Slice, ...],
|
|
tuple[IntIndexer, ...]]:
|
|
is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
|
|
slice_indexers, int_indexers = partition_list(
|
|
is_int_indexing, indexer.indices)
|
|
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
|
|
|
|
def _maybe_concretize(x: Any):
|
|
# This is roughly the same logic as core.concrete_or_error, but we avoid
|
|
# calling that because constructing the ConcretizationTypeError can be
|
|
# expensive as the size of the tracing context (i.e. the jaxpr) grows.
|
|
return core.to_concrete_value(x)
|
|
|
|
@tree_util.register_pytree_node_class
|
|
@dataclasses.dataclass
|
|
class NDIndexer:
|
|
indices: tuple[DimIndexer, ...]
|
|
shape: tuple[int, ...]
|
|
int_indexer_shape: tuple[int, ...]
|
|
# Off by default to avoid doing validation during pytree operations.
|
|
validate: bool = False
|
|
|
|
def __post_init__(self):
|
|
if not self.validate:
|
|
return
|
|
if len(self.indices) != len(self.shape):
|
|
raise ValueError(
|
|
f"`indices` must be the same length as `Ref` shape.: {self}."
|
|
)
|
|
# We validate integer indexing shapes here
|
|
for idx, s in zip(self.indices, self.shape):
|
|
if isinstance(idx, Slice):
|
|
start = idx.start
|
|
if value := _maybe_concretize(start):
|
|
if value >= s:
|
|
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
|
|
if size := _maybe_concretize(idx.size):
|
|
if value + (size - 1) * idx.stride >= s:
|
|
raise ValueError(
|
|
f"Out of bound slice: start={value}, size={size},"
|
|
f" stride={idx.stride}, dim={s}."
|
|
)
|
|
continue
|
|
# The shape of indexer integers should be broadcastable up to the
|
|
# int_indexer_shape of the whole NDIndexer
|
|
if not np.shape(idx):
|
|
if (value := _maybe_concretize(idx)) and value >= s:
|
|
raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
|
|
# For ()-shaped indexers, we can broadcast no problm.
|
|
continue
|
|
# If we don't have a ()-shaped indexer, the rank must match
|
|
# int_indexer_shape
|
|
if np.ndim(idx) != len(self.int_indexer_shape):
|
|
raise ValueError(
|
|
f"Indexer must have rank {np.ndim(idx)}: {idx=} vs."
|
|
f" {self.int_indexer_shape=}"
|
|
)
|
|
# Here we check that the shapes broadcast.
|
|
try:
|
|
np.broadcast_shapes(np.shape(idx), self.int_indexer_shape)
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
f"Could not broadcast integer indexer: {idx=} vs."
|
|
f" {self.int_indexer_shape=}"
|
|
) from e
|
|
|
|
@property
|
|
def is_dynamic_size(self):
|
|
return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices)
|
|
|
|
def tree_flatten(self):
|
|
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
|
|
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, data, flat_idx):
|
|
idx_tree, shape, int_indexer_shape = data
|
|
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
|
|
return cls(tuple(indices), shape, int_indexer_shape)
|
|
|
|
@classmethod
|
|
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
|
if not isinstance(indices, tuple):
|
|
# TODO(slebedev): Consider requiring `indices` to be a Sequence.
|
|
indices = (indices,)
|
|
|
|
if num_ellipsis := sum(idx is ... for idx in indices):
|
|
if num_ellipsis > 1:
|
|
raise ValueError("Only one ellipsis is supported.")
|
|
# Expand ... so that `indices` has the same length as `shape`.
|
|
ip = indices.index(...)
|
|
indices = list(indices)
|
|
indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
|
|
indices = tuple(indices)
|
|
if len(indices) > len(shape):
|
|
raise ValueError("`indices` must not be longer than `shape`: "
|
|
f"{indices=}, {shape=}")
|
|
elif len(indices) < len(shape):
|
|
# Pad `indices` to have the same length as `shape`.
|
|
indices = (*indices, *[slice(None)] * (len(shape) - len(indices)))
|
|
|
|
# Promote all builtin `slice`s to `Slice`.
|
|
indices = tuple(
|
|
Slice.from_slice(i, s) if isinstance(i, slice) else i
|
|
for i, s in zip(indices, shape))
|
|
|
|
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
|
if any(is_int_indexing):
|
|
int_indexers: Sequence[Any]
|
|
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
|
indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers)
|
|
try:
|
|
int_indexer_shape = np.broadcast_shapes(*indexer_shapes)
|
|
except ValueError as e:
|
|
# Raise a nicer error than the NumPy one.
|
|
raise ValueError(
|
|
f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e
|
|
|
|
# Here we use the `broadcast_to` primitive instead of composing lax
|
|
# primitives together because it is easier to lower in targets like
|
|
# Triton/Mosaic.
|
|
#
|
|
# The local import avoids a circular dependency between primitives
|
|
# and this module.
|
|
from jax._src.state import primitives as sp # pytype: disable=import-error
|
|
int_indexers = [
|
|
sp.broadcast_to(i, int_indexer_shape) for i in int_indexers
|
|
]
|
|
indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers))
|
|
else:
|
|
int_indexer_shape = ()
|
|
|
|
return cls(indices, shape, int_indexer_shape, validate=True)
|
|
|
|
def get_indexer_shape(self) -> tuple[int | Array, ...]:
|
|
_, slice_indexers, _ = unpack_ndindexer(self)
|
|
slice_shape = [s.size for s in slice_indexers]
|
|
# In NDIndexers, the int_indexer_shape is *always* at the front of the
|
|
# result.
|
|
return (*self.int_indexer_shape, *slice_shape)
|
|
|
|
def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
|
|
del shape # Unused
|
|
return self.get_indexer_shape()
|
|
|
|
def transform_dtype(self, dtype):
|
|
return dtype
|