From 836563fadfa27c754889de647bfbeff8ab851a92 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 2 Jan 2024 15:52:57 -0800 Subject: [PATCH] [Pallas] Refactor indexing primitives to use NDIndexer abstraction Some notes about this change: * This change upgrades the `RefView` abstraction to store multiple indexers. This allows doing things like `ref.at[0].at[0]` to recursively create a view of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s. * This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p) but does *not* generalize their rules. Most of the rules will raise a NotImplementedError if you use multiple `NDIndexer`s. Adding support will be done in a future CL. * With the above in mind, this change only preserves existing public facing APIs and adding actual support will involve updating the rules. PiperOrigin-RevId: 595229523 --- jax/BUILD | 4 +- jax/_src/pallas/indexing.py | 155 --------- jax/_src/pallas/mosaic/lowering.py | 26 +- jax/_src/pallas/mosaic/primitives.py | 48 +-- jax/_src/pallas/primitives.py | 52 +-- jax/_src/pallas/triton/lowering.py | 81 ++--- jax/_src/state/discharge.py | 200 +++++++---- jax/_src/state/indexing.py | 192 +++++++++++ jax/_src/state/primitives.py | 493 +++++++++++++++------------ jax/_src/state/types.py | 28 +- jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/pallas/__init__.py | 8 +- tests/pallas/indexing_test.py | 37 +- tests/pallas/pallas_test.py | 6 +- tests/state_test.py | 96 +++++- 15 files changed, 819 insertions(+), 608 deletions(-) delete mode 100644 jax/_src/pallas/indexing.py create mode 100644 jax/_src/state/indexing.py diff --git a/jax/BUILD b/jax/BUILD index 02deeb543..1a5b8bf60 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -738,15 +738,17 @@ pytype_strict_library( name = "state_types", srcs = [ "_src/state/__init__.py", + "_src/state/indexing.py", "_src/state/types.py", ], deps = [ ":core", ":effects", ":pretty_printer", + ":tree_util", ":typing", ":util", - ], + ] + py_deps("numpy"), ) pytype_strict_library( diff --git a/jax/_src/pallas/indexing.py b/jax/_src/pallas/indexing.py deleted file mode 100644 index 390db1cbd..000000000 --- a/jax/_src/pallas/indexing.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Contains shared logic and abstractions for Pallas indexing ops.""" - -from __future__ import annotations - -import dataclasses -from typing import Any - -import jax -from jax import core as jax_core -from jax import tree_util -from jax._src.interpreters import mlir -from jax._src.util import merge_lists -from jax._src.util import partition_list -import jax.numpy as jnp -import numpy as np - - -# Currently, JAX doesn't have a primitive that does an equal-rank broadcast. -# We could use `jnp.broadcast_to` but that lowers to squeezing, -# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`) -# so in the lowering, we have to expand out those squeezed dimensions again. -# Having a simple `broadcast_to` primitive allows us to lower directly -# to `tl.broadcast_to`. -broadcast_to_p = jax_core.Primitive('broadcast_to') - -def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array: - if a.shape == shape: - return a - return broadcast_to_p.bind(a, shape=shape) - -@broadcast_to_p.def_impl -def _broadcast_to_impl(a, *, shape): - return jnp.broadcast_to(a, shape) - -@broadcast_to_p.def_abstract_eval -def _broadcast_to_abstract_eval(aval, *, shape): - return jax_core.ShapedArray(shape, aval.dtype) - -mlir.register_lowering( - broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False) -) - - -@tree_util.register_pytree_node_class -@dataclasses.dataclass -class Slice: - """Represents a slice with a dynamic start index and a fixed size.""" - start: Any - size: int - - def __post_init__(self): - if self.size < 0: - raise ValueError("`size` must not be negative.") - - def tree_flatten(self): - # If `start` is statically known, we treat it as static information - if isinstance(self.start, int): - return (), (self.start, self.size) - return (self.start,), (self.size,) - - @classmethod - def tree_unflatten(cls, aux_data, children) -> Slice: - return cls(*children, *aux_data) - - @classmethod - def from_slice(cls, slc: slice, size: int) -> Slice: - start, stop, step = slc.indices(size) - if step != 1: - raise ValueError(f"slice must have a step of 1 (found: {step})") - return cls(start, stop - start) - - -def dslice(start: int | jax.Array | None, size: int | None = None - ) -> slice | Slice: - """Constructs a `Slice` from a start and a size.""" - if start is None: - return slice(None) - if size is None: - if not isinstance(start, int): - raise ValueError("Non-static `dslice`") - return Slice(0, start) - return Slice(start, size) -ds = dslice # Handy alias - -@tree_util.register_pytree_node_class -@dataclasses.dataclass -class NDIndexer: - indices: tuple[int | Slice | jax.Array, ...] - shape: tuple[int, ...] - int_indexer_shape: tuple[int, ...] - - def __post_init__(self): - if len(self.indices) != len(self.shape): - raise ValueError("`indices` must be the same length as `Ref` shape.") - - def tree_flatten(self): - indexed_dims = [not isinstance(idx, slice) for idx in self.indices] - slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices) - flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx) - return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape, - self.int_indexer_shape) - - @classmethod - def tree_unflatten(cls, data, flat_idx): - slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data - non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx) - indices = merge_lists(indexed_dims, slice_idx, non_slice_idx) - return NDIndexer(tuple(indices), shape, int_indexer_shape) - - @classmethod - def from_indices_shape(cls, indices, shape) -> NDIndexer: - if len(indices) > len(shape): - raise ValueError("`indices` must not be longer than `shape`.") - # Pad out indices with slice(None) - indices = [*indices, *[slice(None)] * (len(shape) - len(indices))] - # Convert all `slice`s to `Slice`s - indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice) - else i for i, s in zip(indices, shape)) - is_int_indexing = [not isinstance(i, Slice) for i in indices] - other_indexers, int_indexers = partition_list(is_int_indexing, indices) - int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in - int_indexers] - indexer_shapes = [i.shape for i in int_indexers] - if indexer_shapes: - try: - bcast_shape = np.broadcast_shapes(*indexer_shapes) - except ValueError as e: - # Raise a nicer error than the NumPy one. - raise ValueError("Cannot broadcast shapes for indexing: " - f"{tuple(a for a in indexer_shapes)}") from e - else: - bcast_shape = () - int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers] - indices = merge_lists(is_int_indexing, other_indexers, int_indexers) - return NDIndexer(tuple(indices), shape, bcast_shape) - - def get_indexer_shape(self) -> tuple[int, ...]: - is_int_indexing = [not isinstance(i, Slice) for i in self.indices] - other_indexers, _ = partition_list(is_int_indexing, self.indices) - other_shape = [s.size for s in other_indexers] # type: ignore - return (*self.int_indexer_shape, *other_shape) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b8f920089..9d3787fc3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -42,12 +42,12 @@ from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import core -from jax._src.pallas import indexing from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import primitives as tpu_primitives from jax._src.state import discharge as state_discharge +from jax._src.state import indexing from jax._src.state import primitives as state_primitives from jax._src.util import safe_map from jax._src.util import safe_zip @@ -610,12 +610,16 @@ def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx, def _get_lowering_rule( - ctx: LoweringRuleContext, ref, *non_slice_idx, indexed_dims: Sequence[bool] + ctx: LoweringRuleContext, ref, *idx, tree, ): + indexers = tree_util.tree_unflatten(tree, idx) + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) + if len(indexers) > 1: + raise NotImplementedError("Only one indexer currently supported.") # Call _load_lowering_rule (since it's more general) ref_aval, *non_slice_idx_avals = ctx.avals_in - nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer( - ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims) + nd_indexer = indexers[0] + nd_indexer_avals = indexers_avals[0] args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None)) avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None)) ctx = ctx.replace(avals_in=avals_flat) @@ -630,13 +634,17 @@ def _swap_lowering_rule( ctx: LoweringRuleContext, ref, val, - *non_slice_idx, - indexed_dims: Sequence[bool], + *idx, + tree ): + indexers = tree_util.tree_unflatten(tree, idx) + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:]) + if len(indexers) > 1: + raise NotImplementedError("Only one indexer currently supported.") # Call _masked_swap_lowering_rule (since it's more general) - ref_aval, val_aval, *non_slice_idx_avals = ctx.avals_in - nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer( - ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims) + ref_aval, val_aval, *_ = ctx.avals_in + nd_indexer = indexers[0] + nd_indexer_avals = indexers_avals[0] args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None)) avals_flat = tree_util.tree_leaves( (ref_aval, nd_indexer_avals, val_aval, None) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 1a0eaeca4..1447015c6 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -29,10 +29,10 @@ from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util from jax._src import util -from jax._src.state import primitives as state_primitives +from jax._src.state import indexing +from jax._src.state import primitives as sp from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.pallas import indexing from jax._src.pallas.mosaic import core as tpu_core import jax.numpy as jnp @@ -264,38 +264,6 @@ def _dma_start_abstract_eval(*args, tree, device_id_type): del args, tree, device_id_type return [] -def _pp_slice(slc: indexing.Slice, dim: int, context: jax_core.JaxprPpContext - ) -> str: - start, size = slc.start, slc.size - if isinstance(start, jax_core.Var): - start_str = jax_core.pp_var(start, context) - end_str = f'{start_str}+{size}' - else: - start_str = '' if start == 0 else str(start) - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' - -def _pp_indexer(indexer: indexing.NDIndexer, - context: jax_core.JaxprPpContext) -> pp.Doc: - indices = [] - for idx, dim in zip(indexer.indices, indexer.shape): - if isinstance(idx, indexing.Slice): - indices.append(_pp_slice(idx, dim, context)) - else: - indices.append(jax_core.pp_var(idx, context)) # type: ignore - return pp.text(','.join(indices)) - -def _pp_ref(ref, indexer, context): - return state_primitives.pp_ref( - pp.concat([ - pp.text(jax_core.pp_var(ref, context)), - pp.text("["), - _pp_indexer(indexer, context), - pp.text("]"), - ]) - ) - def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, context: jax_core.JaxprPpContext, settings: jax_core.JaxprPpSettings): @@ -309,9 +277,9 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, return pp.concat([ pp.text('dma_start'), pp.text(' '), - _pp_ref(src_ref, src_indexer, context), + sp.pp_ref_indexers(context, src_ref, (src_indexer,)), pp.text(' -> '), - _pp_ref(dst_ref, dst_indexer, context), + sp.pp_ref_indexers(context, dst_ref, (dst_indexer,)), pp.text(' '), pp.text(jax_core.pp_var(dst_sem, context)), ]) @@ -358,13 +326,14 @@ def _dma_wait_abstract_eval(*args, tree, device_id_type): def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, context: jax_core.JaxprPpContext, settings: jax_core.JaxprPpSettings): + del settings invars = eqn.invars tree = eqn.params["tree"] sem, ref, indexer = tree_util.tree_unflatten(tree, invars) return pp.concat([ pp.text('dma_wait'), pp.text(' '), - _pp_ref(ref, indexer, context), + sp.pp_ref_indexers(context, ref, (indexer,)), pp.text(' '), pp.text(jax_core.pp_var(sem, context)), ]) @@ -373,7 +342,10 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn def _get_ref_and_indexer(ref): if isinstance(ref, state.RefView): - return ref.ref, ref.indexer + indexers = ref.indexers + if len(indexers) > 1: + raise NotImplementedError("Only one indexer supported.") + return ref.ref, ref.indexers[0].indices return ref, (slice(None),) * len(ref.shape) def make_async_copy(src_ref, dst_ref, sem): diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 8ceee708f..c1ab4f8ed 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -27,15 +27,15 @@ from jax._src import core as jax_core from jax._src import pretty_printer as pp from jax._src import state from jax._src.util import (safe_map, safe_zip) -from jax._src.state import primitives as state_primitives from jax._src.state import discharge as state_discharge +from jax._src.state import indexing +from jax._src.state import primitives as sp from jax.interpreters import ad from jax.interpreters import mlir from jax.interpreters import xla import jax.numpy as jnp from jax._src.pallas import core as pallas_core -from jax._src.pallas import indexing # TODO(sharadmv): enable type checking # mypy: ignore-errors @@ -224,44 +224,13 @@ def _load_abstract_eval(*avals_flat, args_tree, **_): load_p.def_effectful_abstract_eval(_load_abstract_eval) -def _pp_dslice(dim: int, slice: Slice, context): - size = pp.text(str(slice.size)) - if isinstance(slice.start, int): - if slice.start == 0: - start = pp.text("") - else: - start = pp.text(str(slice.start)) - if slice.size == dim: - end = pp.text("") - else: - end = pp.text(str(slice.start + slice.size)) - else: - start = pp.text(jax_core.pp_var(slice.start, context)) - end = pp.concat([start, pp.text("+"), size]) - return pp.concat([start, pp.text(":"), end]) - -def _pp_idx(ref_aval, idx: NDIndexer, context): - docs = [ - _pp_dslice(d, s, context) if isinstance(s, Slice) - else pp.text(jax_core.pp_var(s, context)) - for s, d in zip(idx.indices, ref_aval.shape)] - if not docs: - return pp.text("") - doc = [docs[0]] - for d in docs[1:]: - doc.append(pp.text(",")) - doc.append(d) - return pp.concat(doc) - def _load_pp_rule(eqn, context, settings): # Pretty prints `a = load x i` as `x[i] <- a` y, = eqn.outvars x, idx, _, _ = eqn.params["args_tree"].unflatten(eqn.invars) - idx = _pp_idx(eqn.invars[0].aval, idx, context) lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes) - return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([ - pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']') - ]))]) + return pp.concat([ + lhs, pp.text(' <- '), sp.pp_ref_indexers(context, x, (idx,))]) jax_core.pp_eqn_rules[load_p] = _load_pp_rule @@ -339,15 +308,14 @@ def _swap_pp_rule(eqn, context, settings): # Pretty prints `_ = swap x v i` as `x[i] <- v` y, = eqn.outvars x, idx, val, _ = eqn.params["args_tree"].unflatten(eqn.invars) - idx = _pp_idx(eqn.invars[0].aval, idx, context) - x_i = pp.concat([pp.text(jax_core.pp_var(x, context)), - pp.text('['), idx, pp.text(']')]) + x_i = sp.pp_ref_indexers(context, x, (idx,)) if isinstance(y, jax_core.DropVar): - return pp.concat([state_primitives.pp_ref( - x_i), pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]) + return pp.concat([ + x_i, + pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]) y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes) - return pp.concat([y, pp.text(', '), state_primitives.pp_ref(x_i), - pp.text(' <- '), state_primitives.pp_ref(x_i), + return pp.concat([y, pp.text(', '), x_i, + pp.text(' <- '), x_i, pp.text(', '), pp.text(jax_core.pp_var(val, context))]) jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 22ce0bb84..2427a836c 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -39,12 +39,12 @@ from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.lib import hlo_helpers from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core -from jax._src.pallas import indexing from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.pallas_call import pallas_call_p from jax._src.state import AbstractRef from jax._src.state import discharge +from jax._src.state import indexing from jax._src.state import primitives as sp from jax._src.util import merge_lists from jax._src.util import partition_list @@ -438,7 +438,7 @@ _TRITON_FN_MAPPING = { lax.nextafter_p: tl.math.nextafter, ad_util.add_any_p: tl.semantic.add, # Other ops. - indexing.broadcast_to_p: tl.broadcast_to, + sp.broadcast_to_p: tl.broadcast_to, primitives.atomic_cas_p: tl.atomic_cas, primitives.max_contiguous_p: tl.max_contiguous, primitives.multiple_of_p: tl.multiple_of, @@ -727,28 +727,16 @@ def _pack_indices(non_slice_idx, indexed_dims): def _get_lowering_rule( - ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims + ctx: TritonLoweringRuleContext, ptr, *idx, tree ): + indexers = tree_util.tree_unflatten(tree, idx) if not isinstance(ptr.type, tl.pointer_type): - assert not non_slice_idx + assert len(indexers) == 0 return ptr - - ref_aval, *idx_avals = ctx.avals_in - idx_avals = _pack_indices(idx_avals, indexed_dims) - if non_slice_idx: - (int_indexer_shape,) = { - i.shape for i in idx_avals if not isinstance(i, slice) - } - else: - int_indexer_shape = () - - idx = _pack_indices(non_slice_idx, indexed_dims) - idx = tuple( - primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc - for s, slc in zip(ref_aval.shape, idx) - ) - idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape) - args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None)) + if len(indexers) > 1: + raise NotImplementedError("No support for multiple indexers yet.") + indexer = indexers[0] + args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, None, None)) return _masked_load_lowering_rule( ctx, *args_flat, @@ -794,24 +782,16 @@ triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule def _swap_lowering_rule( - ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims + ctx: TritonLoweringRuleContext, ptr, value, *idx, tree ): - ref_aval, _, *idx_avals = ctx.avals_in - idx_avals = _pack_indices(idx_avals, indexed_dims) - if non_slice_idx: - (int_indexer_shape,) = { - i.shape for i in idx_avals if not isinstance(i, slice) - } - else: - int_indexer_shape = () - - idx = _pack_indices(non_slice_idx, indexed_dims) - idx = tuple( - primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc - for s, slc in zip(ref_aval.shape, idx) - ) - idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape) - args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None)) + indexers = tree_util.tree_unflatten(tree, idx) + if not isinstance(ptr.type, tl.pointer_type): + assert len(indexers) == 0 + return ptr + if len(indexers) > 1: + raise NotImplementedError("No support for multiple indexers yet.") + indexer = indexers[0] + args_flat, args_tree = tree_util.tree_flatten((ptr, indexer, value, None)) return _masked_swap_lowering_rule( ctx, *args_flat, args_tree=args_tree, eviction_policy=None ) @@ -850,24 +830,17 @@ triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule def _addupdate_lowering_rule( - ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims + ctx: TritonLoweringRuleContext, ptr, value, *idx, tree ): - ref_block_info, *_ = ctx.block_infos - avals_in = ctx.avals_in - idx = _pack_indices(non_slice_idx, indexed_dims) - if non_slice_idx: - (int_indexer_shape,) = { - tuple(map(lambda x: x.value, i.shape)) for i in non_slice_idx - } - else: - int_indexer_shape = () - idx = tuple( - primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc - for s, slc in zip(avals_in[0].shape, idx) - ) - idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape) + indexers = tree_util.tree_unflatten(tree, idx) + if not isinstance(ptr.type, tl.pointer_type): + assert len(indexers) == 0 + return ptr + if len(indexers) > 1: + raise NotImplementedError("No support for multiple indexers yet.") + indexer = indexers[0] ptr = _compute_pointers_from_indices( - ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder + ptr, ctx.block_infos[0], indexer, ctx.avals_in[0].shape, ctx.builder ) tl.atomic_add(ptr, value, _builder=ctx.builder) return [] diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index c63a2e9b7..32576d988 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -34,9 +34,11 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax from jax._src.lax import slicing as lax_slicing +from jax._src.state import indexing from jax._src.state.types import AbstractRef, RefEffect from jax._src.state.primitives import get_p, swap_p, addupdate_p from jax._src.state.utils import hoist_consts_to_refs +from jax._src.typing import Array from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache, partition_list, merge_lists, split_dict) @@ -144,34 +146,112 @@ def _eval_jaxpr_discharge_state( env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge]) return out_vals + ref_vals +def _is_trivial_indexer(indexer: indexing.NDIndexer): + for s, idx in zip(indexer.shape, indexer.indices): + if not isinstance(idx, indexing.Slice): + return False + if not isinstance(idx.start, int): + return False + if idx.start: + return False + if idx.size != s: + return False + return True + +def _convert_to_array_indexer(indexer: indexing.NDIndexer + ) -> tuple[int | Array, ...]: + # This is the general gather case. We need to create the gather arrays. + is_integer_indexer, _, integer_indexer = ( + indexing.unpack_ndindexer(indexer) + ) + total_shape = indexer.get_indexer_shape() + int_indexer_shape = indexer.int_indexer_shape + slice_shape = total_shape[len(int_indexer_shape):] + slice_dims = tuple( + i + len(int_indexer_shape) for i in range(len(slice_shape)) + ) + slice_dim_iter = iter(slice_dims) + slice_indexer: list[Array] = [] + for idx, is_int_index in zip(indexer.indices, is_integer_indexer): + if not is_int_index: + assert isinstance(idx, indexing.Slice) + slice_indices = lax.broadcasted_iota( + np.dtype("int32"), total_shape, next(slice_dim_iter) + ) + idx.start + slice_indexer.append(slice_indices) + integer_indexer = tuple( + lax.expand_dims(idx, (-1,)) for idx in integer_indexer + ) + continue + assert next(slice_dim_iter, None) is None + return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + + +def _maybe_convert_to_dynamic_slice( + indexer: indexing.NDIndexer, +) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None: + # An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice` + # if each of the indexers is a `Slice` or a ()-shaped value. + if not all(isinstance(i, indexing.Slice) or not np.shape(i) + for i in indexer.indices): + return None + _convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32")) + starts = tuple( + _convert_i32(i.start) if isinstance(i, indexing.Slice) + else _convert_i32(i) for i in indexer.indices + ) + sizes = tuple( + i.size if isinstance(i, indexing.Slice) else 1 for i in indexer.indices + ) + squeeze_dims = tuple( + i + for i, idx in enumerate(indexer.indices) + if not isinstance(idx, indexing.Slice) + ) + return starts, sizes, squeeze_dims + + @register_discharge_rule(get_p) def _get_discharge_rule( in_avals: Sequence[core.AbstractValue], - out_avals: Sequence[core.AbstractValue], x, *non_slice_idx, - indexed_dims: Sequence[bool]): + out_avals: Sequence[core.AbstractValue], x, *idx, + tree): del in_avals, out_avals - y = _get_discharge(x, non_slice_idx, indexed_dims) - return (None,) * (len(non_slice_idx) + 1), y + y = _get_discharge(x, idx, tree) + return (None,) * (len(idx) + 1), y -def _get_discharge(x, idx, indexed_dims): - if not any(indexed_dims): - return x - if all(not i.shape for i in idx): - return _dynamic_index(x, idx, indexed_dims) - else: - return _prepend_gather(x, idx, indexed_dims) - -def _prepend_gather(x, idx, indexed_dims): - indexer = _indexer(idx, indexed_dims) +def _prepend_gather(x, indexer): # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. return x[None][(np.array(0, 'int32'), *indexer)] -def _prepend_scatter(x, idx, indexed_dims, val, *, add=False): - indexer = _indexer(idx, indexed_dims) +def _prepend_scatter(x, indexer, val, *, add=False): + # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. + # However, since this is scatter, we need to remove the 1-sized dimension + # we added at the front. if add: return x[None].at[(0, *indexer)].add(val)[0] return x[None].at[(0, *indexer)].set(val)[0] + +def _get_discharge(x, idx, tree): + indexers = tree_util.tree_unflatten(tree, idx) + if len(indexers) > 1: + raise NotImplementedError("Only single indexer is supported.") + indexer = indexers[0] + if _is_trivial_indexer(indexer): + return x + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the the 1-sized slices at the end. + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + y = lax_slicing.dynamic_slice(x, starts, sizes) + return lax.squeeze(y, squeeze_dims) + indexer = _convert_to_array_indexer(indexer) + if indexer is None: + return x + return x[None][(np.array(0, 'int32'), *indexer)] + def _indexer(idx, indexed_dims): idx_ = iter(idx) indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims) @@ -181,59 +261,59 @@ def _indexer(idx, indexed_dims): @register_discharge_rule(swap_p) def _swap_discharge_rule( in_avals: Sequence[core.AbstractValue], - out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx, - indexed_dims: Sequence[bool]): + out_avals: Sequence[core.AbstractValue], x, val, *idx, + tree): del in_avals, out_avals - if not any(indexed_dims): - z, x_new = x, val - z, x_new = _swap_discharge(x, val, non_slice_idx, indexed_dims) - return (x_new, None) + (None,) * len(non_slice_idx), z + z, x_new = _swap_discharge(x, val, idx, tree) + return (x_new, None) + (None,) * len(idx), z -def _swap_discharge(x, val, idx, indexed_dims): - if not any(indexed_dims): - z, x_new = x, val - elif all(not i.shape for i in idx): - z = _dynamic_index(x, idx, indexed_dims) - x_new = _dynamic_update_index(x, idx, val, indexed_dims) - else: - z = _prepend_gather(x, idx, indexed_dims) - x_new = _prepend_scatter(x, idx, indexed_dims, val) - return z, x_new +def _swap_discharge(x, val, idx, tree): + indexers = tree_util.tree_unflatten(tree, idx) + if len(indexers) > 1: + raise NotImplementedError("Only single indexer is supported.") + indexer = indexers[0] + if _is_trivial_indexer(indexer): + return x, val + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the the 1-sized slices at the end. + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + x_old = lax_slicing.dynamic_slice(x, starts, sizes) + val = lax.expand_dims(val, squeeze_dims) + y = lax_slicing.dynamic_update_slice(x, val, starts) + return lax.squeeze(x_old, squeeze_dims), y + indexer = _convert_to_array_indexer(indexer) + x_old = _prepend_gather(x, indexer) + return x_old, _prepend_scatter(x, indexer, val) @register_discharge_rule(addupdate_p) def _addupdate_discharge_rule( in_avals: Sequence[core.AbstractValue], - out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx, - indexed_dims: Sequence[bool]): + out_avals: Sequence[core.AbstractValue], x, val, *idx, + tree): del in_avals, out_avals - ans = _addupdate_discharge(x, val, non_slice_idx, indexed_dims) - return (ans, None) + (None,) * len(non_slice_idx), [] + ans = _addupdate_discharge(x, val, idx, tree) + return (ans, None) + (None,) * len(idx), [] -def _addupdate_discharge(x, val, idx, indexed_dims): - if not any(indexed_dims): +def _addupdate_discharge(x, val, idx, tree): + indexers = tree_util.tree_unflatten(tree, idx) + if len(indexers) > 1: + raise NotImplementedError("Only single indexer is supported.") + indexer = indexers[0] + if _is_trivial_indexer(indexer): return x + val - if all(not i.shape for i in idx): - y = val + _dynamic_index(x, idx, indexed_dims) - return _dynamic_update_index(x, idx, y, indexed_dims) - else: - return _prepend_scatter(x, idx, indexed_dims, val, add=True) - -def _dynamic_index(x, idx, indexed_dims): - assert isinstance(idx, (list, tuple)) and idx - idx_ = iter(idx) - starts = [next(idx_) if b else np.int32(0) for b in indexed_dims] - assert next(idx_, None) is None - sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)] - out = lax_slicing.dynamic_slice(x, starts, sizes) - return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b]) - -def _dynamic_update_index(x, idx, val, indexed_dims): - assert isinstance(idx, (list, tuple)) and idx - idx_ = iter(idx) - starts = [next(idx_) if b else np.int32(0) for b in indexed_dims] - assert next(idx_, None) is None - sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)] - return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts) + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the the 1-sized slices at the end. + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + x_old = lax_slicing.dynamic_slice(x, starts, sizes) + val = lax.expand_dims(val, squeeze_dims) + y = lax_slicing.dynamic_update_slice(x, x_old + val, starts) + return y + indexer = _convert_to_array_indexer(indexer) + return _prepend_scatter(x, indexer, val, add=True) @weakref_lru_cache def _cached_closed_jaxpr_discharge(closed_jaxpr): diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py new file mode 100644 index 000000000..a4cd4a140 --- /dev/null +++ b/jax/_src/state/indexing.py @@ -0,0 +1,192 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains shared logic and abstractions for Pallas indexing ops.""" + +from __future__ import annotations + +import dataclasses +from typing import Any, Union + +from jax._src import core +from jax._src import tree_util +from jax._src.typing import Array +from jax._src.util import merge_lists +from jax._src.util import partition_list +import numpy as np + + +@tree_util.register_pytree_node_class +@dataclasses.dataclass +class Slice: + """Represents a slice with a dynamic start index and a fixed size.""" + start: Any + size: int + + def __post_init__(self): + if self.size < 0: + raise ValueError("`size` must not be negative.") + + def tree_flatten(self): + # If `start` is statically known, we treat it as static information + if isinstance(self.start, int): + return (), (self.start, self.size) + return (self.start,), (self.size,) + + @classmethod + def tree_unflatten(cls, aux_data, children) -> Slice: + return cls(*children, *aux_data) + + @classmethod + def from_slice(cls, slc: slice, size: int) -> Slice: + start, stop, step = slc.indices(size) + if step != 1: + raise ValueError(f"slice must have a step of 1 (found: {step})") + return cls(start, stop - start) + + +def dslice(start: int | Array | None, size: int | None = None + ) -> slice | Slice: + """Constructs a `Slice` from a start and a size.""" + if start is None: + return slice(None) + if size is None: + if not isinstance(start, int): + raise ValueError("Non-static `dslice`") + return Slice(0, start) + return Slice(start, size) +ds = dslice # Handy alias + + +IntIndexer = Union[int, Array] +DimIndexer = Union[IntIndexer, Slice] + +def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...], + tuple[Slice, ...], + tuple[IntIndexer, ...]]: + is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices] + slice_indexers, int_indexers = partition_list( + is_int_indexing, indexer.indices) + return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore + +def _maybe_concretize(x: Any): + try: + return core.concrete_or_error(None, x) + except core.ConcretizationTypeError: + return None + +@tree_util.register_pytree_node_class +@dataclasses.dataclass +class NDIndexer: + indices: tuple[DimIndexer, ...] + shape: tuple[int, ...] + int_indexer_shape: tuple[int, ...] + validate: bool = False + + def __post_init__(self): + if not self.validate: + return + if len(self.indices) != len(self.shape): + raise ValueError( + f"`indices` must be the same length as `Ref` shape.: {self}." + ) + # We validate integer indexing shapes here + for idx, s in zip(self.indices, self.shape): + if isinstance(idx, Slice): + start = idx.start + if value := _maybe_concretize(start): + if value >= s: + raise ValueError(f"Out of bound slice: start={value}, dim={s}.") + if value + idx.size > s: + raise ValueError( + f"Out of bound slice: start={value}, size={idx.size}, dim={s}." + ) + continue + # The shape of indexer integers should be broadcastable up to the + # int_indexer_shape of the whole NDIndexer + if not np.shape(idx): + if (value := _maybe_concretize(idx)) and value >= s: + raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.") + # For ()-shaped indexers, we can broadcast no problm. + continue + # If we don't have a ()-shaped indexer, the rank must match + # int_indexer_shape + if np.ndim(idx) != len(self.int_indexer_shape): + raise ValueError( + f"Indexer must have rank {np.ndim(idx)}: {idx=} vs." + f" {self.int_indexer_shape=}" + ) + # Here we check that the shapes broadcast. + try: + np.broadcast_shapes(np.shape(idx), self.int_indexer_shape) + except ValueError as e: + raise ValueError( + f"Could not broadcast integer indexer: {idx=} vs." + f" {self.int_indexer_shape=}" + ) from e + + def tree_flatten(self): + flat_idx, idx_tree = tree_util.tree_flatten(self.indices) + return flat_idx, (idx_tree, self.shape, self.int_indexer_shape) + + @classmethod + def tree_unflatten(cls, data, flat_idx): + idx_tree, shape, int_indexer_shape = data + indices = tree_util.tree_unflatten(idx_tree, flat_idx) + return NDIndexer(tuple(indices), shape, int_indexer_shape) + + @classmethod + def from_indices_shape(cls, indices, shape) -> NDIndexer: + if indices == ...: + indices = (slice(None),) * len(shape) + if not isinstance(indices, tuple): + indices = (indices,) + if any(idx is ... for idx in indices): + # TODO(sharadmv,mattjj): support patterns that include ellipsis in them + # e.g. x[0, ..., 1]. + raise NotImplementedError("Ellipsis in indexer not supported yet.") + if len(indices) > len(shape): + raise ValueError("`indices` must not be longer than `shape`: " + f"{indices=}, {shape=}") + # Pad out indices with slice(None) + indices = [*indices, *[slice(None)] * (len(shape) - len(indices))] + # Convert all `slice`s to `Slice`s + indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice) + else i for i, s in zip(indices, shape)) + is_int_indexing = [not isinstance(i, Slice) for i in indices] + other_indexers, int_indexers = partition_list(is_int_indexing, indices) + indexer_shapes = [core.get_aval(i).shape for i in int_indexers] + if indexer_shapes: + try: + bcast_shape = np.broadcast_shapes(*indexer_shapes) + except ValueError as e: + # Raise a nicer error than the NumPy one. + raise ValueError("Cannot broadcast shapes for indexing: " + f"{tuple(a for a in indexer_shapes)}") from e + else: + bcast_shape = () + # Here we use the `broadcast_to` primitive instead of composing lax + # primitives together because it is easier to lower in targets like + # Triton/Mosaic. + from jax._src.state import primitives as sp # pytype: disable=import-error + int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers] + indices = merge_lists(is_int_indexing, other_indexers, int_indexers) + return NDIndexer(tuple(indices), shape, bcast_shape, validate=True) + + def get_indexer_shape(self) -> tuple[int, ...]: + _, slice_indexers, _ = unpack_ndindexer(self) + slice_shape = [s.size for s in slice_indexers] + # In NDIndexers, the int_indexer_shape is *always* at the front of the + # result. + return (*self.int_indexer_shape, *slice_shape) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index b5463e71b..93c389189 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -23,14 +23,17 @@ import numpy as np from jax._src import ad_util from jax._src import core from jax._src import pretty_printer as pp +from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import mlir from jax._src.lax import lax from jax._src.typing import Array -from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, +from jax._src.state import indexing +from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect, AccumEffect) -from jax._src.util import safe_map, safe_zip, tuple_insert +from jax._src.util import safe_map, safe_zip ## General utilities @@ -51,41 +54,14 @@ zip, unsafe_zip = safe_zip, zip # a:f32[3] <- x[] get_p = core.Primitive("get") -def _get_impl(ref: AbstractRef, *idx: int, **_): - del ref, idx +def _get_impl(ref: AbstractRef, *args: Any, tree): + del ref, args, tree raise ValueError("Cannot run stateful primitive.") get_p.def_impl(_get_impl) Indexer = tuple[Union[int, slice, Array], ...] # or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType) -def _is_trivial_indexer(idx: Indexer) -> bool: - if idx is ...: - return True - if type(idx) is tuple: - if len(idx) == 0: - return True - return len(idx) == 1 and idx[0] is ... - return False - -def _unpack_idx(idx: Indexer, ndim: int - ) -> tuple[tuple[Array, ...], tuple[bool, ...]]: - if _is_trivial_indexer(idx): - idx = tuple(slice(None) for _ in range(ndim)) - indexed_dims_ = [] - non_slice_idx = [] - for i in idx: - if isinstance(i, slice): - if i.start is not None or i.stop is not None or i.step is not None: - raise NotImplementedError("Reference indexing only supports trivial slices") - indexed_dims_.append(False) - else: - non_slice_idx.append(i) - indexed_dims_.append(True) - indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_)) - import jax.numpy as jnp - return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims)) - def _get_slice_output_shape(in_shape: tuple[int, ...], idx_shapes: tuple[tuple[int, ...], ...], indexed_dims: tuple[bool, ...]) -> tuple[int, ...]: @@ -95,24 +71,30 @@ def _get_slice_output_shape(in_shape: tuple[int, ...], shape = (*shape_prefix, *shape_suffix) return shape -def _get_indexer(ref: AbstractRef, idx: Indexer - ) -> tuple[Indexer, tuple[bool, ...]]: - if isinstance(ref.inner_aval, core.ShapedArray): - non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim) - else: - if not _is_trivial_indexer(idx): - raise ValueError( - f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.") - non_slice_idx, indexed_dims = (), () - return non_slice_idx, indexed_dims -def ref_get(ref: Any, idx: Indexer) -> Array: - """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" +def _get_indexer( + ref_or_view: Any, idx: Indexer | None, function_name: str +) -> tuple[Any, tuple[indexing.NDIndexer, ...]]: + if isinstance(ref_or_view, RefView): + ref, indexers = ref_or_view.ref, ref_or_view.indexers + else: + ref, indexers = ref_or_view, () ref_aval = core.get_aval(ref) if not isinstance(ref_aval, AbstractRef): - raise ValueError(f"Can only call `get` on a `Ref`: {ref}") - non_slice_idx, indexed_dims = _get_indexer(ref, idx) - return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims) + raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.") + if not isinstance(ref_aval.inner_aval, core.ShapedArray): + return ref, () + if idx is None: + return ref, indexers + nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape) + return ref, (*indexers, nd_indexer) + + +def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: + """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" + ref, indexers = _get_indexer(ref_or_view, idx, "ref_get") + flat_indexers, tree = tree_util.tree_flatten(indexers) + return get_p.bind(ref, *flat_indexers, tree=tree) # `swap` mutates a `Ref`, setting its value and returns its previous value. # b = swap_p.bind(x, a) @@ -132,22 +114,21 @@ def ref_get(ref: Any, idx: Indexer) -> Array: # x:Ref{f32[3]}[i, j] <- a swap_p = core.Primitive("swap") -def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_): - del ref, value, idx +def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree): + del ref, value, idx, tree raise ValueError("Cannot run stateful primitive.") swap_p.def_impl(_swap_impl) -def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array: +def ref_swap(ref_or_view: AbstractRef, idx: Indexer | None, value: Array, + *, _function_name: str = "ref_swap") -> Array: """Sets a `Ref`'s value and returns the original value.""" - ref_aval = core.get_aval(ref) - if not isinstance(ref_aval, AbstractRef): - raise ValueError(f"Can only call `swap` on a `Ref`: {ref}") - non_slice_idx, indexed_dims = _get_indexer(ref, idx) - return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims) + ref, indexers = _get_indexer(ref_or_view, idx, _function_name) + flat_indexers, tree = tree_util.tree_flatten(indexers) + return swap_p.bind(ref, value, *flat_indexers, tree=tree) -def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None: +def ref_set(ref: AbstractRef, idx: Indexer | None, value: Array) -> None: """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" - ref_swap(ref, idx, value) + ref_swap(ref, idx, value, _function_name="ref_set") # `addupdate_p` mutates a `Ref`, adding a value to its existing value. # Semantically, @@ -163,38 +144,40 @@ def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None: addupdate_p = core.Primitive('addupdate') addupdate_p.multiple_results = True -def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int): - del ref, idx, value +def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree): + del ref, value, args, tree raise ValueError("Can't evaluate `addupdate` outside a stateful context.") addupdate_p.def_impl(_addupdate_impl) -def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None: +def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None: """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" - ref_aval = core.get_aval(ref) - if not isinstance(ref_aval, AbstractRef): - raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}") - non_slice_idx, indexed_dims = _get_indexer(ref, idx) - return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims) + ref, indexers = _get_indexer(ref_or_view, idx, "ref_addupdate") + flat_indexers, tree = tree_util.tree_flatten(indexers) + return addupdate_p.bind(ref, x, *flat_indexers, tree=tree) ## get/set/addupdate abstract evaluation rules -def _get_abstract_eval(ref_aval: AbstractRef, *idx, - indexed_dims): + +def _shape_after_indexing( + shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...] +) -> tuple[int, ...]: + for indexer in indexers: + # Run some simple checks that all the indexers have consistent shapes + assert indexer.shape == shape, (indexer.shape, shape) + shape = indexer.get_indexer_shape() + return shape + + +def _get_abstract_eval(ref_aval: AbstractRef, *args, + tree): + indexers = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - if not isinstance(ref_aval.inner_aval, core.ShapedArray): - raise ValueError("`get` with nontrivial indexing must be called " - f"on `ShapedArray` `Ref`: {ref_aval}.") - if len(indexed_dims) != len(ref_aval.shape): - raise ValueError("`indexed_dims` must be the same length as `Ref` shape.") - if sum(indexed_dims) != len(idx): - raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}") - idx_shapes = tuple(i.shape for i in idx) - shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims) - out_aval = ref_aval.inner_aval.update(shape=shape) + out_shape = _shape_after_indexing(ref_aval.shape, indexers) + out_aval = ref_aval.inner_aval.update(shape=out_shape) else: - if idx: + if indexers: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {ReadEffect(0)}) @@ -202,34 +185,29 @@ get_p.def_effectful_abstract_eval(_get_abstract_eval) def _swap_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, - *idx: core.ShapedArray, indexed_dims: tuple[bool]): + *args: Any, tree): + indexers = tree_util.tree_unflatten(tree, args) out_aval: core.AbstractValue if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - if len(indexed_dims) != len(ref_aval.shape): - raise ValueError("`indexed_dims` must be the same length as `Ref` shape.") - if sum(indexed_dims) != len(idx): - raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}") val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) - idx_shapes = tuple(i.shape for i in idx) - expected_output_shape = _get_slice_output_shape( - ref_aval.shape, idx_shapes, indexed_dims) - if expected_output_shape != val_aval.shape: + expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers) + if expected_out_shape != val_aval.shape: raise ValueError("Invalid shape for `swap`. " f"Ref shape: {ref_aval.shape}. " + f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " - f"Indices: {idx}. ") + f"Indices: {indexers}. ") if ref_aval.dtype != val_aval.dtype: raise ValueError("Invalid dtype for `swap`. " f"Ref dtype: {ref_aval.dtype}. " f"Value shape: {val_aval.dtype}. ") - out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype) + out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype) else: - if idx: - raise ValueError("`swap` with nontrivial indexing must be called " - f"on `ShapedArray` `Ref`: {ref_aval}.") + if indexers: + raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {WriteEffect(0)}) swap_p.def_effectful_abstract_eval(_swap_abstract_eval) @@ -237,93 +215,118 @@ swap_p.def_effectful_abstract_eval(_swap_abstract_eval) def _addupdate_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, - *idx: core.ShapedArray, indexed_dims: tuple[bool]): + *args: Any, tree): + indexers = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") - if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray): - raise ValueError("`addupdate` with nontrivial indexing must be called " - f"on `ShapedArray` `Ref`: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - if len(indexed_dims) != len(ref_aval.shape): - raise ValueError("`indexed_dims` must be the same length as `Ref` shape.") - if sum(indexed_dims) != len(idx): - raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}") val_aval = core.raise_to_shaped(val_aval) + slice_shape = _shape_after_indexing(ref_aval.shape, indexers) assert isinstance(val_aval, core.ShapedArray) - idx_shapes = tuple(i.shape for i in idx) - slice_shape = _get_slice_output_shape( - ref_aval.shape, idx_shapes, indexed_dims) if slice_shape != val_aval.shape: raise ValueError("Invalid shape for `addupdate`. " f"Ref shape: {ref_aval.shape}. " + f"Slice shape: {slice_shape}. " f"Value shape: {val_aval.shape}. " - f"Indices: {idx}. ") + f"Indices: {indexers}. ") if ref_aval.dtype != val_aval.dtype: raise ValueError("Invalid dtype for `addupdate`. " f"Ref dtype: {ref_aval.dtype}. " f"Value shape: {val_aval.dtype}. ") - elif idx: - raise ValueError("`addupdate` with nontrivial indexing must be called " - f"on `ShapedArray` `Ref`: {ref_aval}.") + else: + # Check that the indexers are valid + if indexers: + raise ValueError("Cannot index non-shaped array with nontrivial indices.") return [], {AccumEffect(0)} addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) ## Pretty printing for `get` and `swap` in jaxprs -pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL, +pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL, foreground=pp.Color.GREEN) -def _pp_idx(context, non_slice_idx, indexed_dims): - idx_iter = iter(non_slice_idx) - idx = ','.join(core.pp_var(next(idx_iter), context) if indexed else ':' - for indexed in indexed_dims) - assert next(idx_iter, None) is None - return pp.text(idx) +def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice + ) -> str: + start, size = slc.start, slc.size + if isinstance(start, core.Var): + start_str = core.pp_var(start, context) + end_str = f'{start_str}+{size}' + else: + start_str = '' if start == 0 else str(start) + end = start + size + end_str = '' if end == dim else str(end) + return f'{start_str}:{end_str}' + +def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer + ) -> pp.Doc: + indices = [] + for idx, dim in zip(indexer.indices, indexer.shape): + if isinstance(idx, indexing.Slice): + indices.append(_pp_slice(context, dim, idx)) + else: + indices.append(core.pp_var(idx, context)) # type: ignore + return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) + +def _pp_indexers( + context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...], +): + return pp.concat( + [pp_indexer(context, indexer) for indexer in indexers] + ) + +def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers): + return pp_ref_var( + pp.concat([ + pp.text(core.pp_var(ref, context)), + _pp_indexers(context, indexers), + ]) + ) def _get_pp_rule(eqn, context, settings) -> pp.Doc: # Pretty prints `a = get x i` as `x[i] <- a` y, = eqn.outvars - x, *idx = eqn.invars - idx = _pp_idx(context, idx, eqn.params["indexed_dims"]) + x, *flat_idx = eqn.invars + indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) - # TODO more general get - return pp.concat([lhs, pp.text(' <- '), pp_ref(pp.concat([ - pp.text(core.pp_var(x, context)), pp.text('['), idx, pp.text(']')]))]) + return pp.concat([ + lhs, + pp.text(' <- '), + pp_ref_indexers(context, x, indexers) + ]) core.pp_eqn_rules[get_p] = _get_pp_rule def _swap_pp_rule(eqn, context, settings) -> pp.Doc: y, = eqn.outvars - x, v, *idx = eqn.invars - idx = _pp_idx(context, idx, eqn.params["indexed_dims"]) + x, v, *flat_idx = eqn.invars + indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) if type(y) is core.DropVar: # In the case of a set (ignored return value), # pretty print `_ = swap x v i` as `x[i] <- v` del y return pp.concat([ - pp_ref(pp.concat([ - pp.text(core.pp_var(x, context)), - pp.text('['), idx, pp.text(']') - ])), pp.text(' <- '), pp.text(core.pp_var(v, context))]) + pp_ref_indexers(context, x, indexers), + pp.text(' <- '), + pp.text(core.pp_var(v, context)) + ]) else: # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` - x_i = pp.concat([pp.text(core.pp_var(x, context)), - pp.text('['), idx, pp.text(']')]) + x_i = pp_ref_indexers(context, x, indexers) y = core.pp_vars([y], context, print_shapes=settings.print_shapes) - return pp.concat([y, pp.text(', '), pp_ref(x_i), pp.text(' <- '), - pp_ref(x_i), pp.text(', '), + return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '), + x_i, pp.text(', '), pp.text(core.pp_var(v, context))]) core.pp_eqn_rules[swap_p] = _swap_pp_rule def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc: + del settings # pretty-print ` = addupdate x i v` as `x[i] += v` () = eqn.outvars - x, v, *idx = eqn.invars - idx = _pp_idx(context, idx, eqn.params["indexed_dims"]) + x, v, *flat_idx = eqn.invars + indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) return pp.concat([ - pp_ref(pp.concat([ - pp.text(core.pp_var(x, context)), - pp.text('['), idx, pp.text(']') - ])), pp.text(' += '), pp.text(core.pp_var(v, context))]) + pp_ref_indexers(context, x, indexers), + pp.text(' += '), + pp.text(core.pp_var(v, context))]) core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule ## get/swap/addupdate JVP rules @@ -366,6 +369,7 @@ def _get_transpose(g, ref, *idx, **params): ad.primitive_transposes[get_p] = _get_transpose def _swap_transpose(g, ref, x, *idx, **params): + del x # old value doesn't matter anymore # swap transpose is swap x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params) return [None, x_bar] + [None] * len(idx) @@ -403,101 +407,162 @@ def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int, num_idxs_to_left = sum(indexed_dims[:ref_dim]) return ref_dim - num_idxs_to_left + len(idxs_shape) -def _get_vmap(batched_args, batched_dims, *, indexed_dims): +def _batch_indexer(indexer: indexing.NDIndexer, dims, + axis_size: int, + ref_shape: tuple[int, ...], + ref_dim: int | batching.NotMapped, + idx_is_batched: bool) -> indexing.NDIndexer: + indices = indexer.indices + indices_dims = dims.indices + new_indices: list[Array | indexing.Slice | int] = [] + new_integer_indexer_shape = (axis_size, *indexer.int_indexer_shape) + for idx, dim in zip(indices, indices_dims): + if idx_is_batched: + # If at least one of the idx is batched, we broadcast them all and move the + # batch dim to the front. + if isinstance(idx, indexing.Slice): + # size is static, but start can be dynamic + # Check if start is static (which it can be) + is_static_slice = len(tree_util.tree_leaves(idx)) == 0 + if is_static_slice: + new_indices.append(idx) + continue + dim = dim.start + if dim is batching.not_mapped: + # Broadcasting the slice is free (the start index stays the same) + new_indices.append(idx) + else: + raise NotImplementedError( + f"No support for vmapping over nontrivial slices just yet: {idx}") + else: + # Check if we are indexing with a scalar or not. If we are indexing + # with a scalar and we are not batched, we can avoid broadcasting it. + assert hasattr(idx, "shape") + if not idx.shape: + if dim is not batching.not_mapped: + assert idx.shape == (axis_size,) + idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, (0,)) + new_indices.append(idx) + else: + if dim is batching.not_mapped: + bcast_dims = tuple(range(1, np.ndim(idx) + 1)) + idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, + bcast_dims) + else: + idx = batching.moveaxis(idx, dim, 0) + new_indices.append(idx) + else: + if ref_dim is not batching.not_mapped: + if not isinstance(idx, indexing.Slice): + assert hasattr(idx, "shape") + if idx.shape: + bcast_dims = tuple(range(1, np.ndim(idx) + 1)) + idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, + bcast_dims) + new_indices.append(idx) + if ref_dim is not batching.not_mapped: + iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0) + new_indices.insert(ref_dim, iota) + return indexing.NDIndexer(tuple(new_indices), ref_shape, + new_integer_indexer_shape, + validate=True) + +def _get_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) if d is not batching.not_mapped} - ref, *idxs = batched_args - ref_dim, *idx_dims = batched_dims + ref, *flat_idxs = batched_args + ref_dim, *flat_idx_dims = batched_dims + indexers = tree_util.tree_unflatten(tree, flat_idxs) + indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims) - ref_is_batched = ref_dim is not batching.not_mapped - idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims) - bdim_out = 0 - - if idx_is_batched: - # If at least one of the idx is batched, we broadcast them all and move the - # batch dim to the front. - idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d - in zip(idxs, idx_dims)) - idxs_shape, = {i.shape for i in idxs} or [()] - if ref_is_batched: - # If ref is batched, we are doing a `get` with an additional axis. If `idxs` - # are also batched, then we are indexing into the batch axis with an `iota`. - indexed_dims = tuple_insert(indexed_dims, ref_dim, idx_is_batched) - if idx_is_batched: - # If we have batched idx, we need to insert the new iota index. The place - # where we add in the new `iota` index is `ref_dim` so we need to compute - # what `ref_dim` *would be* if we inserted it into `idxs` instead, because - # `idxs` doesn't include the non indexed dims. - idx_place = [i for i, i_dim in enumerate(indexed_dims) - if i_dim].index(ref_dim) - iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) - idxs = tuple_insert(idxs, idx_place, iota) - else: - bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape) - return get_p.bind(ref, *idxs, indexed_dims=indexed_dims), bdim_out + idx_is_batched = any(i_dim is not batching.not_mapped + for i_dim in flat_idx_dims) + if len(indexers) > 1: + raise NotImplementedError("Batching with multiple indexers not supported.") + # TODO(sharadmv): handle vmap of multiple indexers + indexers = tuple(_batch_indexer(indexer, dims, axis_size, + ref.shape, ref_dim, idx_is_batched) + for indexer, dims in zip(indexers, indexers_dims)) + flat_indexers, tree = tree_util.tree_flatten(indexers) + return get_p.bind(ref, *flat_indexers, tree=tree), 0 batching.primitive_batchers[get_p] = _get_vmap -def _swap_vmap(batched_args, batched_dims, *, indexed_dims): +def _swap_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) if d is not batching.not_mapped} - ref, val, *idxs = batched_args - ref_dim, val_dim, *idx_dims = batched_dims + ref, val, *flat_idxs = batched_args + ref_dim, val_dim, *flat_idx_dims = batched_dims + indexers = tree_util.tree_unflatten(tree, flat_idxs) + indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims) + ref_is_batched = ref_dim is not batching.not_mapped val_is_batched = val_dim is not batching.not_mapped - idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims) - if idx_is_batched: - # If at least one of the idx is batched, we broadcast them all and move the - # batch dim to the front. - idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d - in zip(idxs, idx_dims)) - idxs_shape, = {i.shape for i in idxs} or [()] - if ref_is_batched and not idx_is_batched: - indexed_dims = tuple_insert(indexed_dims, ref_dim, False) - bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape) - if not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - val_dim = 0 - val = batching.moveaxis(val, val_dim, bdim_out) - elif idx_is_batched: - assert ref_is_batched and val_is_batched - indexed_dims = tuple_insert(indexed_dims, ref_dim, True) - idx_place = [i for i, i_dim in enumerate(indexed_dims) - if i_dim].index(ref_dim) - iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) - idxs = tuple_insert(idxs, idx_place, iota) + idx_is_batched = any(i_dim is not batching.not_mapped + for i_dim in flat_idx_dims) + if len(indexers) > 1: + raise NotImplementedError("Batching with multiple indexers not supported.") + # TODO(sharadmv): handle vmap of multiple indexers + indexers = tuple(_batch_indexer(indexer, dims, axis_size, + ref.shape, ref_dim, idx_is_batched) + for indexer, dims in zip(indexers, indexers_dims)) + flat_indexers, tree = tree_util.tree_flatten(indexers) + if (ref_is_batched or idx_is_batched) and not val_is_batched: + val = batching.broadcast(val, axis_size, 0) + if val_is_batched: val = batching.moveaxis(val, val_dim, 0) - bdim_out = 0 - return swap_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), bdim_out + return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0 batching.primitive_batchers[swap_p] = _swap_vmap -def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims): +def _addupdate_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) if d is not batching.not_mapped} - ref, val, *idxs = batched_args - ref_dim, val_dim, *idx_dims = batched_dims + ref, val, *flat_idxs = batched_args + ref_dim, val_dim, *flat_idx_dims = batched_dims + indexers = tree_util.tree_unflatten(tree, flat_idxs) + indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims) + ref_is_batched = ref_dim is not batching.not_mapped val_is_batched = val_dim is not batching.not_mapped - idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims) - if idx_is_batched: - # If at least one of the idx is batched, we ensure all have bdims at front. - idxs = tuple(batching.bdim_at_front(i, d, axis_size) - for i, d in zip(idxs, idx_dims)) - idxs_shape, = {i.shape for i in idxs} or [()] - if ref_is_batched and not idx_is_batched: - indexed_dims = tuple_insert(indexed_dims, ref_dim, False) - bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape) - if not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - val_dim = 0 - val = batching.moveaxis(val, val_dim, bdim_out) - elif idx_is_batched: - assert ref_is_batched and val_is_batched - indexed_dims = tuple_insert(indexed_dims, ref_dim, True) - idx_place = [i for i, i_dim in enumerate(indexed_dims) - if i_dim].index(ref_dim) - idxs_shape, = {i.shape for i in idxs} or [()] - iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) - idxs = tuple_insert(idxs, idx_place, iota) + idx_is_batched = any(i_dim is not batching.not_mapped + for i_dim in flat_idx_dims) + if len(indexers) > 1: + raise NotImplementedError("Batching with multiple indexers not supported.") + # TODO(sharadmv): handle vmap of multiple indexers + indexers = tuple(_batch_indexer(indexer, dims, axis_size, + ref.shape, ref_dim, idx_is_batched) + for indexer, dims in zip(indexers, indexers_dims)) + flat_indexers, tree = tree_util.tree_flatten(indexers) + if (ref_is_batched or idx_is_batched) and not val_is_batched: + val = batching.broadcast(val, axis_size, 0) + if val_is_batched: val = batching.moveaxis(val, val_dim, 0) - return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), [] + return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), [] batching.primitive_batchers[addupdate_p] = _addupdate_vmap + +# Currently, JAX doesn't have a primitive that does an equal-rank broadcast. +# We could use `jnp.broadcast_to` but that lowers to squeezing, +# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`) +# so in the lowering, we have to expand out those squeezed dimensions again. +# Having a simple `broadcast_to` primitive allows us to lower directly +# to `tl.broadcast_to`. +broadcast_to_p = core.Primitive('broadcast_to') + +def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: + import jax.numpy as jnp + a = jnp.asarray(a) + if a.shape == shape: + return a + return broadcast_to_p.bind(a, shape=shape) + +@broadcast_to_p.def_impl +def _broadcast_to_impl(a, *, shape): + import jax.numpy as jnp + return jnp.broadcast_to(a, shape) + +@broadcast_to_p.def_abstract_eval +def _broadcast_to_abstract_eval(aval, *, shape): + return core.ShapedArray(shape, aval.dtype) + +mlir.register_lowering( + broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False) +) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index c2cb653a9..c8c8df889 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -23,6 +23,7 @@ from typing import Any, Generic, TypeVar, Union from jax._src import core from jax._src import effects from jax._src import pretty_printer as pp +from jax._src.state import indexing from jax._src.util import safe_map, safe_zip from jax._src.typing import Array @@ -77,21 +78,34 @@ Aval = TypeVar("Aval", bound=core.AbstractValue) @dataclasses.dataclass class RefIndexer: - ref: Any + ref_or_view: Any def __getitem__(self, slc): if not isinstance(slc, tuple): slc = (slc,) - return RefView(self.ref, slc) + indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape) + if isinstance(self.ref_or_view, RefView): + view = self.ref_or_view + return RefView(view.ref, (*view.indexers, indexer)) + return RefView(self.ref_or_view, (indexer,)) + +Indexer = Any @dataclasses.dataclass class RefView: ref: Any - indexer: Any + indexers: tuple[indexing.NDIndexer, ...] @property - def at(self): - raise NotImplementedError("Can't call `.at` multiple times.") + def shape(self) -> tuple[int, ...]: + assert ( + len(self.indexers) > 0 + ), "Should not be able to create a trivial RefView" + return self.indexers[-1].get_indexer_shape() + + @property + def at(self) -> RefIndexer: + return RefIndexer(self) # We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. @@ -137,14 +151,10 @@ class AbstractRef(core.AbstractValue, Generic[Aval]): return ref_set(tracer, idx, value) def _getitem(self, tracer, idx) -> Array: - if not isinstance(idx, tuple): - idx = idx, from jax._src.state.primitives import ref_get # pytype: disable=import-error return ref_get(tracer, idx) def _setitem(self, tracer, idx, value) -> None: - if not isinstance(idx, tuple): - idx = idx, from jax._src.state.primitives import ref_set # pytype: disable=import-error return ref_set(tracer, idx, value) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 371f49089..11f2d7ba3 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1472,6 +1472,7 @@ tf_not_yet_impl = [ "for", "inspect_sharding", "io_callback", + "broadcast_to", "shard_map", "global_array_to_host_local_array", "host_local_array_to_global_array", diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 958712d4c..a0680e9d9 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -17,10 +17,6 @@ from jax._src import pallas from jax._src.pallas.core import BlockSpec from jax._src.pallas.core import no_block_spec -from jax._src.pallas.indexing import broadcast_to -from jax._src.pallas.indexing import ds -from jax._src.pallas.indexing import dslice -from jax._src.pallas.indexing import Slice from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.primitives import atomic_add @@ -42,6 +38,10 @@ from jax._src.pallas.utils import cdiv from jax._src.pallas.utils import next_power_of_2 from jax._src.pallas.utils import strides_from_shape from jax._src.pallas.utils import when +from jax._src.state.primitives import broadcast_to +from jax._src.state.indexing import ds +from jax._src.state.indexing import dslice +from jax._src.state.indexing import Slice try: from jax.experimental.pallas import gpu # pytype: disable=import-error diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 344774c01..581a7284b 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -22,7 +22,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import util -from jax._src.pallas import indexing +from jax._src.state import indexing import numpy as np try: @@ -49,7 +49,8 @@ def int_indexer_strategy(dim) -> hps.SearchStrategy[int]: @hps.composite def slice_indexer_strategy(draw, dim) -> Slice | slice: start = draw(int_indexer_strategy(dim)) - size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max)) + max_size = dim - start + size = draw(hps.integers(min_value=0, max_value=max_size)) return draw( hps.one_of( hps.just(Slice(start, size)), hps.just(slice(start, start + size)) @@ -78,8 +79,8 @@ def indexer_strategy(draw, dim, int_indexer_shape def nd_indexer_strategy(draw, shape) -> NDIndexer: num_indices = draw(hps.integers(min_value=0, max_value=len(shape))) int_indexer_shape = draw(hnp.array_shapes()) - indices = [draw(indexer_strategy(dim, int_indexer_shape)) for dim - in shape[:num_indices]] + indices = tuple(draw(indexer_strategy(dim, int_indexer_shape)) + for dim in shape[:num_indices]) return NDIndexer.from_indices_shape(indices, shape) @@ -97,6 +98,24 @@ class IndexerTest(parameterized.TestCase): with self.assertRaises(ValueError): _ = NDIndexer.from_indices_shape(indices, shape) + def test_invalid_ndindexer_oob_int(self): + indices = (4, 0) + shape = (3, 5) + with self.assertRaises(ValueError): + _ = NDIndexer.from_indices_shape(indices, shape) + + def test_invalid_ndindexer_oob_slice_start(self): + indices = (slice(3, 2), 0) + shape = (3, 5) + with self.assertRaises(ValueError): + _ = NDIndexer.from_indices_shape(indices, shape) + + def test_invalid_ndindexer_oob_slice_end(self): + indices = (Slice(2, 2), 0) + shape = (3, 5) + with self.assertRaises(ValueError): + _ = NDIndexer.from_indices_shape(indices, shape) + def test_ndindexer_with_padding(self): indices = () shape = (5, 5) @@ -137,17 +156,17 @@ class IndexerTest(parameterized.TestCase): indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3)) - indices = (0, slice(4, 10), np.arange(5)) + indices = (0, slice(2, 10), np.arange(5)) indexer = NDIndexer.from_indices_shape(indices, shape) - self.assertTupleEqual(indexer.get_indexer_shape(), (5, 0)) + self.assertTupleEqual(indexer.get_indexer_shape(), (5, 1)) - indices = (0, 5, np.arange(5)) + indices = (0, 1, np.arange(5)) indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), (5,)) - indices = (ds(2, 3), np.arange(5)[:, None], np.arange(4)[None]) + indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None]) indexer = NDIndexer.from_indices_shape(indices, shape) - self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 3)) + self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2)) @hp.given(hps.data()) def test_ndindexer(self, data): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 2962ff92e..82b2c50e0 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1471,7 +1471,7 @@ class PallasPrimitivesTest(PallasTest): @parameterized.parameters(*[ (lambda: (pl.dslice(0, 4), slice(None), slice(None)), "<- a[:,:,:]"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), "<- a[:3,:,:]"), - (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:4,:,:4]"), + (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "<- a[1:,:,:4]"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "<- a[b,:,:4]"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.ds(4)), "<- a[f,g,:4]"), ]) @@ -1486,7 +1486,7 @@ class PallasPrimitivesTest(PallasTest): @parameterized.parameters(*[ (lambda: (pl.dslice(0, 4), slice(None), slice(None)), "a[:,:,:] <-"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), "a[:3,:,:] <-"), - (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:4,:,:4] <-"), + (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), "a[1:,:,:4] <-"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "a[b,:,:4] <-"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), "a[m,n,:4] <-"), ]) @@ -1504,7 +1504,7 @@ class PallasPrimitivesTest(PallasTest): (lambda: (pl.dslice(0, 3), slice(None), slice(None)), "c:i32[3,3,2], a[:3,:,:] <-"), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), - "c:i32[3,3,4], a[1:4,:,:4] <-"), + "c:i32[3,3,4], a[1:,:,:4] <-"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), "e:i32[5,3,4], a[b,:,:4] <-"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), diff --git a/tests/state_test.py b/tests/state_test.py index 4570c7958..13893d1af 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -56,15 +56,15 @@ class StatePrimitivesTest(jtu.JaxTestCase): def test_cant_eval_get_primitive(self): with self.assertRaises(ValueError): - get_p.bind(jnp.ones(5)) + get_p.bind(jnp.ones(5), tree=None) def test_cant_eval_swap_primitive(self): with self.assertRaises(ValueError): - swap_p.bind(jnp.ones(5), jnp.zeros(5)) + swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None) def test_cant_eval_addupdate_primitive(self): with self.assertRaises(ValueError): - addupdate_p.bind(jnp.ones(5), jnp.zeros(5)) + addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None) def test_get_abstract_aval_must_take_in_refs(self): ref_aval = core.ShapedArray((), jnp.float32) @@ -95,11 +95,37 @@ class StatePrimitivesTest(jtu.JaxTestCase): ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])), out_shape=(2, 1, 2), out_dtype=jnp.float32), + dict(testcase_name="get_with_nontrivial_slice", + ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])), + out_shape=(2, 1, 2), out_dtype=jnp.float32), + dict(testcase_name="get_with_nontrivial_slice2", + ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)), + out_shape=(1, 2, 2, 4), out_dtype=jnp.float32), + dict(testcase_name="get_with_ref_simple_at", + ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(1, 3), slice(None), slice(None)), + out_shape=(2, 2, 4), out_dtype=jnp.float32, + at_indices=((0,),)), + dict(testcase_name="get_with_ref_simple_at2", + ref_shape=(6, 1, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(0, 2), slice(0, 1), slice(1, 3), slice(None), slice(None)), + out_shape=(2, 1, 2, 2, 4), out_dtype=jnp.float32, + at_indices=((slice(2, 6),),)), + dict(testcase_name="get_with_ref_multiple_at", + ref_shape=(1, 3, 5, 4), ref_dtype=jnp.float32, + idx=(slice(None), slice(None), slice(0, 2)), + out_shape=(3, 1, 2), out_dtype=jnp.float32, + at_indices=((0,), (slice(None), slice(0, 1)))), ) def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None, - out_dtype=None, should_error=False): + out_dtype=None, at_indices=(), + should_error=False): ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) def f(x_ref): + for at_idx in at_indices: + x_ref = x_ref.at[at_idx] out = ref_get(x_ref, idx) return [out] if should_error: @@ -160,13 +186,43 @@ class StatePrimitivesTest(jtu.JaxTestCase): val_shape=(2, 1, 2), val_dtype=jnp.float32, idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])), out_shape=(2, 1, 2), out_dtype=jnp.float32), + dict(testcase_name="swap_with_nontrivial_slice", + ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])), + val_shape=(2, 1, 2), val_dtype=jnp.float32, + out_shape=(2, 1, 2), out_dtype=jnp.float32), + dict(testcase_name="swap_with_nontrivial_slice2", + ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)), + val_shape=(1, 2, 2, 4), val_dtype=jnp.float32, + out_shape=(1, 2, 2, 4), out_dtype=jnp.float32), + dict(testcase_name="swap_with_ref_simple_at", + ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(0, 1), slice(1, 3), slice(None),), + val_shape=(1, 1, 4), val_dtype=jnp.float32, + out_shape=(1, 1, 4), out_dtype=jnp.float32, + at_indices=((0,),),), + dict(testcase_name="swap_with_ref_simple_at2", + ref_shape=(4, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),), + val_shape=(2, 1, 1, 4), val_dtype=jnp.float32, + out_shape=(2, 1, 1, 4), out_dtype=jnp.float32, + at_indices=((slice(0, 2),),),), + dict(testcase_name="swap_with_ref_multiple_at2", + ref_shape=(1, 4, 3, 2, 4), ref_dtype=jnp.float32, + idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),), + val_shape=(2, 1, 1, 4), val_dtype=jnp.float32, + out_shape=(2, 1, 1, 4), out_dtype=jnp.float32, + at_indices=((slice(None), slice(0, 2),), (0,)),), ) def test_swap_abstract_eval(self, ref_shape, ref_dtype, val_shape, val_dtype, idx, out_shape=None, out_dtype=None, - should_error=False): + at_indices=(), should_error=False): ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) val_aval = core.ShapedArray(val_shape, val_dtype) def f(x_ref, val): + for at_idx in at_indices: + x_ref = x_ref.at[at_idx] out = ref_swap(x_ref, idx, val) return [out] if should_error: @@ -192,13 +248,13 @@ class StatePrimitivesTest(jtu.JaxTestCase): idx=(slice(None),), should_error=True), dict(testcase_name="trivial_addupdate", ref_shape=(1, 2), ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32, - idx=(), out_shape=(1, 2), out_dtype=jnp.float32), + idx=(),), dict(testcase_name="bad_dtype", ref_shape=(1, 2), ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32, idx=(), should_error=True), dict(testcase_name="addupdate_with_index", ref_shape=(1, 2), ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32, - idx=(0,), out_shape=(2,), out_dtype=jnp.float32), + idx=(0,),), dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2), ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32, idx=(slice(None), 0)), @@ -216,13 +272,34 @@ class StatePrimitivesTest(jtu.JaxTestCase): ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, val_shape=(2, 1, 2), val_dtype=jnp.float32, idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))), + dict(testcase_name="ref_with_simple_at", + ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32, + val_shape=(2, 2), val_dtype=jnp.float32, + idx=(np.array([0, 1]), slice(None), np.array([0, 1])), + at_indices=((0,),)), + dict(testcase_name="ref_with_simple_at2", + ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32, + val_shape=(2, 3, 4), val_dtype=jnp.float32, + idx=(np.array([0, 1]), slice(None), np.array([0, 1])), + at_indices=((slice(0, 3),),)), + dict(testcase_name="ref_with_multiple_at", + ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32, + val_shape=(2, 2), val_dtype=jnp.float32, + idx=(np.array([0, 1]), slice(None), np.array([0, 1])), + at_indices=((slice(0, 3),), (0,))), + dict(testcase_name="ref_with_multiple_at2", + ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32, + val_shape=(2, 2), val_dtype=jnp.float32, + idx=(np.array([0, 1]), slice(None), np.array([0, 1])), + at_indices=((slice(None), slice(0, 3),), (0,))), ) def test_addupdate_abstract_eval(self, ref_shape, ref_dtype, - val_shape, val_dtype, idx, out_shape=None, out_dtype=None, - should_error=False): + val_shape, val_dtype, idx, at_indices=(), should_error=False): ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype)) val_aval = core.ShapedArray(val_shape, val_dtype) def f(x_ref, val): + for at_idx in at_indices: + x_ref = x_ref.at[at_idx] ref_addupdate(x_ref, idx, val) return [] if should_error: @@ -1595,7 +1672,6 @@ if CAN_USE_HYPOTHESIS: def test_vjp(self, data): spec = data.draw(func_spec()) - print(spec) def impl(x): return spec.call((x, jnp.zeros_like(x)))[1]