When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.
We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
This uses an ApproxTopK custom-call, which we add support for in supported by
MHLO, by including a lowering to XLA's PartialReduce custom_call via the Client
XLA ApproxTopK function.
PiperOrigin-RevId: 530805966
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.
This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.
PiperOrigin-RevId: 530712918
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.
Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
* added support for shape polymorphism for partitionable threefry and for
random_split.
* removed footgun that was ignoring the partitionable flag in presence of
shape polymorphism.
* Replicated the PRNG tests for threefry (partitionable and non-partitionable),
and unsafe_rbg.
* Added general support for overriding jax.config flags for PolyHarness
This fixes the known bug with random_gamma.
The known missing feature is shape polymorphism for RngBitGenerator.
https://github.com/openxla/stablehlo/issues/1344
Previously, we disabled `check_result` (check that the JAX native and JAX with shape polymorphism produce the same result) for test harnesses that are created by vmap on primitive harnesses if the primitive harness has a custom assertion.
Now we enable that checking even for those harnesses, and we use the same custom assertion.
PiperOrigin-RevId: 530547784
Previously, we used a simple regexp-based parser, which could only
parse additions of multiplications of dimension variables. Now the
symbolic dimension expressions can also contain "mod" and "floordiv"
which would break the parser in confusing ways.
Now we have a recursive-descent parser, with much better error
reporting support.
This reverts the setup.py changes from
f28b20175f307d5a56502446a9706480126a5bd4. We actually need to fix some
more issues before releasing 0.4.9, so fix the install at HEAD in the
meantime.