Use fastcache for LRU caches in JAX.

fastcache is both a faster cache implementation and is also thread-safe.
This commit is contained in:
Peter Hawkins 2019-07-22 17:24:10 -04:00
parent f31d58fcd8
commit 08013954a4
6 changed files with 23 additions and 60 deletions

View File

@ -20,7 +20,7 @@ before_install:
- conda update --yes conda
- conda config --add channels conda-forge
install:
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip six protobuf>=3.6.0 absl-py opt_einsum numpy scipy pytest-xdist
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip six protobuf>=3.6.0 absl-py opt_einsum numpy scipy pytest-xdist fastcache
- pip install jaxlib
- pip install -v .
script:

View File

@ -28,6 +28,7 @@ import warnings
from distutils.util import strtobool
from ..config import flags
from .. import util
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
import six
@ -91,14 +92,6 @@ def get_compile_options(num_replicas=None):
return compile_options
def memoize(func):
class memodict(dict):
def __missing__(self, key):
val = self[key] = func(key)
return val
return memodict().__getitem__
def memoize_thunk(func):
cached = []
return lambda: cached[0] if cached else (cached.append(func()) or cached[0])
@ -173,12 +166,12 @@ Shape = xla_client.Shape # pylint: disable=invalid-name
### utility functions
@memoize
@util.memoize_unary
def dtype_to_etype(dtype):
"""Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[canonicalize_dtype(dtype)]
@memoize
@util.memoize_unary
def dtype_to_etype_exact(dtype):
"""Convert from dtype to exact etype (ignoring FLAGS.jax_enable_x64)."""
return xla_client.dtype_to_etype(dtype)
@ -192,7 +185,7 @@ _dtype_to_32bit_dtype = {
}
@memoize
@util.memoize_unary
def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
dtype = onp.dtype(dtype)

View File

@ -69,7 +69,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .util import curry, partial, OrderedDict
import fastcache
from .util import curry, partial
def thunk(f):
@ -195,18 +197,14 @@ def wrap_init(f, params={}):
def memoize(call, max_size=4096):
cache = OrderedDict()
@fastcache.clru_cache(maxsize=max_size)
def memoized_fun_body(f, args):
return call(f, *args), f
def memoized_fun(f, *args):
key = (f, args)
if key in cache:
ans, f_prev = cache[key]
cache.move_to_end(key)
ans, f_prev = memoized_fun_body(f, args)
if id(f_prev) != id(f):
f.populate_stores(f_prev)
else:
if len(cache) > max_size:
cache.popitem(last=False)
ans = call(f, *args)
cache[key] = (ans, f)
return ans
return memoized_fun

View File

@ -179,6 +179,8 @@ def tree_structure(tree):
class PyTreeDef(object):
__slots__ = ("node_type", "node_data", "children")
def __init__(self, node_type, node_data, children):
self.node_type = node_type
self.node_data = node_data
@ -209,6 +211,8 @@ class PyTreeDef(object):
class PyLeaf(object):
__slots__ = ()
def __repr__(self):
return '*'

View File

@ -21,12 +21,10 @@ import functools
import itertools as it
from operator import mul
import types
import fastcache
import numpy as onp
import six
allow_memoize_hash_failures = False
def safe_zip(*args):
n = len(args[0])
@ -142,39 +140,8 @@ def split_merge(predicate, xs):
return lhs, rhs, merge
if six.PY3:
OrderedDict = collections.OrderedDict
else:
# Retrofits a move_to_end method to OrderedDict in Python 2 mode.
class OrderedDict(collections.OrderedDict):
def move_to_end(self, key):
value = self[key]
del self[key]
self[key] = value
_NO_MEMO_ENTRY = object()
def memoize(fun, max_size=4096):
cache = OrderedDict()
def memoized_fun(*args, **kwargs):
key = (args, tuple(kwargs and sorted(kwargs.items())))
try:
ans = cache.get(key, _NO_MEMO_ENTRY)
if ans != _NO_MEMO_ENTRY:
cache.move_to_end(key)
return ans
except TypeError:
if not allow_memoize_hash_failures:
raise
if len(cache) > max_size:
cache.popitem(last=False)
ans = cache[key] = fun(*args, **kwargs)
return ans
return memoized_fun
return fastcache.clru_cache(maxsize=max_size)(fun)
def memoize_unary(func):
class memodict(dict):

View File

@ -25,7 +25,8 @@ setup(
author_email='jax-dev@google.com',
packages=find_packages(exclude=["examples"]),
install_requires=[
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum'
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum',
'fastcache'
],
url='https://github.com/google/jax',
license='Apache-2.0',