171 Commits

Author SHA1 Message Date
Peter Hawkins
c84cb34ac7 Readd coding declaration to api.py 2019-08-13 15:55:06 -04:00
Peter Hawkins
f898ac1a00 Remove spurious coding line from api.py 2019-08-13 15:39:50 -04:00
Matthew Johnson
f2fe49f99a
Merge pull request #1149 from shoyer/custom-implicit-solve
Helper function for defining differentiable solves.
2019-08-13 10:25:41 -07:00
Peter Hawkins
cc618f4bf4 Don't perform static/dynamic argument splitting if static_argnums was not passed. 2019-08-13 10:32:13 -04:00
Matthew Johnson
4a4304e08c make pmap read axis size from kwargs
fixes #1170
2019-08-12 18:03:25 -07:00
Stephan Hoyer
3604293ba4 changes per review 2019-08-12 11:21:10 -07:00
Roy Frostig
a657922309
Merge pull request #1132 from necula01/update_doc
Update documentation for jax.vmap.
2019-08-12 07:53:33 -07:00
George Necula
35ca6bae4e Fixed vmap documentation 2019-08-12 12:04:14 +02:00
Peter Hawkins
335dbe8285 Remove device_values=False support from jit.
This is no longer needed now `device_get` is not implemented via `jit`.
2019-08-10 15:17:24 -04:00
Peter Hawkins
8fd6e19ef6 More concurrency fixes. 2019-08-09 15:08:26 -04:00
Stephan Hoyer
fb610a1304 more descriptive name for dummy variable 2019-08-09 09:24:55 -07:00
Stephan Hoyer
eb1516e992 Fix math in custom_implicit_solve 2019-08-09 09:15:24 -07:00
Stephan Hoyer
eb4bfb4370 Helper function for defining differentiable solves.
``custom_implicit_solve`` is a helper function designed help library authors
(most notably for ``jax.scipy.optimize``) define derivatives for functions that
perform an implicit solve.
2019-08-08 18:34:03 -07:00
George Necula
9007d67109 Update documentation for jax.vmap. 2019-08-07 09:26:46 +02:00
Peter Hawkins
6fd597bd13 Don't call _check_args in jit/pmap.
Instead, improve the error from xla.abstractify to match the one from _check_args.
This saves abstractifying values twice.
2019-08-05 15:03:50 -04:00
Peter Hawkins
c302e38880 Fix test failures. 2019-08-01 17:20:27 -04:00
Peter Hawkins
c41677fac7
Merge pull request #1073 from hawkinsp/deviceget
Avoid building an identity computation in jax.device_get().
2019-08-01 12:47:04 -04:00
Peter Hawkins
8256836622 Avoid building an identity computation in jax.device_get().
Instead, directly copy values to the host.
2019-07-29 11:55:10 -04:00
Peter Hawkins
476dc3db64 Python changes in preparation for adding a C++ implementation of the PyTree utilities. 2019-07-29 10:57:27 -04:00
Matthew Johnson
0546c94992 speed up pmap axis-size getting
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-07-25 12:41:31 -07:00
Matthew Johnson
dbb907c8bc add warning that device_assignment api is unstable 2019-07-24 22:03:19 +03:00
Matthew Johnson
75150bf335 document int device_assignment argument of jit
fix int/long bug
2019-07-24 21:36:13 +03:00
Matthew Johnson
94f2b60de2
Merge branch 'master' into jit-device-placement-api 2019-07-24 18:35:24 +01:00
Matthew Johnson
c42665c5e9 first cut at jit device_assignment api
make execute_primitive put args on correct device
2019-07-24 20:24:47 +03:00
Peter Hawkins
cfeb20d290 Check that function arguments to APIs like jit are callable. 2019-07-23 17:03:28 -04:00
Matthew Johnson
eecb4d9fbe
Merge pull request #1035 from shoyer/restore-jit-is-disabled
Always restore _jit_is_disabled
2019-07-19 06:44:04 +01:00
Stephan Hoyer
88589a5e4d Always restore _jit_is_disabled
Otherwise if you get an exception inside a ``disable_jit()`` context (not
uncommon if debugging), ``jit`` is disabled permanently!
2019-07-18 19:47:49 -07:00
Matthew Johnson
4c34541c00 raise error when vmap used with kwargs (#912) 2019-07-17 23:25:55 -07:00
Roy Frostig
d2c5fd2c8c
Merge pull request #1018 from shoyer/defjvp_all-doc
DOC: fix docstring for defjvp_all
2019-07-15 10:11:50 -07:00
Stephan Hoyer
5a782b081b DOC: fix docstring for defjvp_all
The form of the arguments for custom_jvp was described inconsistently from the
examples (and the code).
2019-07-13 20:08:46 -07:00
Matthew Johnson
fb1e2124ff enable staging out more multi-replica computations
There are two real changes here:

1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).

