mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19593 from trishume:patch-1
PiperOrigin-RevId: 602881185
This commit is contained in:
commit
4393f84680
@ -418,7 +418,7 @@ state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
|
||||
|
||||
|
||||
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
|
||||
eviction_policy="", volatile=False):
|
||||
eviction_policy="", volatile=False) -> jax.Array:
|
||||
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
|
||||
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
|
||||
return load_p.bind(
|
||||
|
Loading…
x
Reference in New Issue
Block a user