Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes#1431
See https://github.com/google/jax/pull/1668 for more.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
Fixes#883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.
Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
Currently backtraces often look like this:
```
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/p/jax/jax/util.py in memoized_fun(*args, **kwargs)
133 try:
--> 134 return cache[key]
135 except KeyError:
KeyError: ((lu, ShapedArray(int32[2,2])), ())
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
~/p/jax/jax/util.py in memoized_fun(*args, **kwargs)
133 try:
--> 134 return cache[key]
135 except KeyError:
KeyError: ((lu, xla_client.Shape(_dtype=dtype('int32'), _dimensions=(2, 2), _is_tuple=False, _minor_to_major=None)), ())
During handling of the above exception, another exception occurred:
NotImplementedError Traceback (most recent call last)
<ipython-input-26-d6c00d50e3c9> in <module>
```
The "during handling of the above exception..." message is mostly a distraction for the user that occurs because we perform the memoized function evaluation inside a `catch` block. By performing the function evaluation outside the catch block, we can get better backtraces without the distraction of the KeyError exception.
```