Why do we have caching in jax.remat at all? I added it in
https://github.com/google/jax/pull/11743 without much justification other than
it made some tests faster. I think I was worried that the switch to the new
remat's "initial-style" (jaxpr forming up-front) approach would regress
eager-mode performance, so I added benchmarks to measure it and then made those
fast with caching.
But the caching seems a bit too aggressive when static_argnums are involved. In
particular, I allowed caching on Tracer arguments (by object id). That seems
dangerous!
So the change here is to check whether any of the arguments marked static by
static_argnums are Tracers. If so, skip the caching. This change happens not to
affect the benchmarks at all.
PiperOrigin-RevId: 529502687
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.
PiperOrigin-RevId: 529437419
We are seeing some failures when comparing the results
for eigh with shape polymorphism and without.
Normally, shape polymorphism should not change the HLO
so a golden comparison is not necessarily bad, even though
for eigh we should check for correctness of the results
rather than identity.
We need to investigate this further but meanwhile turn
off these tests. The changes introduced recently for
shape polymorphism for eigh are not affecting the
code paths in absence of shape polymorphism. So it
is appropriate to just turn off the tests, and add
an error that shape polymorphism for eigh on
GPU is not ready.
PiperOrigin-RevId: 529388749
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...
PiperOrigin-RevId: 529385973
The use would be to find the output shapes for a function in
presence of shape polymorphism, and to compute the
`polymorphic_shapes` value that can be used in a subsequent
call to `jax2tf.convert`.
Create the new metric '/jax/checkpoint/write/async/thread_duration_sec' to measure the savings from the async thread creation time.
PiperOrigin-RevId: 529227213
We are using the new support for dynamic shapes for hlo.CustomCallOp, where
we need to pass the output shapes as additional operands.
This allows us to enable multiple "random" tests that were previously disabled.
PiperOrigin-RevId: 528990469
The QDWH splitting step involves two orthogonal projectors
P_plus = -0.5*(U-I) and P_minus = 0.5*(U+I), one of which will have rank k and the other rank n-k. Ideally, if we are able to pick the median eigenvalue for the split point optimally, k will be near n/2, and the rank of the two projectors will be similar. However, if our guess of the median eigenvalue is poor or the matrix is rank-deficient, k can be far from n/2, and the cost of the subspace iteration will be higher for the projector of higher rank, since it involves computing the QR decomposition of a matrix of size n x rank.
This change makes the algorithm adaptively pick the projector of lower rank.
PiperOrigin-RevId: 528941151
To handle Tracers, ShapedArray, concrete Arrays, etc `global_array_to_host_local_array` and `host_local_array_to_global_array` are now primitives.
PiperOrigin-RevId: 528925663
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:
[Rollback] Update required Python version to 3.9
PiperOrigin-RevId: 528905991
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.
PiperOrigin-RevId: 528621971