In cl/530804516 we changed the parser for polymorphic shape
specifications and also changed the error message. This
lead to failure in the TF.js jax_conversion_test.
We improve the error message and adjust the jax_conversion_test
to match the new message.
PiperOrigin-RevId: 531238700
When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.
We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
This uses an ApproxTopK custom-call, which we add support for in supported by
MHLO, by including a lowering to XLA's PartialReduce custom_call via the Client
XLA ApproxTopK function.
PiperOrigin-RevId: 530805966
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.
This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.
PiperOrigin-RevId: 530712918
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.
Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
* added support for shape polymorphism for partitionable threefry and for
random_split.
* removed footgun that was ignoring the partitionable flag in presence of
shape polymorphism.
* Replicated the PRNG tests for threefry (partitionable and non-partitionable),
and unsafe_rbg.
* Added general support for overriding jax.config flags for PolyHarness
This fixes the known bug with random_gamma.
The known missing feature is shape polymorphism for RngBitGenerator.
https://github.com/openxla/stablehlo/issues/1344
Previously, we disabled `check_result` (check that the JAX native and JAX with shape polymorphism produce the same result) for test harnesses that are created by vmap on primitive harnesses if the primitive harness has a custom assertion.
Now we enable that checking even for those harnesses, and we use the same custom assertion.
PiperOrigin-RevId: 530547784