327 Commits

Author SHA1 Message Date
Peter Hawkins
cf624196ed
Documentation fixes. (#3282)
Improve some cross-references and poorly quoted text.
2020-06-01 18:09:45 -04:00
Matthew Johnson
49a441f745
revisions to #3197 (#3264)
revert find_top_trace change from #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
2020-06-01 13:24:40 -07:00
Skye Wanderman-Milne
f78ece0f98
Allow sharding infeed inside sharded_jit. (#3256) 2020-06-01 12:35:18 -07:00
Stephan Hoyer
cc8fbb7669
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py

`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.

This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.

I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.

I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)

* Remove unnecessary branch

* Add lax.squeeze primitive

* Changes per review

* Fix typing

* Move expand_dims into lax

* Update per review; add comments/documentation

* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
Julius Kunze
02b4fd3500
Fix broadcast_shapes for polymorphic dims (#3216) (#3224)
* Fix #3216

* Simplify
2020-05-27 18:15:01 -04:00
Skye Wanderman-Milne
6ffde8061d
Implement pmap of sharded_jit (#3144)
* Implement pmap of sharded_jit

* Update jax/interpreters/pxla.py

Co-authored-by: James Bradbury <jekbradbury@google.com>

* Address comments

Co-authored-by: James Bradbury <jekbradbury@google.com>
2020-05-26 14:26:53 -07:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
alexdavies
85fe5a28f1
Add gradients to the scatter_max and scatter_min operations. (#3111)
This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255

Co-authored-by: Alex Davies <adavies@google.com>
2020-05-18 23:06:32 -07:00
Skye Wanderman-Milne
888c9c77b3 Implement pmap of sharded_jit 2020-05-18 18:40:28 -07:00
Peter Hawkins
36e7fad1e2
Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140)
* Add a primitive integer_pow() for values raised to fixed integer scalar.

Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal().

Fixes #3136

```
In [1]: from jax import grad, make_jaxpr
In [2]: def inv(x): return 1/x
In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.))
0.043945312

In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.)
Out[4]:
{ lambda  ; a.
  let b = integer_pow[ y=-7 ] a
      c = mul -6.0 b
      d = mul -120.0 c
  in (d,) }

In [5]:
```

* Use x ** 3 in gelu definition.
2020-05-18 17:54:20 -04:00
Ed Schmerling
510af1de64
Fix documentation for nn.elu, nn.celu, and lax.expm1. (#3116) 2020-05-15 20:51:53 -07:00
Peter Hawkins
77703b8925
Add support for sorting complex values, defaulting to a NumPy-style l… (#3096)
* Add support for sorting complex values, defaulting to a NumPy-style lexicographic ordering.

Implemented using a custom comparator, since the XLA-level default comparator doesn't impose and ordering for complex values.

* Disable sort test on CPU and TPU.
2020-05-14 19:17:44 -04:00
Peter Hawkins
4ce2aa2563
Make lax.sort support tuple arguments using a variadic sort. (#3085)
* Make lax.sort support tuple arguments using a variadic sort.

Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly.

Remove sort_key_val_p, since it is redundant with a variadic sort_p.

* Fix mypy errors.

* Change JVP rule to use NumPy indexing.
Remove redundant case in batching rule.
2020-05-14 11:13:15 -04:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
notEvil
969ed8085c
Add decorator for performing broadcasting inside translation rules (#2468)
* Add decorator for broadcasting at the translation rule layer.

* Fix broadcasting in igamma gradients.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-05-06 10:15:17 -04:00
Srinivas Vasudevan
e51c7d7482
Add IgammaGradA (#2504) 2020-05-05 20:10:31 -04:00
tamaranorman
04102e5b9d
Allow ConvDimensionNumbers to be passed into conv_transpose (#2915) 2020-05-04 14:02:13 -04:00
James Bradbury
1cdd8f1b99
Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896)
* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0
2020-05-01 14:37:13 -07:00
Julius Kunze
c00e9a2a52
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-05-01 12:34:29 -07:00
Tom Hennigan
0736679c33
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)
At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
2020-05-01 10:00:38 -07:00
George Necula
ac023bf28f
Fixed a few places where device sticky-ness was lost. Added FAQ (#2882)
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.

I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.

* Fix mypy annotations and updates based on comments

* Undid some changes, will make another PR
2020-05-01 10:06:59 +03:00
Peter Hawkins
0557248fbd
Check for unsupported dtypes and issue a helpful error. (#2885) 2020-04-29 14:14:49 -04:00
Anselm Levskaya
e599a25422
fix sort_key_val return type annotation, docstring 2020-04-28 10:49:17 -07:00
Anselm Levskaya
ca4e396e31
Merge pull request #2853 from levskaya/topkjvp
Add top_k jvp and batching rules and tests
2020-04-29 00:57:29 +10:00
Anselm Levskaya
dddad2a3dc Add top_k jvp and batching rules 2020-04-28 07:19:58 -07:00
Jamie Townsend
75617be803
Add population_count primitive to lax (#2753)
* add population_count primitive (needs new jaxlib)

fixes #2263

* Add popcount docs

* Add population_count to lax_reference

* Use int prng (since we're only testing uints)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 22:32:52 -07:00
Jon Malmaud
77901e9fa7
Fix lax.rng_uniform. (#2830) 2020-04-24 16:43:04 -04:00
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
Matthew Johnson
13a17286df
stop_gradient_p -> ad_util.py, re-enable some mypy (#2806) 2020-04-23 13:12:24 -07:00
Matthew Johnson
2e34dbc188 update travis to match min jaxlib version 2020-04-21 19:04:28 -07:00
Matthew Johnson
1bcaef142f
apply is_stable=True to sort translation rules (#2789)
fixes #2779
2020-04-21 17:47:28 -07:00
Stephan Hoyer
e6f0b8d87d
Raise an error if stop_gradient is called on non-arrays (#2750)
* Raise an error if stop_gradient is called on non-arrays

* Fix incorrect usage of stop_gradient in solve()

* fix *other* misuse of stop_gradient
2020-04-17 12:42:53 -07:00
Peter Hawkins
9a5b8d626a
Assert that reduction computations don't have constants. (#2754)
This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.
2020-04-17 14:38:50 -04:00
Skye Wanderman-Milne
0e29bd4ba3
Fix some bugs in _reshape_sharded_device_array (#2732) 2020-04-15 18:43:46 -07:00
Skye Wanderman-Milne
07571ae4dd
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142)
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.

This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.

Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).

**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
  nargs    nshards       mean       %std    relative    mean/baseline
-------  ---------  ---------  ---------  ----------  ---------------
     10          8  0.041855    4.15223      1               1.01466
    100          8  0.129884    4.85321      3.1032          0.988543
    101          8  0.136347    6.20233      3.2576          0.967138
    500          8  0.533207    3.6815      12.7394          1.0294
   1000          8  1.10338     0.525193    26.362           0.960435
   5000          8  5.33911     0          127.562           0.963319
    100          2  0.0638619  10.7069       1.52579         1.0362
    100          4  0.0868253   6.76701      2.07443         0.967323
    100          8  0.128151    6.46004      3.06177         0.979742
    100        100  1.22631     1.94885     29.299           1.00371
    100        500  6.60746     0          157.865           0.956657
```
**pmap_shard_outputs**
```
  nouts    nshards        mean       %std    relative    mean/baseline
-------  ---------  ----------  ---------  ----------  ---------------
     10          8   0.0664526   9.49251      1               0.938466
    100          8   0.195711    2.19429      2.94512         1.04239
    500          8   0.82577     0.330864    12.4265          0.994669
   1000          8   1.68323     1.0516      25.3298          0.966915
   5000          8   8.89032     0          133.784           0.998038
    100          2   0.074806   10.1734       1.12571         0.980254
    100          4   0.121334    5.76774      1.82588         1.02033
    100          8   0.185253    5.45068      2.78775         1.01666
    100        100   2.37076     0           35.6759          1.08629
    100        500  17.0832      0          257.074           0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn                mean     %std    relative    mean/baseline
------------------  ----------  -------  ----------  ---------------
integer_indices      0.0603473  8.29159       1             0.359496
integer_2D_indices  18.0241     0           298.672         1.00583
```

