981 Commits

Author SHA1 Message Date
Peter Hawkins
612ffd0687 Change block_until_ready() to return self rather than nothing. 2019-09-05 10:16:20 -04:00
Skye Wanderman-Milne
76156e53a1 Temporarily disable test_jit_device_assignment.
Future breaking changes to jaxlib will break this test, so disable it
until we update jaxlib and then can update jax and reenable the test.
2019-09-04 16:25:04 -07:00
Matthew Johnson
6a81d81d9b fix for custom-transforms vjp nones bug
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-26 13:38:08 -07:00
Matthew Johnson
f639b808c4 instantiate zeros for custom vjp rules 2019-08-25 19:59:50 -07:00
Matthew Johnson
afe21bafa4 address reviewer comments 2019-08-24 12:34:44 -07:00
Matthew Johnson
e90457d737 add dtype warnings to array-creation routines
fixes #1230
2019-08-24 08:19:05 -07:00
Matthew Johnson
d700716e19 add option to disable rank-promotion broadcasting
fixes #1236
2019-08-23 18:13:18 -07:00
Matthew Johnson
8517997518 minor fixes from trax, revise eval_shape api 2019-08-21 20:36:47 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
c53c8bbb43 Some progress de-tupling ad.py 2019-08-21 07:01:07 -07:00
Matthew Johnson
f2fe49f99a
Merge pull request #1149 from shoyer/custom-implicit-solve
Helper function for defining differentiable solves.
2019-08-13 10:25:41 -07:00
Stephan Hoyer
3604293ba4 changes per review 2019-08-12 11:21:10 -07:00
Peter Hawkins
6dc730a5f4 Make JAX tracer state thread-local. Allows performing traces in separate threads.
Using threading within a traced context still won't work, but that is perhaps less important than the ability to call JIT-ted computations from separate threads.

(Revives https://github.com/google/jax/pull/734.)
2019-08-09 13:55:20 -04:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
Stephan Hoyer
eb1516e992 Fix math in custom_implicit_solve 2019-08-09 09:15:24 -07:00
Stephan Hoyer
f6be9c0983 automatic derivs 2019-08-08 18:38:41 -07:00
Stephan Hoyer
eb4bfb4370 Helper function for defining differentiable solves.
``custom_implicit_solve`` is a helper function designed help library authors
(most notably for ``jax.scipy.optimize``) define derivatives for functions that
perform an implicit solve.
2019-08-08 18:34:03 -07:00
Peter Hawkins
6fd597bd13 Don't call _check_args in jit/pmap.
Instead, improve the error from xla.abstractify to match the one from _check_args.
This saves abstractifying values twice.
2019-08-05 15:03:50 -04:00
Matthew Johnson
8b329ec44c
Merge branch 'master' into jit-device-placement-api 2019-07-24 20:35:43 +01:00
Matthew Johnson
3f9c001c33 add ShardedDeviceTuple constant handler, fixes #1062 2019-07-24 21:45:56 +03:00
Matthew Johnson
75150bf335 document int device_assignment argument of jit
fix int/long bug
2019-07-24 21:36:13 +03:00
Matthew Johnson
94f2b60de2
Merge branch 'master' into jit-device-placement-api 2019-07-24 18:35:24 +01:00
Matthew Johnson
c42665c5e9 first cut at jit device_assignment api
make execute_primitive put args on correct device
2019-07-24 20:24:47 +03:00
Peter Hawkins
cfeb20d290 Check that function arguments to APIs like jit are callable. 2019-07-23 17:03:28 -04:00
Matthew Johnson
fb1e2124ff enable staging out more multi-replica computations
There are two real changes here:

1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).

2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.

The other changes are unrelated miscellania.
2019-07-09 15:12:02 -07:00
Matthew Johnson
705d49f519 enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:

1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives

By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!

The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has

* the pmap impl, particularly the dispatching logic for top-level pmaps,
  including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
  and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap

These things moved over to xla.py

* the logic for lowering pmap primitives, especially the static axis
  environment used during xla lowering

This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
  i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
  strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
  parallel collectives, and so it has rules with signature
  `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
  control flow primitives (like `scan`), for which the translation rules
  themselves lower jaxprs to XLA computations and thus require the static axis
  env to be passed in; the rules there have signature
  `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
  i.e. the primitives underlying `jit` and `pmap` respectively, and has
  rules with signature
  `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`

Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).

This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.

This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-05 16:39:46 -07:00
Jamie Townsend
ec3fb89d1f Test for defvjp closure error 2019-06-27 17:39:42 +01:00
Jamie Townsend
323d9f51dc Raise error when differentiating w.r.t. outer variable with defjvp_all 2019-06-27 15:35:12 +01:00
Jamie Townsend
9ad18a8a54 Add failing custom_transforms closure test 2019-06-26 14:21:03 +01:00
Peter Hawkins
f193bc7d11 Make error message test more permissive. 2019-06-24 11:29:06 -04:00
Peter Hawkins
5bdbcc42d5 Address review comments, add a test. 2019-06-24 10:45:42 -04:00
Matthew Johnson
e939e7291a lower args in JaxprTrace.process_eval 2019-06-18 21:23:52 -07:00
Matthew Johnson
eb01b8bfef improve linearize error message
fixes #871
2019-06-18 09:31:50 -07:00
Matthew Johnson
96775a4d40 add tuple simplification logic
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-06-17 18:19:55 -07:00
Matthew Johnson
7c308dca29 implement reviewer suggestions 2019-06-11 06:45:09 -07:00
Matthew Johnson
720dec4072 add custom_gradient 2019-06-05 13:48:04 -07:00
Matthew Johnson
35e5e64416 make custom_transforms handle pytrees, add api.defvjp
With advice from @dougalm!
2019-06-05 12:29:25 -07:00
Peter Hawkins
b80bbe41f6 Move block_until_ready into DeviceValue. 2019-06-03 12:37:08 -04:00
Peter Hawkins
fbe701fed0 Add a block_until_ready method to DeviceArray. 2019-06-03 12:05:28 -04:00
Matthew Johnson
fadd18b36c namedtuple subclass transparency (fixes #806) 2019-06-03 07:22:32 -07:00
Matthew Johnson
dda95df519 fix duck typing in jax.eval_shape (cf. #798) 2019-06-01 09:50:24 -07:00
Matthew Johnson
11c512a194 add jax.eval_shape, fixes #798 2019-06-01 09:36:46 -07:00
Peter Hawkins
af67447f21 Add an explicit DeviceArray.delete method that deletes both the device and host parts of a DeviceArray.
Fixes #725.
2019-05-30 09:54:52 -04:00
Matthew Johnson
9c931ddebe allow more types to be jaxpr literals, fixes #772 2019-05-28 22:38:06 -07:00
Matthew Johnson
ca66c7693e add test for namedtuple transparency 2019-05-20 10:15:20 -07:00
Peter Hawkins
367833bea2 Changes for compatibility with a upcoming Jaxlib update.
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
2019-05-08 20:32:24 -04:00
Matthew Johnson
11aa1a583e
Merge branch 'master' into device-tuples 2019-05-03 08:38:42 -07:00
Matthew Johnson
7fc3f3f704 fix legacy numpy issue with DeviceArray.__repr__ 2019-05-03 08:14:03 -07:00
Matthew Johnson
7c5d683915 revise sharded result handling, misc cleanup 2019-05-03 08:06:55 -07:00
Matthew Johnson
ddd29e724e fix DeviceArray.__repr__ for complex dtypes, test
c.f. #666
2019-05-02 19:27:22 -07:00