Stephan Hoyer
12509c913c
Remove jax.api._custom_implicit_solve
2019-09-11 21:33:48 -07:00
Matthew Johnson
772fdb8c4e
move automasking prototype into jax/interpreters
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
samuela
b558afb671
Prettify code example docs for disable_jit
2019-09-02 16:30:07 -07:00
Matthew Johnson
dbe56c30ac
leave todos for better error messages
2019-08-26 14:05:17 -07:00
Matthew Johnson
6a81d81d9b
fix for custom-transforms vjp nones bug
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-26 13:38:08 -07:00
Skye Wanderman-Milne
ae835b747e
Add jax.devices() and friends, and add devices
arg to pmap.
...
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Anselm Levskaya
8a78c92b89
Merge branch 'master' into multibackend
2019-08-22 21:59:47 -07:00
Anselm Levskaya
685ca6765e
resolve merge conflicts with master
2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47
Merge branch 'master' into multibackend
2019-08-22 19:52:29 -07:00
Matthew Johnson
de6ce2a555
allow vmap in_axes to be lists
2019-08-22 12:50:47 -07:00
Matthew Johnson
8517997518
minor fixes from trax, revise eval_shape api
2019-08-21 20:36:47 -07:00
Matthew Johnson
b702f8de3e
De-tuplify the rest of the core
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
3c37a3260a
Update linearize to no-tuple version
2019-08-21 07:01:07 -07:00
Dougal Maclaurin
c53c8bbb43
Some progress de-tupling ad.py
2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56
Start exploring jaxprs without tuples
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
dbe4bdf944
add multibackend option to soft_pmap
2019-08-21 01:06:04 -07:00
Anselm Levskaya
f01fc35ce5
Make op-by-op work with all jit-returned devicearrays.
2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013
WIP: experimental multibackend jit
2019-08-19 23:45:36 -07:00
Peter Hawkins
c84cb34ac7
Readd coding declaration to api.py
2019-08-13 15:55:06 -04:00
Peter Hawkins
f898ac1a00
Remove spurious coding line from api.py
2019-08-13 15:39:50 -04:00
Matthew Johnson
f2fe49f99a
Merge pull request #1149 from shoyer/custom-implicit-solve
...
Helper function for defining differentiable solves.
2019-08-13 10:25:41 -07:00
Peter Hawkins
cc618f4bf4
Don't perform static/dynamic argument splitting if static_argnums
was not passed.
2019-08-13 10:32:13 -04:00
Matthew Johnson
4a4304e08c
make pmap read axis size from kwargs
...
fixes #1170
2019-08-12 18:03:25 -07:00
Stephan Hoyer
3604293ba4
changes per review
2019-08-12 11:21:10 -07:00
Roy Frostig
a657922309
Merge pull request #1132 from necula01/update_doc
...
Update documentation for jax.vmap.
2019-08-12 07:53:33 -07:00
George Necula
35ca6bae4e
Fixed vmap documentation
2019-08-12 12:04:14 +02:00
Peter Hawkins
335dbe8285
Remove device_values=False support from jit
.
...
This is no longer needed now `device_get` is not implemented via `jit`.
2019-08-10 15:17:24 -04:00
Peter Hawkins
8fd6e19ef6
More concurrency fixes.
2019-08-09 15:08:26 -04:00
Stephan Hoyer
fb610a1304
more descriptive name for dummy variable
2019-08-09 09:24:55 -07:00
Stephan Hoyer
eb1516e992
Fix math in custom_implicit_solve
2019-08-09 09:15:24 -07:00
Stephan Hoyer
eb4bfb4370
Helper function for defining differentiable solves.
...
``custom_implicit_solve`` is a helper function designed help library authors
(most notably for ``jax.scipy.optimize``) define derivatives for functions that
perform an implicit solve.
2019-08-08 18:34:03 -07:00
George Necula
9007d67109
Update documentation for jax.vmap.
2019-08-07 09:26:46 +02:00
Peter Hawkins
6fd597bd13
Don't call _check_args in jit/pmap.
...
Instead, improve the error from xla.abstractify to match the one from _check_args.
This saves abstractifying values twice.
2019-08-05 15:03:50 -04:00
Peter Hawkins
c302e38880
Fix test failures.
2019-08-01 17:20:27 -04:00
Peter Hawkins
c41677fac7
Merge pull request #1073 from hawkinsp/deviceget
...
Avoid building an identity computation in jax.device_get().
2019-08-01 12:47:04 -04:00
Peter Hawkins
8256836622
Avoid building an identity computation in jax.device_get().
...
Instead, directly copy values to the host.
2019-07-29 11:55:10 -04:00
Peter Hawkins
476dc3db64
Python changes in preparation for adding a C++ implementation of the PyTree utilities.
2019-07-29 10:57:27 -04:00
Matthew Johnson
0546c94992
speed up pmap axis-size getting
...
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-07-25 12:41:31 -07:00
Matthew Johnson
dbb907c8bc
add warning that device_assignment api is unstable
2019-07-24 22:03:19 +03:00
Matthew Johnson
75150bf335
document int device_assignment argument of jit
...
fix int/long bug
2019-07-24 21:36:13 +03:00
Matthew Johnson
94f2b60de2
Merge branch 'master' into jit-device-placement-api
2019-07-24 18:35:24 +01:00
Matthew Johnson
c42665c5e9
first cut at jit device_assignment api
...
make execute_primitive put args on correct device
2019-07-24 20:24:47 +03:00
Peter Hawkins
cfeb20d290
Check that function arguments to APIs like jit
are callable.
2019-07-23 17:03:28 -04:00
Matthew Johnson
eecb4d9fbe
Merge pull request #1035 from shoyer/restore-jit-is-disabled
...
Always restore _jit_is_disabled
2019-07-19 06:44:04 +01:00
Stephan Hoyer
88589a5e4d
Always restore _jit_is_disabled
...
Otherwise if you get an exception inside a ``disable_jit()`` context (not
uncommon if debugging), ``jit`` is disabled permanently!
2019-07-18 19:47:49 -07:00
Matthew Johnson
4c34541c00
raise error when vmap used with kwargs ( #912 )
2019-07-17 23:25:55 -07:00
Roy Frostig
d2c5fd2c8c
Merge pull request #1018 from shoyer/defjvp_all-doc
...
DOC: fix docstring for defjvp_all
2019-07-15 10:11:50 -07:00
Stephan Hoyer
5a782b081b
DOC: fix docstring for defjvp_all
...
The form of the arguments for custom_jvp was described inconsistently from the
examples (and the code).
2019-07-13 20:08:46 -07:00
Matthew Johnson
fb1e2124ff
enable staging out more multi-replica computations
...
There are two real changes here:
1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).
2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.
The other changes are unrelated miscellania.
2019-07-09 15:12:02 -07:00
Matthew Johnson
7eb168928c
fix typo
2019-07-05 17:15:01 -07:00