Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.
For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).
Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.
PiperOrigin-RevId: 583816243
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.
We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py.
We disable the test in GitHub CI, because it is very large, pending
some changes to ensure it parallelizes well. The test is still
running in internal CI. This is matching the current behavior, since
jax2tf tests are only run internally.
PiperOrigin-RevId: 583603863
Split code to determine CUDA library versions out of py_extension() module and into a cc_library(), because it fixes a linking problem in Google's build. (Long story, not worth it.)
Fixes https://github.com/google/jax/issues/8289
PiperOrigin-RevId: 583544218
Until now the backwards compatibility tests for exporting JAX functions with custom calls were part of the jax2tf test suite. But these tests are independent of TF, and we need to write such tests for Pallas and other projects that should not depend on jax2tf.
Here we move the test utilities out of jax2tf.
This is needed to enable writing Pallas backwards compatibility tests.
We rename back_compat_test_util.py to export_back_compat_test_util.py for clarity.
In a subsequent move we will move the actual backwards compatibility tests themselves out of jax2tf.
PiperOrigin-RevId: 583312085
Previously we used lax.max to evaluate core.non_negative_dim, but this is
problematic if we are in a tracing context. Then, even if the operand is
a constant we produce a tracer. Change the code to check explicitly if
the operand is a constant or if it is a symbolic expression.
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.
PiperOrigin-RevId: 583057751