mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove jax._src.util.partialmethod.
Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10. Issue #8097
This commit is contained in:
parent
1520fa261f
commit
42e0d4e5f5
@ -92,14 +92,6 @@ def split_dict(dct, names):
|
||||
def concatenate(xs):
|
||||
return list(it.chain.from_iterable(xs))
|
||||
|
||||
class partialmethod(functools.partial):
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return partial(self.func, instance,
|
||||
*(self.args or ()), **(self.keywords or {}))
|
||||
|
||||
def curry(f):
|
||||
"""Curries arguments of f, returning a function on any remaining arguments.
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, total_ordering
|
||||
from functools import partial, partialmethod, total_ordering
|
||||
import gc
|
||||
import itertools as it
|
||||
import operator
|
||||
@ -38,8 +38,8 @@ from .errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
from . import linear_util as lu
|
||||
|
||||
from jax._src import source_info_util
|
||||
from ._src.util import (safe_zip, safe_map, curry, prod, partialmethod,
|
||||
tuple_insert, tuple_delete, cache, as_hashable_function,
|
||||
from ._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
|
||||
tuple_delete, cache, as_hashable_function,
|
||||
HashableFunction)
|
||||
import jax._src.pretty_printer as pp
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
from functools import partial, partialmethod
|
||||
import itertools as it
|
||||
import operator as op
|
||||
import re
|
||||
@ -38,9 +38,8 @@ from ..core import (ConcreteArray, ShapedArray, AbstractToken,
|
||||
abstract_token)
|
||||
from ..errors import UnexpectedTracerError
|
||||
import jax._src.pretty_printer as pp
|
||||
from .._src.util import (partialmethod, cache, prod, unzip2,
|
||||
extend_name_stack, wrap_name, safe_zip, safe_map,
|
||||
partition_list)
|
||||
from .._src.util import (cache, prod, unzip2, extend_name_stack, wrap_name,
|
||||
safe_zip, safe_map, partition_list)
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from . import partial_eval as pe
|
||||
|
Loading…
x
Reference in New Issue
Block a user