350 Commits

Author SHA1 Message Date
Matthew Johnson
75150bf335 document int device_assignment argument of jit
fix int/long bug
2019-07-24 21:36:13 +03:00
Matthew Johnson
94f2b60de2
Merge branch 'master' into jit-device-placement-api 2019-07-24 18:35:24 +01:00
Matthew Johnson
c42665c5e9 first cut at jit device_assignment api
make execute_primitive put args on correct device
2019-07-24 20:24:47 +03:00
Peter Hawkins
cfeb20d290 Check that function arguments to APIs like jit are callable. 2019-07-23 17:03:28 -04:00
Matthew Johnson
eecb4d9fbe
Merge pull request #1035 from shoyer/restore-jit-is-disabled
Always restore _jit_is_disabled
2019-07-19 06:44:04 +01:00
Stephan Hoyer
88589a5e4d Always restore _jit_is_disabled
Otherwise if you get an exception inside a ``disable_jit()`` context (not
uncommon if debugging), ``jit`` is disabled permanently!
2019-07-18 19:47:49 -07:00
Matthew Johnson
4c34541c00 raise error when vmap used with kwargs (#912) 2019-07-17 23:25:55 -07:00
Roy Frostig
d2c5fd2c8c
Merge pull request #1018 from shoyer/defjvp_all-doc
DOC: fix docstring for defjvp_all
2019-07-15 10:11:50 -07:00
Stephan Hoyer
5a782b081b DOC: fix docstring for defjvp_all
The form of the arguments for custom_jvp was described inconsistently from the
examples (and the code).
2019-07-13 20:08:46 -07:00
Matthew Johnson
fb1e2124ff enable staging out more multi-replica computations
There are two real changes here:

1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).

2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.

The other changes are unrelated miscellania.
2019-07-09 15:12:02 -07:00
Matthew Johnson
7eb168928c fix typo 2019-07-05 17:15:01 -07:00
Matthew Johnson
705d49f519 enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:

1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives

By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!

The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has

* the pmap impl, particularly the dispatching logic for top-level pmaps,
  including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
  and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap

These things moved over to xla.py

* the logic for lowering pmap primitives, especially the static axis
  environment used during xla lowering

This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
  i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
  strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
  parallel collectives, and so it has rules with signature
  `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
  control flow primitives (like `scan`), for which the translation rules
  themselves lower jaxprs to XLA computations and thus require the static axis
  env to be passed in; the rules there have signature
  `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
  i.e. the primitives underlying `jit` and `pmap` respectively, and has
  rules with signature
  `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`

Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).

This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.

This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-05 16:39:46 -07:00
Matthew Johnson
527fe14838 fix simple static_argnums bug 2019-07-05 07:47:38 -07:00
Matthew Johnson
c13e816f6c
Merge pull request #976 from zhongwen/master
Convert int input to static_argnums to a tuple
2019-07-04 10:57:58 -07:00
Zhongwen Xu
f3a741ef28
Update api.py 2019-07-04 18:43:19 +01:00
Zhongwen Xu
d04954d31f
Convert int input to static_argnums to a tuple
User could make mistake of passing an int to static_argnums, this helps to avoid unnecessary error.
2019-07-04 18:29:35 +01:00
Jamie Townsend
bf367d3a56 Set instantiate=True in custom_transforms translation rule 2019-07-03 08:13:34 +01:00
Jamie Townsend
ffa43b895b Update signature fo lower_fun 2019-07-03 08:00:00 +01:00
Jamie Townsend
f320e23b31 Merge branch 'master' into custom-transforms 2019-07-03 07:51:05 +01:00
Peter Hawkins
8432bff3d4 Implement device_put as a primitive.
Uses the common dispatch logic rather than an explicit isinstance(..., Tracer) test.
2019-07-02 12:18:47 -04:00
Jamie Townsend
03d6d0a5bc Add note to custom_transorms docstring 2019-07-02 13:47:59 +01:00
Peter Hawkins
6647c505aa Add a more direct implementation of device_put.
This implementation copies tensors to device without building an XLA computation. XLA compilation may take time superlinear in the number of arguments, but there's no good reason for us to build a computation at all, it was merely a convenient way to implement `device_put` for free. Instead, when the argument to `device_put` isn't a tracer, call `xla.device_put` directly. If it is a tracer, fall back to the old implementation.

