227 Commits

Author SHA1 Message Date
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Matthew Johnson
82dbf91311 add tests for #1640, adapt make_jaxpr staging 2019-12-31 11:53:02 -08:00
Pavel Sountsov
cc92bb6411 Improve the VJP structure mismatch errors. (#1854) 2019-12-13 08:41:51 -05:00
Stephan Hoyer
7bf2d77bd9
Clarify SPMD requirement for pmap (#1826) 2019-12-06 12:03:22 -08:00
Matthew Johnson
0899673363 switch xla_computation instantiate outputs default 2019-12-04 10:34:02 -08:00
Matthew Johnson
c1aeaf511c xla_computation option to instantiate const output 2019-12-04 10:34:02 -08:00
George Necula
2b0b04fcad Merge remote-tracking branch 'upstream/master' into jaxpr_pp 2019-11-28 08:56:00 +01:00
George Necula
0cb3b433b5 Change in how we print sorted params for eqns 2019-11-28 07:34:40 +01:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
George Necula
e0706ff864 Relaxed check to allow both tuples and lists 2019-11-27 14:24:41 +01:00
George Necula
c1d8d3f74d Add error checking that arguments of jvp are tuples 2019-11-27 13:12:24 +01:00
George Necula
5c15dda2c9 Changed api.make_jaxpr to return a TypedJaxpr
* A TypedJaxpr contains more useful information (consts, types)
* Also forced the instantiation of constants when producing the jaxpr.
  Before:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [*]}
  After this change:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [1.0]}
