JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).
PiperOrigin-RevId: 592266215
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.
Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.
Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).
In the process of implementing this we have done some small cleanup of the Exported structure:
* renamed serialization_version to mlir_module_serialization_version
* renamed disabled_checks to disabled_safety_checks
This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.
There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.
PiperOrigin-RevId: 590078785
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.
Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.
* E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
* E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
* E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.
Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)
This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.
PiperOrigin-RevId: 581016785