Timing for benchmark in #947:
In [4]: %time x = jax.device_put([onp.random.randn(10,5) for _ in range(700)])
CPU times: user 45.3 ms, sys: 3.79 ms, total: 49.1 ms
Wall time: 33.8 ms
where the timing was previously 43.4s.

Fixes #947.
2019-07-01 17:02:17 -04:00
Jamie Townsend
f76a1c9639 Add out of scope error to defvjp 2019-06-27 17:35:34 +01:00
Jamie Townsend
323d9f51dc Raise error when differentiating w.r.t. outer variable with defjvp_all 2019-06-27 15:35:12 +01:00
Jamie Townsend
31fa0412ea Used shaped aval for custom_transforms jaxpr 2019-06-27 14:13:20 +01:00
Jamie Townsend
3d1ba30f2e A few custom_transforms fix-ups 2019-06-26 16:49:04 +01:00
Jamie Townsend
3cdcd9ce93 Draft fix for custom_transform of closure 2019-06-26 16:22:21 +01:00
Matthew Johnson
d64188bcb6 del serial_pmap, simpler papply, add parallelize
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.

The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.

This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.

Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
  doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
  broadcasting_papply rule and plumbing the `size` argument to papply
  rules
* remove psplit primitive and psplit_like primitives and replace it with
  calls to all_to_all where needed
2019-06-24 19:38:26 -07:00
Matthew Johnson
fe7329e808 iniital soft_pmap implementation 2019-06-24 19:34:48 -07:00
Peter Hawkins
5bdbcc42d5 Address review comments, add a test. 2019-06-24 10:45:42 -04:00
Peter Hawkins
f2bc287865 Verify that the inputs to reverse-mode automatic differentiation are of an inexact type. 2019-06-21 11:05:52 -04:00
Matthew Johnson
8bc4e379f5 make DeviceArray.__hash__ raise an error
Fixes #883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.

Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
2019-06-19 10:12:13 -07:00
Matthew Johnson
eb01b8bfef improve linearize error message
fixes #871
2019-06-18 09:31:50 -07:00
Matthew Johnson
7c308dca29 implement reviewer suggestions 2019-06-11 06:45:09 -07:00
Matthew Johnson
121d78129b docstring improvements from @skye comments 2019-06-06 10:12:07 -07:00
Matthew Johnson
12b90ddf68 fix typo 2019-06-05 19:19:34 -07:00
Matthew Johnson
a6c41a323c finish drafting defvjp/defjvp docstrings 2019-06-05 19:13:33 -07:00
Matthew Johnson
cfaa49f884 improve custom_gradient docstring 2019-06-05 18:02:15 -07:00
Matthew Johnson
948ec8fbe8 add docstrings for defjvp and defjvp2 2019-06-05 17:56:18 -07:00
Matthew Johnson
ab20f0292c add docstring for defjvp_all 2019-06-05 17:34:14 -07:00
Matthew Johnson
720dec4072 add custom_gradient 2019-06-05 13:48:04 -07:00
Matthew Johnson
372d60bb08 add docstring to custom_transforms 2019-06-05 13:20:44 -07:00
Matthew Johnson
35e5e64416 make custom_transforms handle pytrees, add api.defvjp
With advice from @dougalm!
2019-06-05 12:29:25 -07:00
Matthew Johnson
dda95df519 fix duck typing in jax.eval_shape (cf. #798) 2019-06-01 09:50:24 -07:00
Matthew Johnson
11c512a194 add jax.eval_shape, fixes #798 2019-06-01 09:36:46 -07:00
Matthew Johnson
93e1143373 improve disable_jit docstring 2019-06-01 08:30:25 -07:00
Matthew Johnson
5a049feea9 further improve grad-of-nonscalar error message 2019-05-28 21:10:09 -07:00
Matthew Johnson
743727325f improve error message for grad of non-scalar funs 2019-05-28 20:25:38 -07:00
Matthew Johnson
a193b3592b when xla_computation sees no kwargs, don't make () 2019-05-22 14:39:34 -07:00
Matthew Johnson
b6031ffdd7 avoid packing leaf outputs for jit/pmap funs 2019-05-17 07:36:52 -07:00