287 Commits

Author SHA1 Message Date
Adam Paszke
3f1d3a73ac Remove example from ad.instantiate_zeros, fix vmap bug 2020-06-05 15:52:01 +00:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Jake Vanderplas
b187663a87
deflake jax/lax & add to flake8 check (#3310) 2020-06-04 13:50:44 -07:00
Roy Frostig
6015a2a689 introduce lax.switch 2020-06-03 22:19:15 -07:00
Skye Wanderman-Milne
5ad9feda5f
Fix handling of infeed token inside sharded_jit (#3313) 2020-06-03 15:23:49 -07:00
Julius Kunze
d1dbf7c7d8
Implement mask for some primitives + jit. (#2922)
* Implement mask for slice, conv, pad, transpose, where

* Remove tentative mask(jit)

* Add explanatory comment to dot_general masking rule

* Rm reshape from select masking rule

* Rm unnecessary check from lax slice abstract_eval rule

* Revert to standard indentation in masking_test.py

* Begin simplifying masking tests

* Finish drafting masking check function

* More progress simplifying tests

* Add conv masking in batch dim

* Finish fixing up tests

* Revert to old API, making out_shape compulsory again

* More efficient conv masking rule

* Tidy up masking_test imports

* Check that out tree is preserved by masking

* fix flake errors

Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-03 13:40:48 -07:00
Jake Vanderplas
0db57cb541
Fix validation code in lax.conv (#3279) 2020-06-03 10:33:19 -07:00
Peter Hawkins
dd81a8dded
Fix some type errors in lax.py found by pytype. (#3292) 2020-06-02 10:27:14 -04:00
Jamie Townsend
3909875f9d
Improve speed of tracing dynamic_update_slice (#3247)
* Improve tracing performance of _dynamic_slice_indices

* More precisely preserve semantics of dynamic_slice_indices

* Use safe_map in dynamic_slice_indices
2020-06-02 09:37:32 -04:00
James Bradbury
f1a7073738
pmap(in_axes=None) of sharded_jit (#3257)
* pmap(in_axes=None) of sharded_jit

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

* address comments

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-06-01 16:50:22 -07:00
Peter Hawkins
cf624196ed
Documentation fixes. (#3282)
Improve some cross-references and poorly quoted text.
2020-06-01 18:09:45 -04:00
Matthew Johnson
49a441f745
revisions to #3197 (#3264)
revert find_top_trace change from #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
2020-06-01 13:24:40 -07:00
Skye Wanderman-Milne
f78ece0f98
Allow sharding infeed inside sharded_jit. (#3256) 2020-06-01 12:35:18 -07:00
Stephan Hoyer
cc8fbb7669
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py

`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.

This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.

I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.

I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)

* Remove unnecessary branch

* Add lax.squeeze primitive

* Changes per review

* Fix typing

* Move expand_dims into lax

* Update per review; add comments/documentation

* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
Julius Kunze
02b4fd3500
Fix broadcast_shapes for polymorphic dims (#3216) (#3224)
* Fix #3216

* Simplify
2020-05-27 18:15:01 -04:00
Skye Wanderman-Milne
6ffde8061d
Implement pmap of sharded_jit (#3144)
* Implement pmap of sharded_jit

* Update jax/interpreters/pxla.py

Co-authored-by: James Bradbury <jekbradbury@google.com>

* Address comments

Co-authored-by: James Bradbury <jekbradbury@google.com>
2020-05-26 14:26:53 -07:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
alexdavies
85fe5a28f1
Add gradients to the scatter_max and scatter_min operations. (#3111)
This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255

Co-authored-by: Alex Davies <adavies@google.com>
2020-05-18 23:06:32 -07:00
Skye Wanderman-Milne
888c9c77b3 Implement pmap of sharded_jit 2020-05-18 18:40:28 -07:00
Peter Hawkins
36e7fad1e2
Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140)
* Add a primitive integer_pow() for values raised to fixed integer scalar.

Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal().

Fixes #3136

```
In [1]: from jax import grad, make_jaxpr
In [2]: def inv(x): return 1/x
In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.))
0.043945312

In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.)
Out[4]:
{ lambda  ; a.
  let b = integer_pow[ y=-7 ] a
      c = mul -6.0 b
      d = mul -120.0 c
  in (d,) }

In [5]:
```

* Use x ** 3 in gelu definition.
2020-05-18 17:54:20 -04:00
Ed Schmerling
510af1de64
Fix documentation for nn.elu, nn.celu, and lax.expm1. (#3116) 2020-05-15 20:51:53 -07:00
Peter Hawkins
77703b8925
Add support for sorting complex values, defaulting to a NumPy-style l… (#3096)
* Add support for sorting complex values, defaulting to a NumPy-style lexicographic ordering.

Implemented using a custom comparator, since the XLA-level default comparator doesn't impose and ordering for complex values.

* Disable sort test on CPU and TPU.
2020-05-14 19:17:44 -04:00
Peter Hawkins
4ce2aa2563
Make lax.sort support tuple arguments using a variadic sort. (#3085)
* Make lax.sort support tuple arguments using a variadic sort.

Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly.

Remove sort_key_val_p, since it is redundant with a variadic sort_p.

* Fix mypy errors.

* Change JVP rule to use NumPy indexing.
Remove redundant case in batching rule.
2020-05-14 11:13:15 -04:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
notEvil
969ed8085c
Add decorator for performing broadcasting inside translation rules (#2468)
* Add decorator for broadcasting at the translation rule layer.

* Fix broadcasting in igamma gradients.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-05-06 10:15:17 -04:00
Srinivas Vasudevan
e51c7d7482
Add IgammaGradA (#2504) 2020-05-05 20:10:31 -04:00
tamaranorman
04102e5b9d
Allow ConvDimensionNumbers to be passed into conv_transpose (#2915) 2020-05-04 14:02:13 -04:00
James Bradbury
1cdd8f1b99
Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896)
* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0
2020-05-01 14:37:13 -07:00
Julius Kunze
c00e9a2a52
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-05-01 12:34:29 -07:00
Tom Hennigan
0736679c33
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)
At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
2020-05-01 10:00:38 -07:00
George Necula
ac023bf28f
Fixed a few places where device sticky-ness was lost. Added FAQ (#2882)
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.

I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.

* Fix mypy annotations and updates based on comments

* Undid some changes, will make another PR
2020-05-01 10:06:59 +03:00
Peter Hawkins
0557248fbd
Check for unsupported dtypes and issue a helpful error. (#2885) 2020-04-29 14:14:49 -04:00
Anselm Levskaya
e599a25422
fix sort_key_val return type annotation, docstring 2020-04-28 10:49:17 -07:00
Anselm Levskaya
ca4e396e31
Merge pull request #2853 from levskaya/topkjvp
Add top_k jvp and batching rules and tests
2020-04-29 00:57:29 +10:00
Anselm Levskaya
dddad2a3dc Add top_k jvp and batching rules 2020-04-28 07:19:58 -07:00
Jamie Townsend
75617be803
Add population_count primitive to lax (#2753)
* add population_count primitive (needs new jaxlib)

fixes #2263

* Add popcount docs

* Add population_count to lax_reference

* Use int prng (since we're only testing uints)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 22:32:52 -07:00
Jon Malmaud
77901e9fa7
Fix lax.rng_uniform. (#2830) 2020-04-24 16:43:04 -04:00
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
Matthew Johnson
13a17286df
stop_gradient_p -> ad_util.py, re-enable some mypy (#2806) 2020-04-23 13:12:24 -07:00
Matthew Johnson
2e34dbc188 update travis to match min jaxlib version 2020-04-21 19:04:28 -07:00
Matthew Johnson
1bcaef142f
apply is_stable=True to sort translation rules (#2789)
fixes #2779
2020-04-21 17:47:28 -07:00
Stephan Hoyer
e6f0b8d87d
Raise an error if stop_gradient is called on non-arrays (#2750)
* Raise an error if stop_gradient is called on non-arrays

* Fix incorrect usage of stop_gradient in solve()

* fix *other* misuse of stop_gradient
2020-04-17 12:42:53 -07:00
Peter Hawkins
9a5b8d626a
Assert that reduction computations don't have constants. (#2754)
This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.
2020-04-17 14:38:50 -04:00
Skye Wanderman-Milne
0e29bd4ba3
Fix some bugs in _reshape_sharded_device_array (#2732) 2020-04-15 18:43:46 -07:00
Skye Wanderman-Milne
07571ae4dd
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142)
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.

This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.

Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).

**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
  nargs    nshards       mean       %std    relative    mean/baseline
-------  ---------  ---------  ---------  ----------  ---------------
     10          8  0.041855    4.15223      1               1.01466
    100          8  0.129884    4.85321      3.1032          0.988543
    101          8  0.136347    6.20233      3.2576          0.967138
    500          8  0.533207    3.6815      12.7394          1.0294
   1000          8  1.10338     0.525193    26.362           0.960435
   5000          8  5.33911     0          127.562           0.963319
    100          2  0.0638619  10.7069       1.52579         1.0362
    100          4  0.0868253   6.76701      2.07443         0.967323
    100          8  0.128151    6.46004      3.06177         0.979742
    100        100  1.22631     1.94885     29.299           1.00371
    100        500  6.60746     0          157.865           0.956657
```
**pmap_shard_outputs**
```
  nouts    nshards        mean       %std    relative    mean/baseline
-------  ---------  ----------  ---------  ----------  ---------------
     10          8   0.0664526   9.49251      1               0.938466
    100          8   0.195711    2.19429      2.94512         1.04239
    500          8   0.82577     0.330864    12.4265          0.994669
   1000          8   1.68323     1.0516      25.3298          0.966915
   5000          8   8.89032     0          133.784           0.998038
    100          2   0.074806   10.1734       1.12571         0.980254
    100          4   0.121334    5.76774      1.82588         1.02033
    100          8   0.185253    5.45068      2.78775         1.01666
    100        100   2.37076     0           35.6759          1.08629
    100        500  17.0832      0          257.074           0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn                mean     %std    relative    mean/baseline
------------------  ----------  -------  ----------  ---------------
integer_indices      0.0603473  8.29159       1             0.359496
integer_2D_indices  18.0241     0           298.672         1.00583
```

This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
2020-04-15 12:43:55 -07:00
Peter Hawkins
714b276b9a
Implement jax.ops.index_mul. (#2696)
* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.
2020-04-13 16:16:34 -04:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
Peter Hawkins
a3cc9a7d32
Remove a workaround for a long-fixed XLA type conversion bug. (#2670) 2020-04-10 11:45:27 -04:00