339 Commits

Author SHA1 Message Date
Stephan Hoyer
12509c913c Remove jax.api._custom_implicit_solve 2019-09-11 21:33:48 -07:00
Matthew Johnson
772fdb8c4e move automasking prototype into jax/interpreters
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
samuela
b558afb671
Prettify code example docs for disable_jit 2019-09-02 16:30:07 -07:00
Matthew Johnson
dbe56c30ac leave todos for better error messages 2019-08-26 14:05:17 -07:00
Matthew Johnson
6a81d81d9b fix for custom-transforms vjp nones bug
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-26 13:38:08 -07:00
Skye Wanderman-Milne
ae835b747e Add jax.devices() and friends, and add devices arg to pmap.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
    will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
    should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Anselm Levskaya
8a78c92b89 Merge branch 'master' into multibackend 2019-08-22 21:59:47 -07:00
Anselm Levskaya
685ca6765e resolve merge conflicts with master 2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47 Merge branch 'master' into multibackend 2019-08-22 19:52:29 -07:00
Matthew Johnson
de6ce2a555 allow vmap in_axes to be lists 2019-08-22 12:50:47 -07:00
Matthew Johnson
8517997518 minor fixes from trax, revise eval_shape api 2019-08-21 20:36:47 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
3c37a3260a Update linearize to no-tuple version 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
c53c8bbb43 Some progress de-tupling ad.py 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
dbe4bdf944 add multibackend option to soft_pmap 2019-08-21 01:06:04 -07:00
Anselm Levskaya
f01fc35ce5 Make op-by-op work with all jit-returned devicearrays. 2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013 WIP: experimental multibackend jit 2019-08-19 23:45:36 -07:00
Peter Hawkins
c84cb34ac7 Readd coding declaration to api.py 2019-08-13 15:55:06 -04:00
Peter Hawkins
f898ac1a00 Remove spurious coding line from api.py 2019-08-13 15:39:50 -04:00
Matthew Johnson
f2fe49f99a
Merge pull request #1149 from shoyer/custom-implicit-solve
Helper function for defining differentiable solves.
2019-08-13 10:25:41 -07:00
Peter Hawkins
cc618f4bf4 Don't perform static/dynamic argument splitting if static_argnums was not passed. 2019-08-13 10:32:13 -04:00
Matthew Johnson
4a4304e08c make pmap read axis size from kwargs
fixes #1170
2019-08-12 18:03:25 -07:00
Stephan Hoyer
3604293ba4 changes per review 2019-08-12 11:21:10 -07:00
Roy Frostig
a657922309
Merge pull request #1132 from necula01/update_doc
Update documentation for jax.vmap.
2019-08-12 07:53:33 -07:00
George Necula
35ca6bae4e Fixed vmap documentation 2019-08-12 12:04:14 +02:00
Peter Hawkins
335dbe8285 Remove device_values=False support from jit.
This is no longer needed now `device_get` is not implemented via `jit`.
2019-08-10 15:17:24 -04:00
Peter Hawkins
8fd6e19ef6 More concurrency fixes. 2019-08-09 15:08:26 -04:00
Stephan Hoyer
fb610a1304 more descriptive name for dummy variable 2019-08-09 09:24:55 -07:00
Stephan Hoyer
eb1516e992 Fix math in custom_implicit_solve 2019-08-09 09:15:24 -07:00
Stephan Hoyer
eb4bfb4370 Helper function for defining differentiable solves.
``custom_implicit_solve`` is a helper function designed help library authors
(most notably for ``jax.scipy.optimize``) define derivatives for functions that
perform an implicit solve.
2019-08-08 18:34:03 -07:00
George Necula
9007d67109 Update documentation for jax.vmap. 2019-08-07 09:26:46 +02:00
Peter Hawkins
6fd597bd13 Don't call _check_args in jit/pmap.
Instead, improve the error from xla.abstractify to match the one from _check_args.
This saves abstractifying values twice.
2019-08-05 15:03:50 -04:00
Peter Hawkins
c302e38880 Fix test failures. 2019-08-01 17:20:27 -04:00
Peter Hawkins
c41677fac7
Merge pull request #1073 from hawkinsp/deviceget
Avoid building an identity computation in jax.device_get().
2019-08-01 12:47:04 -04:00
Peter Hawkins
8256836622 Avoid building an identity computation in jax.device_get().
Instead, directly copy values to the host.
2019-07-29 11:55:10 -04:00
Peter Hawkins
476dc3db64 Python changes in preparation for adding a C++ implementation of the PyTree utilities. 2019-07-29 10:57:27 -04:00
Matthew Johnson
0546c94992 speed up pmap axis-size getting
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-07-25 12:41:31 -07:00
Matthew Johnson
dbb907c8bc add warning that device_assignment api is unstable 2019-07-24 22:03:19 +03:00
Matthew Johnson
75150bf335 document int device_assignment argument of jit
fix int/long bug
2019-07-24 21:36:13 +03:00
Matthew Johnson
94f2b60de2
Merge branch 'master' into jit-device-placement-api 2019-07-24 18:35:24 +01:00
Matthew Johnson
c42665c5e9 first cut at jit device_assignment api
make execute_primitive put args on correct device
2019-07-24 20:24:47 +03:00
Peter Hawkins
cfeb20d290 Check that function arguments to APIs like jit are callable. 2019-07-23 17:03:28 -04:00
Matthew Johnson
eecb4d9fbe
Merge pull request #1035 from shoyer/restore-jit-is-disabled
Always restore _jit_is_disabled
2019-07-19 06:44:04 +01:00
Stephan Hoyer
88589a5e4d Always restore _jit_is_disabled
Otherwise if you get an exception inside a ``disable_jit()`` context (not
uncommon if debugging), ``jit`` is disabled permanently!
2019-07-18 19:47:49 -07:00
Matthew Johnson
4c34541c00 raise error when vmap used with kwargs (#912) 2019-07-17 23:25:55 -07:00
Roy Frostig
d2c5fd2c8c
Merge pull request #1018 from shoyer/defjvp_all-doc
DOC: fix docstring for defjvp_all
2019-07-15 10:11:50 -07:00
Stephan Hoyer
5a782b081b DOC: fix docstring for defjvp_all
The form of the arguments for custom_jvp was described inconsistently from the
examples (and the code).
2019-07-13 20:08:46 -07:00
Matthew Johnson
fb1e2124ff enable staging out more multi-replica computations
There are two real changes here:

1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).

2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.

The other changes are unrelated miscellania.
2019-07-09 15:12:02 -07:00
Matthew Johnson
7eb168928c fix typo 2019-07-05 17:15:01 -07:00