I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,
def f(x: int) -> str: ...
def g(x: int) -> str: ...
callback = f if ... else g # has type object!
This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:
export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
This PR includes an end-to-end example project which demonstrates the
use of the FFI. This complements [the FFI
tutorial](https://jax.readthedocs.io/en/latest/ffi.html) by putting all
of the code in one place, as well as demonstrating how FFI extensions
can be packaged. Alongside the example project, I have also added a new
GitHub Actions workflow to test the example as part of CI.
For now, the tests only run on CPU, but once we have GPU runners for
GitHub Actions (soon!), I plan on migrating the custom call examples
from `docs/gpu_ops` and `docs/cuda_custom_call` into this test case.
Similarly, I wanted to start small and this example project only
includes exactly the same functions as the tutorial for now, but I think
this could be a good place to showcase more advanced examples (including
custom calls with state).
The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286
suppressions.
We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
So, instead of
pl.BlockSpec(lambda i, j: ..., (42, 24))
``pl.BlockSpec`` now expects
pl.BlockSpec((42, 24), lambda i, j: ...)
I will update Pallas tests in a follow up.
PiperOrigin-RevId: 648486321
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).
PiperOrigin-RevId: 592266215
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.
Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.
Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).
In the process of implementing this we have done some small cleanup of the Exported structure:
* renamed serialization_version to mlir_module_serialization_version
* renamed disabled_checks to disabled_safety_checks
This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.
There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.
PiperOrigin-RevId: 590078785
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.
Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.
* E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
* E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
* E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.
Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)
This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.
PiperOrigin-RevId: 581016785