mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove _ListWithW since it is not needed anymore
PiperOrigin-RevId: 510495372
This commit is contained in:
parent
031d15ed2d
commit
7a09fb98ef
@ -735,9 +735,6 @@ def pjit(
|
||||
has_explicit_sharding)
|
||||
|
||||
|
||||
class _ListWithW(list):
|
||||
__slots__ = ('__weakref__',)
|
||||
|
||||
def hashable_pytree(pytree):
|
||||
vals, treedef = tree_flatten(pytree)
|
||||
vals = tuple(vals)
|
||||
@ -941,7 +938,7 @@ def _create_pjit_jaxpr(fun, global_in_avals, api_name):
|
||||
else:
|
||||
jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
final_consts = []
|
||||
return _ListWithW([jaxpr, final_consts, global_out_avals])
|
||||
return jaxpr, final_consts, global_out_avals
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
|
Loading…
x
Reference in New Issue
Block a user