[pallas] Allow TransformedRef to be passed to pl.load / pl.store, when idx = ().

PiperOrigin-RevId: 678257485
This commit is contained in:
Chris Jones 2024-09-24 08:16:48 -07:00 committed by jax authors
parent 6860617ebf
commit 8d86a04727

View File

@ -62,10 +62,13 @@ zip, unsafe_zip = safe_zip, zip
get_p = core.Primitive("get")
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...]
Indexer = Union[int, slice, Array, types.EllipsisType]
def get_ref_and_transforms(
ref_or_view: Any, idx: Indexer | None, function_name: str
ref_or_view: Any,
idx: Indexer | tuple[Indexer, ...] | None,
function_name: str,
) -> tuple[Any, tuple[Transform, ...]]:
if isinstance(ref_or_view, TransformedRef):
ref, transforms = ref_or_view.ref, ref_or_view.transforms
@ -76,18 +79,27 @@ def get_ref_and_transforms(
raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
return ref, ()
if idx is None:
if idx is None or idx is Ellipsis:
idx = ()
elif not isinstance(idx, tuple):
idx = (idx,)
if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer):
return ref, transforms
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
return ref, (*transforms, nd_indexer)
def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
def ref_get(
ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None = None
) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get")
flat_transforms, tree = tree_util.tree_flatten(transforms)
return get_p.bind(ref, *flat_transforms, tree=tree)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
# It generalizes the setting operation for a `Ref` as we can ignore the return
@ -110,7 +122,7 @@ swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
def ref_swap(
ref_or_view: AbstractRef | TransformedRef,
idx: Indexer | None,
idx: Indexer | tuple[Indexer, ...] | None,
value: Array,
_function_name: str = "ref_swap",
) -> Array:
@ -121,11 +133,14 @@ def ref_swap(
def ref_set(
ref_or_view: AbstractRef | TransformedRef, idx: Indexer | None, value: Array
ref_or_view: AbstractRef | TransformedRef,
idx: Indexer | tuple[Indexer, ...] | None,
value: Array,
) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref_or_view, idx, value, _function_name="ref_set")
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
# Semantically,
# ```
@ -141,12 +156,18 @@ addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True
addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p))
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
def ref_addupdate(
ref_or_view: AbstractRef,
idx: Indexer | tuple[Indexer, ...] | None,
x: Array,
) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate")
flat_transforms, tree = tree_util.tree_flatten(transforms)
return addupdate_p.bind(ref, x, *flat_transforms, tree=tree)
## get/set/addupdate abstract evaluation rules