2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.

The other changes are unrelated miscellania.
2019-07-09 15:12:02 -07:00
Matthew Johnson
7eb168928c fix typo 2019-07-05 17:15:01 -07:00
Matthew Johnson
705d49f519 enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:

1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives

By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!

The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has

* the pmap impl, particularly the dispatching logic for top-level pmaps,
  including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
  and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap

These things moved over to xla.py

* the logic for lowering pmap primitives, especially the static axis
  environment used during xla lowering

This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
  i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
  strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
  parallel collectives, and so it has rules with signature
  `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
  control flow primitives (like `scan`), for which the translation rules
  themselves lower jaxprs to XLA computations and thus require the static axis
  env to be passed in; the rules there have signature
  `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
  i.e. the primitives underlying `jit` and `pmap` respectively, and has
  rules with signature
  `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`

Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).

This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.

This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-05 16:39:46 -07:00
Matthew Johnson
527fe14838 fix simple static_argnums bug 2019-07-05 07:47:38 -07:00
Matthew Johnson
c13e816f6c
Merge pull request #976 from zhongwen/master
Convert int input to static_argnums to a tuple
2019-07-04 10:57:58 -07:00
Zhongwen Xu
f3a741ef28
Update api.py 2019-07-04 18:43:19 +01:00
Zhongwen Xu
d04954d31f
Convert int input to static_argnums to a tuple
User could make mistake of passing an int to static_argnums, this helps to avoid unnecessary error.
2019-07-04 18:29:35 +01:00
Jamie Townsend
bf367d3a56 Set instantiate=True in custom_transforms translation rule 2019-07-03 08:13:34 +01:00
Jamie Townsend
ffa43b895b Update signature fo lower_fun 2019-07-03 08:00:00 +01:00
Jamie Townsend
f320e23b31 Merge branch 'master' into custom-transforms 2019-07-03 07:51:05 +01:00
Peter Hawkins
8432bff3d4 Implement device_put as a primitive.
Uses the common dispatch logic rather than an explicit isinstance(..., Tracer) test.
2019-07-02 12:18:47 -04:00
Jamie Townsend
03d6d0a5bc Add note to custom_transorms docstring 2019-07-02 13:47:59 +01:00
Peter Hawkins
6647c505aa Add a more direct implementation of device_put.
This implementation copies tensors to device without building an XLA computation. XLA compilation may take time superlinear in the number of arguments, but there's no good reason for us to build a computation at all, it was merely a convenient way to implement `device_put` for free. Instead, when the argument to `device_put` isn't a tracer, call `xla.device_put` directly. If it is a tracer, fall back to the old implementation.

Timing for benchmark in #947:
In [4]: %time x = jax.device_put([onp.random.randn(10,5) for _ in range(700)])
CPU times: user 45.3 ms, sys: 3.79 ms, total: 49.1 ms
Wall time: 33.8 ms
where the timing was previously 43.4s.

Fixes #947.
2019-07-01 17:02:17 -04:00
Jamie Townsend
f76a1c9639 Add out of scope error to defvjp 2019-06-27 17:35:34 +01:00
Jamie Townsend
323d9f51dc Raise error when differentiating w.r.t. outer variable with defjvp_all 2019-06-27 15:35:12 +01:00
Jamie Townsend
31fa0412ea Used shaped aval for custom_transforms jaxpr 2019-06-27 14:13:20 +01:00
Jamie Townsend
3d1ba30f2e A few custom_transforms fix-ups 2019-06-26 16:49:04 +01:00
Jamie Townsend
3cdcd9ce93 Draft fix for custom_transform of closure 2019-06-26 16:22:21 +01:00
Matthew Johnson
d64188bcb6 del serial_pmap, simpler papply, add parallelize
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.

The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.

This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.

Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
  doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
  broadcasting_papply rule and plumbing the `size` argument to papply
  rules
* remove psplit primitive and psplit_like primitives and replace it with
  calls to all_to_all where needed
2019-06-24 19:38:26 -07:00
Matthew Johnson
fe7329e808 iniital soft_pmap implementation 2019-06-24 19:34:48 -07:00