mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[pallas] Allow TransformedRef
to be passed to pl.load
/ pl.store
, when idx = ()
.
PiperOrigin-RevId: 678257485
This commit is contained in:
parent
6860617ebf
commit
8d86a04727
@ -62,10 +62,13 @@ zip, unsafe_zip = safe_zip, zip
|
||||
get_p = core.Primitive("get")
|
||||
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
|
||||
|
||||
Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...]
|
||||
Indexer = Union[int, slice, Array, types.EllipsisType]
|
||||
|
||||
|
||||
def get_ref_and_transforms(
|
||||
ref_or_view: Any, idx: Indexer | None, function_name: str
|
||||
ref_or_view: Any,
|
||||
idx: Indexer | tuple[Indexer, ...] | None,
|
||||
function_name: str,
|
||||
) -> tuple[Any, tuple[Transform, ...]]:
|
||||
if isinstance(ref_or_view, TransformedRef):
|
||||
ref, transforms = ref_or_view.ref, ref_or_view.transforms
|
||||
@ -76,18 +79,27 @@ def get_ref_and_transforms(
|
||||
raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
|
||||
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
return ref, ()
|
||||
if idx is None:
|
||||
|
||||
if idx is None or idx is Ellipsis:
|
||||
idx = ()
|
||||
elif not isinstance(idx, tuple):
|
||||
idx = (idx,)
|
||||
|
||||
if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer):
|
||||
return ref, transforms
|
||||
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
|
||||
return ref, (*transforms, nd_indexer)
|
||||
|
||||
|
||||
def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
|
||||
def ref_get(
|
||||
ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None = None
|
||||
) -> Array:
|
||||
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||||
ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get")
|
||||
flat_transforms, tree = tree_util.tree_flatten(transforms)
|
||||
return get_p.bind(ref, *flat_transforms, tree=tree)
|
||||
|
||||
|
||||
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
||||
# b = swap_p.bind(x, a)
|
||||
# It generalizes the setting operation for a `Ref` as we can ignore the return
|
||||
@ -110,7 +122,7 @@ swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
|
||||
|
||||
def ref_swap(
|
||||
ref_or_view: AbstractRef | TransformedRef,
|
||||
idx: Indexer | None,
|
||||
idx: Indexer | tuple[Indexer, ...] | None,
|
||||
value: Array,
|
||||
_function_name: str = "ref_swap",
|
||||
) -> Array:
|
||||
@ -121,11 +133,14 @@ def ref_swap(
|
||||
|
||||
|
||||
def ref_set(
|
||||
ref_or_view: AbstractRef | TransformedRef, idx: Indexer | None, value: Array
|
||||
ref_or_view: AbstractRef | TransformedRef,
|
||||
idx: Indexer | tuple[Indexer, ...] | None,
|
||||
value: Array,
|
||||
) -> None:
|
||||
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
||||
ref_swap(ref_or_view, idx, value, _function_name="ref_set")
|
||||
|
||||
|
||||
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
|
||||
# Semantically,
|
||||
# ```
|
||||
@ -141,12 +156,18 @@ addupdate_p = core.Primitive('addupdate')
|
||||
addupdate_p.multiple_results = True
|
||||
addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p))
|
||||
|
||||
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
|
||||
|
||||
def ref_addupdate(
|
||||
ref_or_view: AbstractRef,
|
||||
idx: Indexer | tuple[Indexer, ...] | None,
|
||||
x: Array,
|
||||
) -> None:
|
||||
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||||
ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate")
|
||||
flat_transforms, tree = tree_util.tree_flatten(transforms)
|
||||
return addupdate_p.bind(ref, x, *flat_transforms, tree=tree)
|
||||
|
||||
|
||||
## get/set/addupdate abstract evaluation rules
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user