248 Commits

Author SHA1 Message Date
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
Matthew Johnson
13a17286df
stop_gradient_p -> ad_util.py, re-enable some mypy (#2806) 2020-04-23 13:12:24 -07:00
Matthew Johnson
2e34dbc188 update travis to match min jaxlib version 2020-04-21 19:04:28 -07:00
Matthew Johnson
1bcaef142f
apply is_stable=True to sort translation rules (#2789)
fixes #2779
2020-04-21 17:47:28 -07:00
Stephan Hoyer
e6f0b8d87d
Raise an error if stop_gradient is called on non-arrays (#2750)
* Raise an error if stop_gradient is called on non-arrays

* Fix incorrect usage of stop_gradient in solve()

* fix *other* misuse of stop_gradient
2020-04-17 12:42:53 -07:00
Peter Hawkins
9a5b8d626a
Assert that reduction computations don't have constants. (#2754)
This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.
2020-04-17 14:38:50 -04:00
Skye Wanderman-Milne
0e29bd4ba3
Fix some bugs in _reshape_sharded_device_array (#2732) 2020-04-15 18:43:46 -07:00
Skye Wanderman-Milne
07571ae4dd
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142)
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.

This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.

Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).

**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
  nargs    nshards       mean       %std    relative    mean/baseline
-------  ---------  ---------  ---------  ----------  ---------------
     10          8  0.041855    4.15223      1               1.01466
    100          8  0.129884    4.85321      3.1032          0.988543
    101          8  0.136347    6.20233      3.2576          0.967138
    500          8  0.533207    3.6815      12.7394          1.0294
   1000          8  1.10338     0.525193    26.362           0.960435
   5000          8  5.33911     0          127.562           0.963319
    100          2  0.0638619  10.7069       1.52579         1.0362
    100          4  0.0868253   6.76701      2.07443         0.967323
    100          8  0.128151    6.46004      3.06177         0.979742
    100        100  1.22631     1.94885     29.299           1.00371
    100        500  6.60746     0          157.865           0.956657
```
**pmap_shard_outputs**
```
  nouts    nshards        mean       %std    relative    mean/baseline
-------  ---------  ----------  ---------  ----------  ---------------
     10          8   0.0664526   9.49251      1               0.938466
    100          8   0.195711    2.19429      2.94512         1.04239
    500          8   0.82577     0.330864    12.4265          0.994669
   1000          8   1.68323     1.0516      25.3298          0.966915
   5000          8   8.89032     0          133.784           0.998038
    100          2   0.074806   10.1734       1.12571         0.980254
    100          4   0.121334    5.76774      1.82588         1.02033
    100          8   0.185253    5.45068      2.78775         1.01666
    100        100   2.37076     0           35.6759          1.08629
    100        500  17.0832      0          257.074           0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn                mean     %std    relative    mean/baseline
------------------  ----------  -------  ----------  ---------------
integer_indices      0.0603473  8.29159       1             0.359496
integer_2D_indices  18.0241     0           298.672         1.00583
```

This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
2020-04-15 12:43:55 -07:00
Peter Hawkins
714b276b9a
Implement jax.ops.index_mul. (#2696)
* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.
2020-04-13 16:16:34 -04:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
Peter Hawkins
a3cc9a7d32
Remove a workaround for a long-fixed XLA type conversion bug. (#2670) 2020-04-10 11:45:27 -04:00
Peter Hawkins
1bb67637ca
Add batch_group_count to conv_general_dilated. (#2635)
* Add batch_group_count to conv_general_dilated.

* Use batch_group_count for RHS grouped convolution transpose rule.

* Implement lhs/rhs transpose and batching rules for batch_group_count convolution.
2020-04-09 16:21:30 -04:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Peter Hawkins
5174f6d5dc
Add type annotations to user-facing functions in lax.py (#2644)
* Add type annotations to user-facing functions in lax.py

* Remove redundant comment.
2020-04-08 14:13:15 -04:00
Peter Hawkins
fa383b4a9f
Mark primitive parameters as keyword-only arguments in rules in lax.py. (#2625)
* Mark primitive parameters as keyword-only arguments in rules in lax.py.

* Fix dynamic update slice batching rule.

* Fix dynamic slice batching rule.
2020-04-07 09:38:10 -04:00
Peter Hawkins
cf4dd84b14
cumsum is linear, so its gradient can be linear also. (#2618)
* cumsum is linear, so its gradient can be linear also.

* Rename _impl functions to _prefix_scan.
2020-04-06 15:14:22 -04:00
Peter Hawkins
36c529d4e3
Handle n==0 case in TPU cumsum/cumprod. (#2617) 2020-04-06 12:33:55 -04:00
Peter Hawkins
329321b0f1
Add backend-specific lowering for cumsum/cumprod on TPU. (#2614)
* Add backend-specific lowering for cumsum/cumprod on TPU.

Make cumsum/cumprod primitives so they can have backend-specific lowerings.

* Disable cumulative reduction gradient test on TPU.
2020-04-06 11:22:01 -04:00
Sharad Vikram
99944d1204
Fix lax.broadcast_shapes returning numpy ints in shape tuple (#2471)
* Fix lax.broadcast_shapes returning numpy ints in shape tuple

* Use _canonicalize_dimension and add test
2020-04-04 23:19:39 -07:00
Peter Hawkins
2b3befff32
Make reduce_prod differentiable to arbitrary order. (#2597)
* Make reduce_prod differentiable to arbitrary order.

The previous strategy for computing the JVP of reduce_prod used a pair of reduce_window operations to form left and right products for each position.

This PR instead builds an explicit reduction tree and differentiates through it, which while not as efficient as using XLA's built-in reductions, has the advantage of being differentiable to arbitrary order.
.

* Return the tree-reduction primals instead of returning the original primals in JVP rule.
2020-04-03 16:09:48 -04:00
Matthew Johnson
bdc0c3bf43 remove remat context check, add initial staging 2020-03-29 23:29:55 -07:00
Matthew Johnson
1b5978953b add ShardedDeviceArray to ad vspace op handlers
fixes #2529 (thanks, @dpfau !)
2020-03-28 11:56:12 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Srinivas Vasudevan
c7f211d433
Update JAX to use XLA hyperbolic functions. (#2415) 2020-03-19 10:29:37 -04:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Matthew Johnson
f1d9130f25 remove safe_mul (undo #383, also cf. #1052) 2020-03-17 22:07:53 -07:00
George Necula
0ddc2ec360 Fixed failing tests 2020-03-17 06:51:01 +01:00
George Necula
5cf82c756e Improved argument checking for lax.broadcast_in_dim
* Added checking that the output shape has higher or equal rank to input
* Added checking that the broadcast_dims are sorted (required by XLA)
* Relaxed check that operand dimension size can be 1
* Added lax.broadcast_in_dim docstring
2020-03-17 06:51:01 +01:00
Matthew Johnson
c0c3a4a506
Merge pull request #2401 from hawkinsp/ones
Check for invalid shapes in broadcast_in_dim and fail gracefully.
2020-03-16 19:46:52 -07:00
Roy Frostig
6545cf3421
Merge pull request #2424 from google/broadcast-shapecheck
add lax.broadcast_in_dim shape check and test
2020-03-15 22:22:24 -07:00
Roy Frostig
94832f9627 add lax.broadcast_in_dim shape check and test
Operand dimensions must equal their corresponding dimensions in the broadcast shape.
2020-03-15 20:30:44 -07:00
Matthew Johnson
a7b3be71e8 move jet into jax.experimental 2020-03-15 11:10:56 -07:00
Matthew Johnson
668a1703bc add jet tests, remove top-level files 2020-03-14 21:22:10 -07:00
Jacob Kelly
840797d4a1 refactor reduce_max jet rule 2020-03-14 18:42:51 -07:00
Jacob Kelly
b4d003d460 jet rule for log
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
30830dfc25 linear rule for sub
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
dcebe50562 jet for reduce_max
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
3bcf02a191 Add gather rule
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
098aabefcd fix typo 2020-03-14 18:42:51 -07:00
Jacob Kelly
ddd52c4730 adding div and linear prims
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Matthew Johnson
7adf9fe84f add more jet rules!
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
a21fdf8669 more jet rules and tests
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
e84a621184 new jet implementation, with conv-based rules
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
7f0463e2c9
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,) ] a
    in b }

The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!

Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,)
                        input_shape=(3,) ] a
    in b }

That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)

But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!

That's exactly what this commit does!

Co-authored-by: Roy Frostig <frostig@google.com>

Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
Peter Hawkins
419961f9dd Check for invalid shapes in broadcast_in_dim and fail gracefully. 2020-03-11 09:57:20 -04:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
Skye Wanderman-Milne
efa0315c8f
[docs] Add docstring for jax.lax.tie_in (#2364) 2020-03-05 16:21:19 -08:00
Peter Hawkins
0416d2a5f2
Fix abstract evaluation rule for lax.top_k. (#2290) 2020-02-24 07:31:46 -08:00
Peter Hawkins
af0967fdbf
Add an experimental lax.top_k operator. (#2280) 2020-02-20 17:15:25 -08:00
George Necula
ceab1e3edf Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4e25d039a76d99af97374e7ece8c1c63a3.

Issue: #2245
2020-02-17 17:56:56 +01:00