Add type annotation to pl.load

This commit is contained in:
Tristan Hume 2024-01-30 16:32:29 -08:00 committed by GitHub
parent af2292aa4e
commit 7933acdb90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -418,7 +418,7 @@ state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
eviction_policy="", volatile=False):
eviction_policy="", volatile=False) -> jax.Array:
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
return load_p.bind(