* Added argument check to all primitives.
The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.
This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.
In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.
* Merged find_top_trace with check_args
* Add decorator for broadcasting at the translation rule layer.
* Fix broadcasting in igamma gradients.
Co-authored-by: Peter Hawkins <phawkins@google.com>
* allow in_axes=None for pmap in api.py
* wire in_axes=None through parallel_callable
* add test
* fix error string
* fixes
* fixes
* add test for nested pmap with in_axes
* test pmap still defaults to (implicit) out_axes=0
At head the following fails:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.
I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.
* Fix mypy annotations and updates based on comments
* Undid some changes, will make another PR
* add population_count primitive (needs new jaxlib)
fixes#2263
* Add popcount docs
* Add population_count to lax_reference
* Use int prng (since we're only testing uints)
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Remove usage of xla_client.{Computation,ComputationBuilder}.
ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.
Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.
This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.
Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).
**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
nargs nshards mean %std relative mean/baseline
------- --------- --------- --------- ---------- ---------------
10 8 0.041855 4.15223 1 1.01466
100 8 0.129884 4.85321 3.1032 0.988543
101 8 0.136347 6.20233 3.2576 0.967138
500 8 0.533207 3.6815 12.7394 1.0294
1000 8 1.10338 0.525193 26.362 0.960435
5000 8 5.33911 0 127.562 0.963319
100 2 0.0638619 10.7069 1.52579 1.0362
100 4 0.0868253 6.76701 2.07443 0.967323
100 8 0.128151 6.46004 3.06177 0.979742
100 100 1.22631 1.94885 29.299 1.00371
100 500 6.60746 0 157.865 0.956657
```
**pmap_shard_outputs**
```
nouts nshards mean %std relative mean/baseline
------- --------- ---------- --------- ---------- ---------------
10 8 0.0664526 9.49251 1 0.938466
100 8 0.195711 2.19429 2.94512 1.04239
500 8 0.82577 0.330864 12.4265 0.994669
1000 8 1.68323 1.0516 25.3298 0.966915
5000 8 8.89032 0 133.784 0.998038
100 2 0.074806 10.1734 1.12571 0.980254
100 4 0.121334 5.76774 1.82588 1.02033
100 8 0.185253 5.45068 2.78775 1.01666
100 100 2.37076 0 35.6759 1.08629
100 500 17.0832 0 257.074 0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn mean %std relative mean/baseline
------------------ ---------- ------- ---------- ---------------
integer_indices 0.0603473 8.29159 1 0.359496
integer_2D_indices 18.0241 0 298.672 1.00583
```
This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
* Implement jax.ops.index_mul.
* Add index_mul to documentation.
* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.
* Fix typo in docstring.
* Make pytest run over JAX tests warning clean, and error on warnings.
Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.
Also fix crashes on Mac related to a preexisting linear algebra bug.
* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
* Add batch_group_count to conv_general_dilated.
* Use batch_group_count for RHS grouped convolution transpose rule.
* Implement lhs/rhs transpose and batching rules for batch_group_count convolution.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:
* instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
* instead of PartialVal((None, pval)) we use PartialVal.known(pval)
Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
* Add backend-specific lowering for cumsum/cumprod on TPU.
Make cumsum/cumprod primitives so they can have backend-specific lowerings.
* Disable cumulative reduction gradient test on TPU.
* Make reduce_prod differentiable to arbitrary order.
The previous strategy for computing the JVP of reduce_prod used a pair of reduce_window operations to form left and right products for each position.
This PR instead builds an explicit reduction tree and differentiates through it, which while not as efficient as using XLA's built-in reductions, has the advantage of being differentiable to arbitrary order.
.
* Return the tree-reduction primals instead of returning the original primals in JVP rule.
* Added checking that the output shape has higher or equal rank to input
* Added checking that the broadcast_dims are sorted (required by XLA)
* Relaxed check that operand dimension size can be 1
* Added lax.broadcast_in_dim docstring
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>