The `custom_vmap` primitive stages out its wrapped function at call
time. It might extract closed-over or otherwise constant values
("consts") in doing so. To handle these, we can reduce back to the
empty closure setting: convert the consts to formal arguments, both in
the target function and in the custom vmap rule, and ignore them in
the latter.
We only need to play this trick once, on initial entry. After that, we
can resume in assuming an empty closure.
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.
PiperOrigin-RevId: 490249959
Namely, make it so that `split(key, n)[i]` equals `fold_in(key, i)`
for any key and for `0 <= i < n`.
This change affects the observed random bits for a fixed key (indirectly
through splits and folds), so here we guard it behind
`jax.config.jax_threefry_partitionable`. It's not described very well
by the flag name, but it makes for a simple way to bundle together
several random-bit-altering changes as part of the same upgrade cycle.
This results in a 5x speedup!
Before:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 3.03 ms 3.02 ms 220
```
After:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 0.673 ms 0.671 ms 985
```
PiperOrigin-RevId: 489880547
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).
By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).
I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
The TFRT CPU client is better in every way and the SE CPU client is unmaintained and has not been used by JAX in many months.
PiperOrigin-RevId: 489246256