* add population_count primitive (needs new jaxlib)
fixes#2263
* Add popcount docs
* Add population_count to lax_reference
* Use int prng (since we're only testing uints)
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Remove usage of xla_client.{Computation,ComputationBuilder}.
ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.
Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.
This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.
Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).
**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
nargs nshards mean %std relative mean/baseline
------- --------- --------- --------- ---------- ---------------
10 8 0.041855 4.15223 1 1.01466
100 8 0.129884 4.85321 3.1032 0.988543
101 8 0.136347 6.20233 3.2576 0.967138
500 8 0.533207 3.6815 12.7394 1.0294
1000 8 1.10338 0.525193 26.362 0.960435
5000 8 5.33911 0 127.562 0.963319
100 2 0.0638619 10.7069 1.52579 1.0362
100 4 0.0868253 6.76701 2.07443 0.967323
100 8 0.128151 6.46004 3.06177 0.979742
100 100 1.22631 1.94885 29.299 1.00371
100 500 6.60746 0 157.865 0.956657
```
**pmap_shard_outputs**
```
nouts nshards mean %std relative mean/baseline
------- --------- ---------- --------- ---------- ---------------
10 8 0.0664526 9.49251 1 0.938466
100 8 0.195711 2.19429 2.94512 1.04239
500 8 0.82577 0.330864 12.4265 0.994669
1000 8 1.68323 1.0516 25.3298 0.966915
5000 8 8.89032 0 133.784 0.998038
100 2 0.074806 10.1734 1.12571 0.980254
100 4 0.121334 5.76774 1.82588 1.02033
100 8 0.185253 5.45068 2.78775 1.01666
100 100 2.37076 0 35.6759 1.08629
100 500 17.0832 0 257.074 0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn mean %std relative mean/baseline
------------------ ---------- ------- ---------- ---------------
integer_indices 0.0603473 8.29159 1 0.359496
integer_2D_indices 18.0241 0 298.672 1.00583
```
This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
* Implement jax.ops.index_mul.
* Add index_mul to documentation.
* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.
* Fix typo in docstring.
* Make pytest run over JAX tests warning clean, and error on warnings.
Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.
Also fix crashes on Mac related to a preexisting linear algebra bug.
* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
* Add batch_group_count to conv_general_dilated.
* Use batch_group_count for RHS grouped convolution transpose rule.
* Implement lhs/rhs transpose and batching rules for batch_group_count convolution.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:
* instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
* instead of PartialVal((None, pval)) we use PartialVal.known(pval)
Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
* Add backend-specific lowering for cumsum/cumprod on TPU.
Make cumsum/cumprod primitives so they can have backend-specific lowerings.
* Disable cumulative reduction gradient test on TPU.
* Make reduce_prod differentiable to arbitrary order.
The previous strategy for computing the JVP of reduce_prod used a pair of reduce_window operations to form left and right products for each position.
This PR instead builds an explicit reduction tree and differentiates through it, which while not as efficient as using XLA's built-in reductions, has the advantage of being differentiable to arbitrary order.
.
* Return the tree-reduction primals instead of returning the original primals in JVP rule.
* Added checking that the output shape has higher or equal rank to input
* Added checking that the broadcast_dims are sorted (required by XLA)
* Relaxed check that operand dimension size can be 1
* Added lax.broadcast_in_dim docstring
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>