Merge pull request #19593 from trishume:patch-1

PiperOrigin-RevId: 602881185
This commit is contained in:
jax authors 2024-01-30 17:01:03 -08:00
commit 4393f84680

View File

@ -418,7 +418,7 @@ state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
eviction_policy="", volatile=False):
eviction_policy="", volatile=False) -> jax.Array:
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
return load_p.bind(