Remove _ListWithW since it is not needed anymore

PiperOrigin-RevId: 510495372
This commit is contained in:
Yash Katariya 2023-02-17 12:29:29 -08:00 committed by jax authors
parent 031d15ed2d
commit 7a09fb98ef

View File

@ -735,9 +735,6 @@ def pjit(
has_explicit_sharding)
class _ListWithW(list):
__slots__ = ('__weakref__',)
def hashable_pytree(pytree):
vals, treedef = tree_flatten(pytree)
vals = tuple(vals)
@ -941,7 +938,7 @@ def _create_pjit_jaxpr(fun, global_in_avals, api_name):
else:
jaxpr = core.ClosedJaxpr(jaxpr, consts)
final_consts = []
return _ListWithW([jaxpr, final_consts, global_out_avals])
return jaxpr, final_consts, global_out_avals
@lru_cache(maxsize=4096)