mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[typing] annotate lax.slicing
This commit is contained in:
parent
9cabd227d7
commit
ae9f8eeb0c
@ -14,7 +14,7 @@
|
||||
|
||||
import enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -40,20 +40,18 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.typing import Array, ArrayLike, Shape
|
||||
|
||||
xb = xla_bridge
|
||||
xc = xla_client
|
||||
|
||||
Array = Any
|
||||
Shape = core.Shape
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
|
||||
def slice(operand: Array, start_indices: Sequence[int],
|
||||
def slice(operand: ArrayLike, start_indices: Sequence[int],
|
||||
limit_indices: Sequence[int],
|
||||
strides: Optional[Sequence[int]] = None) -> Array:
|
||||
"""Wraps XLA's `Slice
|
||||
@ -64,7 +62,7 @@ def slice(operand: Array, start_indices: Sequence[int],
|
||||
limit_indices=tuple(limit_indices),
|
||||
strides=None if strides is None else tuple(strides))
|
||||
|
||||
def dynamic_slice(operand: Array, start_indices: Sequence[Array],
|
||||
def dynamic_slice(operand: Array, start_indices: Union[Array, Sequence[ArrayLike]],
|
||||
slice_sizes: Shape) -> Array:
|
||||
"""Wraps XLA's `DynamicSlice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
|
||||
@ -112,8 +110,8 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
|
||||
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
|
||||
slice_sizes=tuple(static_sizes))
|
||||
|
||||
def dynamic_update_slice(operand: Array, update: Array,
|
||||
start_indices: Array) -> Array:
|
||||
def dynamic_update_slice(operand: Array, update: ArrayLike,
|
||||
start_indices: Union[Array, Sequence[ArrayLike]]) -> Array:
|
||||
"""Wraps XLA's `DynamicUpdateSlice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
|
||||
operator.
|
||||
@ -222,7 +220,7 @@ class GatherScatterMode(enum.Enum):
|
||||
raise ValueError(f'Unknown gather mode "{s}"')
|
||||
|
||||
|
||||
def gather(operand: Array, start_indices: Array,
|
||||
def gather(operand: ArrayLike, start_indices: ArrayLike,
|
||||
dimension_numbers: GatherDimensionNumbers,
|
||||
slice_sizes: Shape,
|
||||
*,
|
||||
@ -320,7 +318,7 @@ class ScatterDimensionNumbers(NamedTuple):
|
||||
scatter_dims_to_operand_dims: Sequence[int]
|
||||
|
||||
def scatter_add(
|
||||
operand: Array, scatter_indices: Array, updates: Array,
|
||||
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
|
||||
dimension_numbers: ScatterDimensionNumbers, *,
|
||||
indices_are_sorted: bool = False, unique_indices: bool = False,
|
||||
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
|
||||
@ -367,7 +365,7 @@ def scatter_add(
|
||||
mode=GatherScatterMode.from_any(mode))
|
||||
|
||||
def scatter_mul(
|
||||
operand: Array, scatter_indices: Array, updates: Array,
|
||||
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
|
||||
dimension_numbers: ScatterDimensionNumbers, *,
|
||||
indices_are_sorted: bool = False, unique_indices: bool = False,
|
||||
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
|
||||
@ -414,7 +412,7 @@ def scatter_mul(
|
||||
mode=GatherScatterMode.from_any(mode))
|
||||
|
||||
def scatter_min(
|
||||
operand: Array, scatter_indices: Array, updates: Array,
|
||||
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
|
||||
dimension_numbers: ScatterDimensionNumbers, *,
|
||||
indices_are_sorted: bool = False, unique_indices: bool = False,
|
||||
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
|
||||
@ -461,7 +459,7 @@ def scatter_min(
|
||||
mode=GatherScatterMode.from_any(mode))
|
||||
|
||||
def scatter_max(
|
||||
operand: Array, scatter_indices: Array, updates: Array,
|
||||
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
|
||||
dimension_numbers: ScatterDimensionNumbers, *,
|
||||
indices_are_sorted: bool = False, unique_indices: bool = False,
|
||||
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
|
||||
@ -573,7 +571,7 @@ def scatter_apply(
|
||||
_scatter_reduction_computation = lambda x, y: y
|
||||
|
||||
def scatter(
|
||||
operand: Array, scatter_indices: Array, updates: Array,
|
||||
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
|
||||
dimension_numbers: ScatterDimensionNumbers, *,
|
||||
indices_are_sorted: bool = False, unique_indices: bool = False,
|
||||
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
|
||||
@ -686,10 +684,10 @@ def index_in_dim(operand: Array, index: int, axis: int = 0,
|
||||
return lax.squeeze(result, (axis,))
|
||||
|
||||
|
||||
def dynamic_slice_in_dim(operand: Array, start_index: Array,
|
||||
def dynamic_slice_in_dim(operand: Array, start_index: ArrayLike,
|
||||
slice_size: int, axis: int = 0) -> Array:
|
||||
"""Convenience wrapper around dynamic_slice applying to one dimension."""
|
||||
start_indices = [np.zeros((), dtype=dtypes.dtype(start_index))] * operand.ndim
|
||||
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
|
||||
slice_sizes = list(operand.shape)
|
||||
|
||||
axis = int(axis)
|
||||
@ -708,18 +706,18 @@ def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
|
||||
return lax.squeeze(result, (axis,))
|
||||
|
||||
|
||||
def dynamic_update_slice_in_dim(operand: Array, update: Array,
|
||||
start_index: Array, axis: int) -> Array:
|
||||
def dynamic_update_slice_in_dim(operand: Array, update: ArrayLike,
|
||||
start_index: ArrayLike, axis: int) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
|
||||
in a single ``axis``.
|
||||
"""
|
||||
axis = int(axis)
|
||||
start_indices = [lax._zero(start_index)] * lax._ndim(operand)
|
||||
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
|
||||
start_indices[axis] = start_index
|
||||
return dynamic_update_slice(operand, update, start_indices)
|
||||
|
||||
|
||||
def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
|
||||
def dynamic_update_index_in_dim(operand: Array, update: ArrayLike, index: ArrayLike,
|
||||
axis: int) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
|
||||
of size 1 in a single ``axis``.
|
||||
@ -731,7 +729,6 @@ def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
|
||||
return dynamic_update_slice_in_dim(operand, update, index, axis)
|
||||
|
||||
|
||||
|
||||
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
|
||||
lax._check_shapelike("slice", "start_indices", start_indices)
|
||||
lax._check_shapelike("slice", "limit_indices", limit_indices)
|
||||
@ -2094,26 +2091,30 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
|
||||
|
||||
|
||||
def _dynamic_slice_indices(operand, start_indices: Any):
|
||||
def _dynamic_slice_indices(
|
||||
operand: Array,
|
||||
start_indices: Union[Array, Sequence[ArrayLike]]
|
||||
) -> List[Array]:
|
||||
# Normalize the start_indices w.r.t. operand.shape
|
||||
if len(start_indices) != operand.ndim:
|
||||
msg = ("Length of slice indices must match number of operand dimensions ({} "
|
||||
"vs {})")
|
||||
raise ValueError(msg.format(len(start_indices), operand.shape))
|
||||
if not isinstance(start_indices, (tuple, list)):
|
||||
if start_indices.ndim != 1:
|
||||
if start_indices.ndim != 1: # type: ignore[union-attr]
|
||||
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
||||
.format(start_indices.shape))
|
||||
.format(start_indices.shape)) # type: ignore[union-attr]
|
||||
start_indices = list(start_indices)
|
||||
result = []
|
||||
result: List[Array] = []
|
||||
for i, d in zip(start_indices, operand.shape):
|
||||
# We test whether i and d are static to avoid unnecessary staging.
|
||||
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
|
||||
result.append(lax.convert_element_type(i + d, _dtype(i)) if i < 0 else i)
|
||||
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
|
||||
continue
|
||||
d = core.dimension_as_value(d)
|
||||
if isinstance(i, (int, np.integer)):
|
||||
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
|
||||
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0
|
||||
else lax.convert_element_type(i, _dtype(i)))
|
||||
continue
|
||||
d = lax.convert_element_type(d, _dtype(i))
|
||||
result.append(lax.select(i < 0, i + d, i))
|
||||
|
Loading…
x
Reference in New Issue
Block a user