Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.
This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.
PiperOrigin-RevId: 530712918
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.
Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
* added support for shape polymorphism for partitionable threefry and for
random_split.
* removed footgun that was ignoring the partitionable flag in presence of
shape polymorphism.
* Replicated the PRNG tests for threefry (partitionable and non-partitionable),
and unsafe_rbg.
* Added general support for overriding jax.config flags for PolyHarness
This fixes the known bug with random_gamma.
The known missing feature is shape polymorphism for RngBitGenerator.
https://github.com/openxla/stablehlo/issues/1344
Previously, we disabled `check_result` (check that the JAX native and JAX with shape polymorphism produce the same result) for test harnesses that are created by vmap on primitive harnesses if the primitive harness has a custom assertion.
Now we enable that checking even for those harnesses, and we use the same custom assertion.
PiperOrigin-RevId: 530547784
This reverts the setup.py changes from
f28b20175f307d5a56502446a9706480126a5bd4. We actually need to fix some
more issues before releasing 0.4.9, so fix the install at HEAD in the
meantime.
Background: Currently, we pad the sub-matrices that occur during the spectral bisection algorithm to fit in a small number of buckets, in order to keep compilation time down. Each unique bucket size gives rise to a separate JIT compilation. The current strategy uses powers of two times the termination size of 256, below which we switch to a Jacobi solver. One issue is that the bisection step rarely splits the matrix in two exact equal parts, so one of the child-problems is forced to use the large bucket size of its parent, which wastes significant device cycles.
This changes modifies the bucket selection strategy to not use [256, 512, 1024, ... n], but instead include a little slack at each level, such that both sub-problems from a non-perfect split will likely fall into the smaller bucket size. Specifically, we add 4% slack and round up to the next larger multiple of 32. These heuristic values were found experimentally. As an example, for n = 2048, we get the bucket sizes [2048, 1056, 544, 288, 256].
Maasuring runtimes on random matrices of size 512, 1024, and 2048, we see significant speedups:
N | wall time before | wall time after
===========================================
512 | 27.8 ms | 24.8 ms
1024 | 97.6 ma | 79.3 ms
2048 | 414.5 ms | 308.0 ms
PiperOrigin-RevId: 529567005
Why do we have caching in jax.remat at all? I added it in
https://github.com/google/jax/pull/11743 without much justification other than
it made some tests faster. I think I was worried that the switch to the new
remat's "initial-style" (jaxpr forming up-front) approach would regress
eager-mode performance, so I added benchmarks to measure it and then made those
fast with caching.
But the caching seems a bit too aggressive when static_argnums are involved. In
particular, I allowed caching on Tracer arguments (by object id). That seems
dangerous!
So the change here is to check whether any of the arguments marked static by
static_argnums are Tracers. If so, skip the caching. This change happens not to
affect the benchmarks at all.
PiperOrigin-RevId: 529502687
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.
PiperOrigin-RevId: 529437419
We are seeing some failures when comparing the results
for eigh with shape polymorphism and without.
Normally, shape polymorphism should not change the HLO
so a golden comparison is not necessarily bad, even though
for eigh we should check for correctness of the results
rather than identity.
We need to investigate this further but meanwhile turn
off these tests. The changes introduced recently for
shape polymorphism for eigh are not affecting the
code paths in absence of shape polymorphism. So it
is appropriate to just turn off the tests, and add
an error that shape polymorphism for eigh on
GPU is not ready.
PiperOrigin-RevId: 529388749
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...
PiperOrigin-RevId: 529385973
The use would be to find the output shapes for a function in
presence of shape polymorphism, and to compute the
`polymorphic_shapes` value that can be used in a subsequent
call to `jax2tf.convert`.