This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
2020-04-15 12:43:55 -07:00
Peter Hawkins
714b276b9a
Implement jax.ops.index_mul. (#2696)
* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.
2020-04-13 16:16:34 -04:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
Peter Hawkins
a3cc9a7d32
Remove a workaround for a long-fixed XLA type conversion bug. (#2670) 2020-04-10 11:45:27 -04:00
Peter Hawkins
1bb67637ca
Add batch_group_count to conv_general_dilated. (#2635)
* Add batch_group_count to conv_general_dilated.

* Use batch_group_count for RHS grouped convolution transpose rule.

* Implement lhs/rhs transpose and batching rules for batch_group_count convolution.
2020-04-09 16:21:30 -04:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Peter Hawkins
5174f6d5dc
Add type annotations to user-facing functions in lax.py (#2644)
* Add type annotations to user-facing functions in lax.py

* Remove redundant comment.
2020-04-08 14:13:15 -04:00
Peter Hawkins
fa383b4a9f
Mark primitive parameters as keyword-only arguments in rules in lax.py. (#2625)
* Mark primitive parameters as keyword-only arguments in rules in lax.py.

* Fix dynamic update slice batching rule.

* Fix dynamic slice batching rule.
2020-04-07 09:38:10 -04:00
Peter Hawkins
cf4dd84b14
cumsum is linear, so its gradient can be linear also. (#2618)
* cumsum is linear, so its gradient can be linear also.

* Rename _impl functions to _prefix_scan.
2020-04-06 15:14:22 -04:00
Peter Hawkins
36c529d4e3
Handle n==0 case in TPU cumsum/cumprod. (#2617) 2020-04-06 12:33:55 -04:00
Peter Hawkins
329321b0f1
Add backend-specific lowering for cumsum/cumprod on TPU. (#2614)
* Add backend-specific lowering for cumsum/cumprod on TPU.

Make cumsum/cumprod primitives so they can have backend-specific lowerings.

* Disable cumulative reduction gradient test on TPU.
2020-04-06 11:22:01 -04:00
Sharad Vikram
99944d1204
Fix lax.broadcast_shapes returning numpy ints in shape tuple (#2471)
* Fix lax.broadcast_shapes returning numpy ints in shape tuple

* Use _canonicalize_dimension and add test
2020-04-04 23:19:39 -07:00
Peter Hawkins
2b3befff32
Make reduce_prod differentiable to arbitrary order. (#2597)
* Make reduce_prod differentiable to arbitrary order.

The previous strategy for computing the JVP of reduce_prod used a pair of reduce_window operations to form left and right products for each position.

This PR instead builds an explicit reduction tree and differentiates through it, which while not as efficient as using XLA's built-in reductions, has the advantage of being differentiable to arbitrary order.
.

* Return the tree-reduction primals instead of returning the original primals in JVP rule.
2020-04-03 16:09:48 -04:00
Matthew Johnson
bdc0c3bf43 remove remat context check, add initial staging 2020-03-29 23:29:55 -07:00