We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
* A TypedJaxpr contains more useful information (consts, types)
* Also forced the instantiation of constants when producing the jaxpr.
Before:
>>>print(api.make_jaxpr(lambda x: 1.)(0.))
lambda ; ; a.
let
in [*]}
After this change:
>>>print(api.make_jaxpr(lambda x: 1.)(0.))
lambda ; ; a.
let
in [1.0]}
The appropriate Backend is instead inferred from the 'device' argument. This is a first step towards removing the 'backend' argument from more functions.
* Move internal type-related functions into a new (internal) jax.types module.
Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.
Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.
* Rename jax.types to jax.dtypes.
* s/types/dtypes/ in tests.
The original repro @levskaya showed us was essentially this OOM:
for i in range(40):
f = jit(lambda: 1. * np.ones((300, 1024, 1024)))
f().block_until_ready()
Even though f was being rebound on every iteration, the cache entries
corresponding to the previous iterations of the loop were sticking around.
Instead, if the user drops all references to a function, we want to clear the
corresponding compilation cache entries (since they can never be used).
The fix here is to use a two-level cache for compiled code: the first level is
a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun,
and the second level is a regular dictionary keyed by (transforms, params,
args). Because this logic is now present in linear_util.py:cache, the
implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous
now.
One unintended consequence is that this implementation now avoids using
fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to
implement this logic in pure Python. We might want to revise this for
performance reasons.
This commit also incidentally fixed#1600.
All participating hosts are assumed to be running the same pmap
code. Conceptually, this can be considered a single pmap over an array
sharded on its leading pmapped dimension across the hosts. Each host
passes its input shard to its pmapped function call, which returns the
corresponding output shard (i.e. an array of the same leading
dimension size). However, any collective operations will be run across
the entire "global" array.
If the `devices` argument to pmap is None, the pmap is assumed to be
running across all hosts visible to XLA (as returned by
jax.host_count()). Each host can pass in an input array of leading
dimension size equal to or less than the number of devices local to
that host. Note that this doesn't change the current behavior for
single-host platforms. If `devices` are specified, the participating
hosts are dictated by the devices' host_ids, and each host must pass
in an input array of leading dim size equal to the number of local
participating devices.
Implementation-wise, each host independently compiles the computation,
which we assume yields the same executable on all hosts (follow-up
work will add more error checking). The hosts must know the global
axis size of the sharded array, e.g. to provide the correct replica
count to XLA. This is equal to the length of `devices` if specified,
but if not, pmap is recursively called (with `devices` specified) to
use `psum` to compute the global axis size.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
should be used in the replicated computation.