[typing] annotate lax.slicing

This commit is contained in:
Jake VanderPlas 2022-10-09 04:20:46 -07:00
parent 9cabd227d7
commit ae9f8eeb0c

View File

@ -14,7 +14,7 @@
import enum
from functools import partial
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
import weakref
import numpy as np
@ -40,20 +40,18 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.typing import Array, ArrayLike, Shape
xb = xla_bridge
xc = xla_client
Array = Any
Shape = core.Shape
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
_dtype = partial(dtypes.dtype, canonicalize=True)
def slice(operand: Array, start_indices: Sequence[int],
def slice(operand: ArrayLike, start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `Slice
@ -64,7 +62,7 @@ def slice(operand: Array, start_indices: Sequence[int],
limit_indices=tuple(limit_indices),
strides=None if strides is None else tuple(strides))
def dynamic_slice(operand: Array, start_indices: Sequence[Array],
def dynamic_slice(operand: Array, start_indices: Union[Array, Sequence[ArrayLike]],
slice_sizes: Shape) -> Array:
"""Wraps XLA's `DynamicSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
@ -112,8 +110,8 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
slice_sizes=tuple(static_sizes))
def dynamic_update_slice(operand: Array, update: Array,
start_indices: Array) -> Array:
def dynamic_update_slice(operand: Array, update: ArrayLike,
start_indices: Union[Array, Sequence[ArrayLike]]) -> Array:
"""Wraps XLA's `DynamicUpdateSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
operator.
@ -222,7 +220,7 @@ class GatherScatterMode(enum.Enum):
raise ValueError(f'Unknown gather mode "{s}"')
def gather(operand: Array, start_indices: Array,
def gather(operand: ArrayLike, start_indices: ArrayLike,
dimension_numbers: GatherDimensionNumbers,
slice_sizes: Shape,
*,
@ -320,7 +318,7 @@ class ScatterDimensionNumbers(NamedTuple):
scatter_dims_to_operand_dims: Sequence[int]
def scatter_add(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
@ -367,7 +365,7 @@ def scatter_add(
mode=GatherScatterMode.from_any(mode))
def scatter_mul(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
@ -414,7 +412,7 @@ def scatter_mul(
mode=GatherScatterMode.from_any(mode))
def scatter_min(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
@ -461,7 +459,7 @@ def scatter_min(
mode=GatherScatterMode.from_any(mode))
def scatter_max(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
@ -573,7 +571,7 @@ def scatter_apply(
_scatter_reduction_computation = lambda x, y: y
def scatter(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
@ -686,10 +684,10 @@ def index_in_dim(operand: Array, index: int, axis: int = 0,
return lax.squeeze(result, (axis,))
def dynamic_slice_in_dim(operand: Array, start_index: Array,
def dynamic_slice_in_dim(operand: Array, start_index: ArrayLike,
slice_size: int, axis: int = 0) -> Array:
"""Convenience wrapper around dynamic_slice applying to one dimension."""
start_indices = [np.zeros((), dtype=dtypes.dtype(start_index))] * operand.ndim
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
slice_sizes = list(operand.shape)
axis = int(axis)
@ -708,18 +706,18 @@ def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
return lax.squeeze(result, (axis,))
def dynamic_update_slice_in_dim(operand: Array, update: Array,
start_index: Array, axis: int) -> Array:
def dynamic_update_slice_in_dim(operand: Array, update: ArrayLike,
start_index: ArrayLike, axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
in a single ``axis``.
"""
axis = int(axis)
start_indices = [lax._zero(start_index)] * lax._ndim(operand)
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
start_indices[axis] = start_index
return dynamic_update_slice(operand, update, start_indices)
def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
def dynamic_update_index_in_dim(operand: Array, update: ArrayLike, index: ArrayLike,
axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
of size 1 in a single ``axis``.
@ -731,7 +729,6 @@ def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
return dynamic_update_slice_in_dim(operand, update, index, axis)
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
lax._check_shapelike("slice", "start_indices", start_indices)
lax._check_shapelike("slice", "limit_indices", limit_indices)
@ -2094,26 +2091,30 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
def _dynamic_slice_indices(operand, start_indices: Any):
def _dynamic_slice_indices(
operand: Array,
start_indices: Union[Array, Sequence[ArrayLike]]
) -> List[Array]:
# Normalize the start_indices w.r.t. operand.shape
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
"vs {})")
raise ValueError(msg.format(len(start_indices), operand.shape))
if not isinstance(start_indices, (tuple, list)):
if start_indices.ndim != 1:
if start_indices.ndim != 1: # type: ignore[union-attr]
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape))
.format(start_indices.shape)) # type: ignore[union-attr]
start_indices = list(start_indices)
result = []
result: List[Array] = []
for i, d in zip(start_indices, operand.shape):
# We test whether i and d are static to avoid unnecessary staging.
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
result.append(lax.convert_element_type(i + d, _dtype(i)) if i < 0 else i)
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
continue
d = core.dimension_as_value(d)
if isinstance(i, (int, np.integer)):
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0
else lax.convert_element_type(i, _dtype(i)))
continue
d = lax.convert_element_type(d, _dtype(i))
result.append(lax.select(i < 0, i + d, i))