263 Commits

Author SHA1 Message Date
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
notEvil
969ed8085c
Add decorator for performing broadcasting inside translation rules (#2468)
* Add decorator for broadcasting at the translation rule layer.

* Fix broadcasting in igamma gradients.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-05-06 10:15:17 -04:00
Srinivas Vasudevan
e51c7d7482
Add IgammaGradA (#2504) 2020-05-05 20:10:31 -04:00
tamaranorman
04102e5b9d
Allow ConvDimensionNumbers to be passed into conv_transpose (#2915) 2020-05-04 14:02:13 -04:00
James Bradbury
1cdd8f1b99
Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896)
* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0
2020-05-01 14:37:13 -07:00
Julius Kunze
c00e9a2a52
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-05-01 12:34:29 -07:00
Tom Hennigan
0736679c33
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)
At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
2020-05-01 10:00:38 -07:00
George Necula
ac023bf28f
Fixed a few places where device sticky-ness was lost. Added FAQ (#2882)
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.

I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.

* Fix mypy annotations and updates based on comments

* Undid some changes, will make another PR
2020-05-01 10:06:59 +03:00
Peter Hawkins
0557248fbd
Check for unsupported dtypes and issue a helpful error. (#2885) 2020-04-29 14:14:49 -04:00
Anselm Levskaya
e599a25422
fix sort_key_val return type annotation, docstring 2020-04-28 10:49:17 -07:00
Anselm Levskaya
ca4e396e31
Merge pull request #2853 from levskaya/topkjvp
Add top_k jvp and batching rules and tests
2020-04-29 00:57:29 +10:00
Anselm Levskaya
dddad2a3dc Add top_k jvp and batching rules 2020-04-28 07:19:58 -07:00
Jamie Townsend
75617be803
Add population_count primitive to lax (#2753)
* add population_count primitive (needs new jaxlib)

fixes #2263

* Add popcount docs

* Add population_count to lax_reference

* Use int prng (since we're only testing uints)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 22:32:52 -07:00
Jon Malmaud
77901e9fa7
Fix lax.rng_uniform. (#2830) 2020-04-24 16:43:04 -04:00
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
Matthew Johnson
13a17286df
stop_gradient_p -> ad_util.py, re-enable some mypy (#2806) 2020-04-23 13:12:24 -07:00
Matthew Johnson
2e34dbc188 update travis to match min jaxlib version 2020-04-21 19:04:28 -07:00
Matthew Johnson
1bcaef142f
apply is_stable=True to sort translation rules (#2789)
fixes #2779
2020-04-21 17:47:28 -07:00
Stephan Hoyer
e6f0b8d87d
Raise an error if stop_gradient is called on non-arrays (#2750)
* Raise an error if stop_gradient is called on non-arrays

* Fix incorrect usage of stop_gradient in solve()

* fix *other* misuse of stop_gradient
2020-04-17 12:42:53 -07:00
Peter Hawkins
9a5b8d626a
Assert that reduction computations don't have constants. (#2754)
This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.
2020-04-17 14:38:50 -04:00
Skye Wanderman-Milne
0e29bd4ba3
Fix some bugs in _reshape_sharded_device_array (#2732) 2020-04-15 18:43:46 -07:00
Skye Wanderman-Milne
07571ae4dd
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142)
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.

This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.

Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).

**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
  nargs    nshards       mean       %std    relative    mean/baseline
-------  ---------  ---------  ---------  ----------  ---------------
     10          8  0.041855    4.15223      1               1.01466
    100          8  0.129884    4.85321      3.1032          0.988543
    101          8  0.136347    6.20233      3.2576          0.967138
    500          8  0.533207    3.6815      12.7394          1.0294
   1000          8  1.10338     0.525193    26.362           0.960435
   5000          8  5.33911     0          127.562           0.963319
    100          2  0.0638619  10.7069       1.52579         1.0362
    100          4  0.0868253   6.76701      2.07443         0.967323
    100          8  0.128151    6.46004      3.06177         0.979742
    100        100  1.22631     1.94885     29.299           1.00371
    100        500  6.60746     0          157.865           0.956657
```
**pmap_shard_outputs**
```
  nouts    nshards        mean       %std    relative    mean/baseline
-------  ---------  ----------  ---------  ----------  ---------------
     10          8   0.0664526   9.49251      1               0.938466
    100          8   0.195711    2.19429      2.94512         1.04239
    500          8   0.82577     0.330864    12.4265          0.994669
   1000          8   1.68323     1.0516      25.3298          0.966915
   5000          8   8.89032     0          133.784           0.998038
    100          2   0.074806   10.1734       1.12571         0.980254
    100          4   0.121334    5.76774      1.82588         1.02033
    100          8   0.185253    5.45068      2.78775         1.01666
    100        100   2.37076     0           35.6759          1.08629
    100        500  17.0832      0          257.074           0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn                mean     %std    relative    mean/baseline
------------------  ----------  -------  ----------  ---------------
integer_indices      0.0603473  8.29159       1             0.359496
integer_2D_indices  18.0241     0           298.672         1.00583
```

This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
2020-04-15 12:43:55 -07:00
Peter Hawkins
714b276b9a
Implement jax.ops.index_mul. (#2696)
* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.
2020-04-13 16:16:34 -04:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
Peter Hawkins
a3cc9a7d32
Remove a workaround for a long-fixed XLA type conversion bug. (#2670) 2020-04-10 11:45:27 -04:00
Peter Hawkins
1bb67637ca
Add batch_group_count to conv_general_dilated. (#2635)
* Add batch_group_count to conv_general_dilated.

* Use batch_group_count for RHS grouped convolution transpose rule.

* Implement lhs/rhs transpose and batching rules for batch_group_count convolution.
2020-04-09 16:21:30 -04:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Peter Hawkins
5174f6d5dc
Add type annotations to user-facing functions in lax.py (#2644)
* Add type annotations to user-facing functions in lax.py

* Remove redundant comment.
2020-04-08 14:13:15 -04:00
Peter Hawkins
fa383b4a9f
Mark primitive parameters as keyword-only arguments in rules in lax.py. (#2625)
* Mark primitive parameters as keyword-only arguments in rules in lax.py.

* Fix dynamic update slice batching rule.

* Fix dynamic slice batching rule.
2020-04-07 09:38:10 -04:00
Peter Hawkins
cf4dd84b14
cumsum is linear, so its gradient can be linear also. (#2618)
* cumsum is linear, so its gradient can be linear also.

* Rename _impl functions to _prefix_scan.
2020-04-06 15:14:22 -04:00
Peter Hawkins
36c529d4e3
Handle n==0 case in TPU cumsum/cumprod. (#2617) 2020-04-06 12:33:55 -04:00
Peter Hawkins
329321b0f1
Add backend-specific lowering for cumsum/cumprod on TPU. (#2614)
* Add backend-specific lowering for cumsum/cumprod on TPU.

Make cumsum/cumprod primitives so they can have backend-specific lowerings.

* Disable cumulative reduction gradient test on TPU.
2020-04-06 11:22:01 -04:00
Sharad Vikram
99944d1204
Fix lax.broadcast_shapes returning numpy ints in shape tuple (#2471)
* Fix lax.broadcast_shapes returning numpy ints in shape tuple

* Use _canonicalize_dimension and add test
2020-04-04 23:19:39 -07:00
Peter Hawkins
2b3befff32
Make reduce_prod differentiable to arbitrary order. (#2597)
* Make reduce_prod differentiable to arbitrary order.

The previous strategy for computing the JVP of reduce_prod used a pair of reduce_window operations to form left and right products for each position.

This PR instead builds an explicit reduction tree and differentiates through it, which while not as efficient as using XLA's built-in reductions, has the advantage of being differentiable to arbitrary order.
.

* Return the tree-reduction primals instead of returning the original primals in JVP rule.
2020-04-03 16:09:48 -04:00
Matthew Johnson
bdc0c3bf43 remove remat context check, add initial staging 2020-03-29 23:29:55 -07:00
Matthew Johnson
1b5978953b add ShardedDeviceArray to ad vspace op handlers
fixes #2529 (thanks, @dpfau !)
2020-03-28 11:56:12 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Srinivas Vasudevan
c7f211d433
Update JAX to use XLA hyperbolic functions. (#2415) 2020-03-19 10:29:37 -04:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Matthew Johnson
f1d9130f25 remove safe_mul (undo #383, also cf. #1052) 2020-03-17 22:07:53 -07:00
George Necula
0ddc2ec360 Fixed failing tests 2020-03-17 06:51:01 +01:00
George Necula
5cf82c756e Improved argument checking for lax.broadcast_in_dim
* Added checking that the output shape has higher or equal rank to input
* Added checking that the broadcast_dims are sorted (required by XLA)
* Relaxed check that operand dimension size can be 1
* Added lax.broadcast_in_dim docstring
2020-03-17 06:51:01 +01:00
Matthew Johnson
c0c3a4a506
Merge pull request #2401 from hawkinsp/ones
Check for invalid shapes in broadcast_in_dim and fail gracefully.
2020-03-16 19:46:52 -07:00
Roy Frostig
6545cf3421
Merge pull request #2424 from google/broadcast-shapecheck
add lax.broadcast_in_dim shape check and test
2020-03-15 22:22:24 -07:00
Roy Frostig
94832f9627 add lax.broadcast_in_dim shape check and test
Operand dimensions must equal their corresponding dimensions in the broadcast shape.
2020-03-15 20:30:44 -07:00
Matthew Johnson
a7b3be71e8 move jet into jax.experimental 2020-03-15 11:10:56 -07:00
Matthew Johnson
668a1703bc add jet tests, remove top-level files 2020-03-14 21:22:10 -07:00
Jacob Kelly
840797d4a1 refactor reduce_max jet rule 2020-03-14 18:42:51 -07:00
Jacob Kelly
b4d003d460 jet rule for log
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:42:51 -07:00