mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Use fastcache for LRU caches in JAX.
fastcache is both a faster cache implementation and is also thread-safe.
This commit is contained in:
parent
f31d58fcd8
commit
08013954a4
@ -20,7 +20,7 @@ before_install:
|
||||
- conda update --yes conda
|
||||
- conda config --add channels conda-forge
|
||||
install:
|
||||
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip six protobuf>=3.6.0 absl-py opt_einsum numpy scipy pytest-xdist
|
||||
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip six protobuf>=3.6.0 absl-py opt_einsum numpy scipy pytest-xdist fastcache
|
||||
- pip install jaxlib
|
||||
- pip install -v .
|
||||
script:
|
||||
|
@ -28,6 +28,7 @@ import warnings
|
||||
from distutils.util import strtobool
|
||||
|
||||
from ..config import flags
|
||||
from .. import util
|
||||
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
|
||||
import six
|
||||
|
||||
@ -91,14 +92,6 @@ def get_compile_options(num_replicas=None):
|
||||
return compile_options
|
||||
|
||||
|
||||
def memoize(func):
|
||||
class memodict(dict):
|
||||
def __missing__(self, key):
|
||||
val = self[key] = func(key)
|
||||
return val
|
||||
return memodict().__getitem__
|
||||
|
||||
|
||||
def memoize_thunk(func):
|
||||
cached = []
|
||||
return lambda: cached[0] if cached else (cached.append(func()) or cached[0])
|
||||
@ -173,12 +166,12 @@ Shape = xla_client.Shape # pylint: disable=invalid-name
|
||||
|
||||
### utility functions
|
||||
|
||||
@memoize
|
||||
@util.memoize_unary
|
||||
def dtype_to_etype(dtype):
|
||||
"""Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
|
||||
return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[canonicalize_dtype(dtype)]
|
||||
|
||||
@memoize
|
||||
@util.memoize_unary
|
||||
def dtype_to_etype_exact(dtype):
|
||||
"""Convert from dtype to exact etype (ignoring FLAGS.jax_enable_x64)."""
|
||||
return xla_client.dtype_to_etype(dtype)
|
||||
@ -192,7 +185,7 @@ _dtype_to_32bit_dtype = {
|
||||
}
|
||||
|
||||
|
||||
@memoize
|
||||
@util.memoize_unary
|
||||
def canonicalize_dtype(dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
|
||||
dtype = onp.dtype(dtype)
|
||||
|
@ -69,7 +69,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .util import curry, partial, OrderedDict
|
||||
import fastcache
|
||||
|
||||
from .util import curry, partial
|
||||
|
||||
|
||||
def thunk(f):
|
||||
@ -195,18 +197,14 @@ def wrap_init(f, params={}):
|
||||
|
||||
|
||||
def memoize(call, max_size=4096):
|
||||
cache = OrderedDict()
|
||||
@fastcache.clru_cache(maxsize=max_size)
|
||||
def memoized_fun_body(f, args):
|
||||
return call(f, *args), f
|
||||
|
||||
def memoized_fun(f, *args):
|
||||
key = (f, args)
|
||||
if key in cache:
|
||||
ans, f_prev = cache[key]
|
||||
cache.move_to_end(key)
|
||||
ans, f_prev = memoized_fun_body(f, args)
|
||||
if id(f_prev) != id(f):
|
||||
f.populate_stores(f_prev)
|
||||
else:
|
||||
if len(cache) > max_size:
|
||||
cache.popitem(last=False)
|
||||
ans = call(f, *args)
|
||||
cache[key] = (ans, f)
|
||||
return ans
|
||||
|
||||
return memoized_fun
|
||||
|
@ -179,6 +179,8 @@ def tree_structure(tree):
|
||||
|
||||
|
||||
class PyTreeDef(object):
|
||||
__slots__ = ("node_type", "node_data", "children")
|
||||
|
||||
def __init__(self, node_type, node_data, children):
|
||||
self.node_type = node_type
|
||||
self.node_data = node_data
|
||||
@ -209,6 +211,8 @@ class PyTreeDef(object):
|
||||
|
||||
|
||||
class PyLeaf(object):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return '*'
|
||||
|
||||
|
39
jax/util.py
39
jax/util.py
@ -21,12 +21,10 @@ import functools
|
||||
import itertools as it
|
||||
from operator import mul
|
||||
import types
|
||||
|
||||
import fastcache
|
||||
import numpy as onp
|
||||
|
||||
import six
|
||||
|
||||
allow_memoize_hash_failures = False
|
||||
|
||||
|
||||
def safe_zip(*args):
|
||||
n = len(args[0])
|
||||
@ -142,39 +140,8 @@ def split_merge(predicate, xs):
|
||||
return lhs, rhs, merge
|
||||
|
||||
|
||||
if six.PY3:
|
||||
OrderedDict = collections.OrderedDict
|
||||
else:
|
||||
# Retrofits a move_to_end method to OrderedDict in Python 2 mode.
|
||||
class OrderedDict(collections.OrderedDict):
|
||||
def move_to_end(self, key):
|
||||
value = self[key]
|
||||
del self[key]
|
||||
self[key] = value
|
||||
|
||||
|
||||
_NO_MEMO_ENTRY = object()
|
||||
|
||||
def memoize(fun, max_size=4096):
|
||||
cache = OrderedDict()
|
||||
def memoized_fun(*args, **kwargs):
|
||||
key = (args, tuple(kwargs and sorted(kwargs.items())))
|
||||
try:
|
||||
ans = cache.get(key, _NO_MEMO_ENTRY)
|
||||
if ans != _NO_MEMO_ENTRY:
|
||||
cache.move_to_end(key)
|
||||
return ans
|
||||
except TypeError:
|
||||
if not allow_memoize_hash_failures:
|
||||
raise
|
||||
|
||||
if len(cache) > max_size:
|
||||
cache.popitem(last=False)
|
||||
|
||||
ans = cache[key] = fun(*args, **kwargs)
|
||||
return ans
|
||||
return memoized_fun
|
||||
|
||||
return fastcache.clru_cache(maxsize=max_size)(fun)
|
||||
|
||||
def memoize_unary(func):
|
||||
class memodict(dict):
|
||||
|
3
setup.py
3
setup.py
@ -25,7 +25,8 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=find_packages(exclude=["examples"]),
|
||||
install_requires=[
|
||||
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum'
|
||||
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum',
|
||||
'fastcache'
|
||||
],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
|
Loading…
x
Reference in New Issue
Block a user