Remove jax._src.util.partialmethod.

Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10.

Issue #8097
This commit is contained in:
Peter Hawkins 2021-10-05 12:12:41 -04:00
parent 1520fa261f
commit 42e0d4e5f5
3 changed files with 6 additions and 15 deletions

View File

@ -92,14 +92,6 @@ def split_dict(dct, names):
def concatenate(xs):
return list(it.chain.from_iterable(xs))
class partialmethod(functools.partial):
def __get__(self, instance, owner):
if instance is None:
return self
else:
return partial(self.func, instance,
*(self.args or ()), **(self.keywords or {}))
def curry(f):
"""Curries arguments of f, returning a function on any remaining arguments.

View File

@ -16,7 +16,7 @@
import collections
from collections import namedtuple
from contextlib import contextmanager
from functools import partial, total_ordering
from functools import partial, partialmethod, total_ordering
import gc
import itertools as it
import operator
@ -38,8 +38,8 @@ from .errors import (ConcretizationTypeError, TracerArrayConversionError,
from . import linear_util as lu
from jax._src import source_info_util
from ._src.util import (safe_zip, safe_map, curry, prod, partialmethod,
tuple_insert, tuple_delete, cache, as_hashable_function,
from ._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
tuple_delete, cache, as_hashable_function,
HashableFunction)
import jax._src.pretty_printer as pp

View File

@ -14,7 +14,7 @@
from collections import defaultdict, deque
from functools import partial
from functools import partial, partialmethod
import itertools as it
import operator as op
import re
@ -38,9 +38,8 @@ from ..core import (ConcreteArray, ShapedArray, AbstractToken,
abstract_token)
from ..errors import UnexpectedTracerError
import jax._src.pretty_printer as pp
from .._src.util import (partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map,
partition_list)
from .._src.util import (cache, prod, unzip2, extend_name_stack, wrap_name,
safe_zip, safe_map, partition_list)
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from . import partial_eval as pe