2019-11-26 09:17:03 +01:00
Skye Wanderman-Milne
f415f266b8
Remove 'backend' argument from device_put. (#1762)
The appropriate Backend is instead inferred from the 'device' argument. This is a first step towards removing the 'backend' argument from more functions.
2019-11-25 16:23:40 -08:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Matthew Johnson
728cb7fba8 improve grad error message without enough args
fixes #1696
2019-11-14 21:18:23 -08:00
Peter Hawkins
6125157db8
Add type checks that verify JVP primal inputs have the same types as tangent inputs, and JVP cotangent inputs have the same type as primal outputs. (#1690) 2019-11-14 15:37:33 -05:00
Matthew Johnson
483553ffd7 patch lax.axis_index, add warning about soft_pmap 2019-11-14 00:23:26 -08:00
Matthew Johnson
6f47ac007f fix xla.lower_fun and jax.xla_computation 2019-11-11 15:07:46 -08:00
Matthew Johnson
8bcee8d45f fix a leak where compiled results lived too long
The original repro @levskaya showed us was essentially this OOM:

  for i in range(40):
    f = jit(lambda: 1. * np.ones((300, 1024, 1024)))
    f().block_until_ready()

Even though f was being rebound on every iteration, the cache entries
corresponding to the previous iterations of the loop were sticking around.

Instead, if the user drops all references to a function, we want to clear the
corresponding compilation cache entries (since they can never be used).

The fix here is to use a two-level cache for compiled code: the first level is
a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun,
and the second level is a regular dictionary keyed by (transforms, params,
args). Because this logic is now present in linear_util.py:cache, the
implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous
now.

One unintended consequence is that this implementation now avoids using
fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to
implement this logic in pure Python. We might want to revise this for
performance reasons.

This commit also incidentally fixed #1600.
2019-10-31 16:26:29 -07:00
Matthew Johnson
46fe76c23a tweak comment 2019-10-31 14:47:16 -07:00
Matthew Johnson
979b38352f make vmap structured axes work for any pytree 2019-10-31 14:09:12 -07:00
Matthew Johnson
9923cefe8f Merge branch 'vmap-improvements' of github.com:google/jax into vmap-improvements 2019-10-31 11:59:06 -07:00
Matthew Johnson
14acca7b51 address reviewer comments, fix test error 2019-10-31 11:57:37 -07:00
Matthew Johnson
9d94c42323
Update jax/api.py
Co-Authored-By: Stephan Hoyer <shoyer@google.com>
2019-10-31 11:22:23 -07:00
Matthew Johnson
eae47b2330 improve vmap error messages
fixes #705
2019-10-31 10:35:08 -07:00
Matthew Johnson
f5079a6281 improve vmap docstring and tree prefix errors
fixes #795
2019-10-30 15:39:58 -07:00
Matthew Johnson
585cefc8c0 document pmap with pytrees, fixes #1486 2019-10-15 23:49:15 +00:00
Skye Wanderman-Milne
5585dda9fe
Change device_put to take a device argument instead of device_num. (#1463) 2019-10-11 14:07:16 -07:00
Peter Hawkins
27afa128d2 Fix rendering of jax.eval_shape docs. 2019-10-11 09:16:02 -04:00
Skye Wanderman-Milne
eb0137be36 Add optional tuple_args argument to xla_computation.
This is useful when using JAX to create an HLO module that is compiled
and executed elsewhere.

Also fixes a bug in the `tuple_args` logic.
2019-09-30 11:10:32 -07:00
Skye Wanderman-Milne
9ac5aa1e15 Add jax.local_devices() and jax.host_ids(). 2019-09-27 16:36:45 -07:00
Matthew Johnson
c33e8cb2c3
Merge pull request #1339 from shoyer/solvers
lax.root, a primitive for differentiable root finding
2019-09-27 12:17:47 -07:00
Skye Wanderman-Milne
dc2ee0de89 Add support for multihost pmaps.
All participating hosts are assumed to be running the same pmap
code. Conceptually, this can be considered a single pmap over an array
sharded on its leading pmapped dimension across the hosts. Each host
passes its input shard to its pmapped function call, which returns the
corresponding output shard (i.e. an array of the same leading
dimension size). However, any collective operations will be run across
the entire "global" array.

If the `devices` argument to pmap is None, the pmap is assumed to be
running across all hosts visible to XLA (as returned by
jax.host_count()). Each host can pass in an input array of leading
dimension size equal to or less than the number of devices local to
that host. Note that this doesn't change the current behavior for
single-host platforms. If `devices` are specified, the participating
hosts are dictated by the devices' host_ids, and each host must pass
in an input array of leading dim size equal to the number of local
participating devices.

Implementation-wise, each host independently compiles the computation,
which we assume yields the same executable on all hosts (follow-up
work will add more error checking). The hosts must know the global
axis size of the sharded array, e.g. to provide the correct replica
count to XLA. This is equal to the length of `devices` if specified,
but if not, pmap is recursively called (with `devices` specified) to
use `psum` to compute the global axis size.
2019-09-26 14:44:16 -07:00
Stephan Hoyer
fd975b61e6 Merge branch 'master' into solvers 2019-09-24 22:46:27 -07:00
Skye Wanderman-Milne
ad03bafb9b Change jit to take Device object instead of device ordinal.
This also changes the name of the argument to `device` from `device_assignment`.
2019-09-24 18:28:08 -07:00
Roy Frostig
d1c66614e8 add a "last" symbol for vmap axis specs, use it in api.jacfwd. tests and fixes #1372
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-09-23 13:35:52 -07:00
Matthew Johnson
283299649b add a 'monomorphic dim' symbol, bug fixes 2019-09-15 11:10:05 -07:00
Matthew Johnson
78c70ecd0c add dynamic shape envs 2019-09-15 11:10:04 -07:00
Stephan Hoyer
12509c913c Remove jax.api._custom_implicit_solve 2019-09-11 21:33:48 -07:00
Matthew Johnson
772fdb8c4e move automasking prototype into jax/interpreters
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
samuela
b558afb671
Prettify code example docs for disable_jit 2019-09-02 16:30:07 -07:00
Matthew Johnson
dbe56c30ac leave todos for better error messages 2019-08-26 14:05:17 -07:00
Matthew Johnson
6a81d81d9b fix for custom-transforms vjp nones bug
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-26 13:38:08 -07:00
Skye Wanderman-Milne
ae835b747e Add jax.devices() and friends, and add devices arg to pmap.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
    will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
    should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Anselm Levskaya
8a78c92b89 Merge branch 'master' into multibackend 2019-08-22 21:59:47 -07:00
Anselm Levskaya
685ca6765e resolve merge conflicts with master 2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47 Merge branch 'master' into multibackend 2019-08-22 19:52:29 -07:00
Matthew Johnson
de6ce2a555 allow vmap in_axes to be lists 2019-08-22 12:50:47 -07:00
Matthew Johnson
8517997518 minor fixes from trax, revise eval_shape api 2019-08-21 20:36:47 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00