This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
* Implement mask for slice, conv, pad, transpose, where
* Remove tentative mask(jit)
* Add explanatory comment to dot_general masking rule
* Rm reshape from select masking rule
* Rm unnecessary check from lax slice abstract_eval rule
* Revert to standard indentation in masking_test.py
* Begin simplifying masking tests
* Finish drafting masking check function
* More progress simplifying tests
* Add conv masking in batch dim
* Finish fixing up tests
* Revert to old API, making out_shape compulsory again
* More efficient conv masking rule
* Tidy up masking_test imports
* Check that out tree is preserved by masking
* fix flake errors
Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Improve tracing performance of _dynamic_slice_indices
* More precisely preserve semantics of dynamic_slice_indices
* Use safe_map in dynamic_slice_indices
revert find_top_trace change from #3197
The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).
Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
* Added argument check to all primitives.
The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.
This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.
In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.
* Merged find_top_trace with check_args
This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255
Co-authored-by: Alex Davies <adavies@google.com>
* Add a primitive integer_pow() for values raised to fixed integer scalar.
Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal().
Fixes#3136
```
In [1]: from jax import grad, make_jaxpr
In [2]: def inv(x): return 1/x
In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.))
0.043945312
In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.)
Out[4]:
{ lambda ; a.
let b = integer_pow[ y=-7 ] a
c = mul -6.0 b
d = mul -120.0 c
in (d,) }
In [5]:
```
* Use x ** 3 in gelu definition.
* Add support for sorting complex values, defaulting to a NumPy-style lexicographic ordering.
Implemented using a custom comparator, since the XLA-level default comparator doesn't impose and ordering for complex values.
* Disable sort test on CPU and TPU.
* Make lax.sort support tuple arguments using a variadic sort.
Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly.
Remove sort_key_val_p, since it is redundant with a variadic sort_p.
* Fix mypy errors.
* Change JVP rule to use NumPy indexing.
Remove redundant case in batching rule.
* Added argument check to all primitives.
The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.
This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.
In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.
* Merged find_top_trace with check_args
* Add decorator for broadcasting at the translation rule layer.
* Fix broadcasting in igamma gradients.
Co-authored-by: Peter Hawkins <phawkins@google.com>
* allow in_axes=None for pmap in api.py
* wire in_axes=None through parallel_callable
* add test
* fix error string
* fixes
* fixes
* add test for nested pmap with in_axes
* test pmap still defaults to (implicit) out_axes=0
At head the following fails:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.
I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.
* Fix mypy annotations and updates based on comments
* Undid some changes, will make another PR
* add population_count primitive (needs new jaxlib)
fixes#2263
* Add popcount docs
* Add population_count to lax_reference
* Use int prng (since we're only testing uints)
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Remove usage of xla_client.{Computation,ComputationBuilder}.
ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.
Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.
This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.
Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).
**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
nargs nshards mean %std relative mean/baseline
------- --------- --------- --------- ---------- ---------------
10 8 0.041855 4.15223 1 1.01466
100 8 0.129884 4.85321 3.1032 0.988543
101 8 0.136347 6.20233 3.2576 0.967138
500 8 0.533207 3.6815 12.7394 1.0294
1000 8 1.10338 0.525193 26.362 0.960435
5000 8 5.33911 0 127.562 0.963319
100 2 0.0638619 10.7069 1.52579 1.0362
100 4 0.0868253 6.76701 2.07443 0.967323
100 8 0.128151 6.46004 3.06177 0.979742
100 100 1.22631 1.94885 29.299 1.00371
100 500 6.60746 0 157.865 0.956657
```
**pmap_shard_outputs**
```
nouts nshards mean %std relative mean/baseline
------- --------- ---------- --------- ---------- ---------------
10 8 0.0664526 9.49251 1 0.938466
100 8 0.195711 2.19429 2.94512 1.04239
500 8 0.82577 0.330864 12.4265 0.994669
1000 8 1.68323 1.0516 25.3298 0.966915
5000 8 8.89032 0 133.784 0.998038
100 2 0.074806 10.1734 1.12571 0.980254
100 4 0.121334 5.76774 1.82588 1.02033
100 8 0.185253 5.45068 2.78775 1.01666
100 100 2.37076 0 35.6759 1.08629
100 500 17.0832 0 257.074 0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn mean %std relative mean/baseline
------------------ ---------- ------- ---------- ---------------
integer_indices 0.0603473 8.29159 1 0.359496
integer_2D_indices 18.0241 0 298.672 1.00583
```
This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
* Implement jax.ops.index_mul.
* Add index_mul to documentation.
* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.
* Fix typo in docstring.
* Make pytest run over JAX tests warning clean, and error on warnings.
Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.
Also fix crashes on Mac related to a preexisting linear algebra bug.
* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.