Temporary fixes an issue with `python -m build` that fails when python 3.8 is used because `matplotlib~=3.8.4` is unavailable for this python version.
We are working on creating Bazel build rule with the hermetic Python for JAX wheel ([we already have Jaxlib and plugins build rules ready](https://github.com/jax-ml/jax/pull/23276)). The required python modules are provided in requirements.in file, so when we implement Bazel build rule for JAX wheel, requirements.in will be the only source of dependencies, and test-requirements.txt won't be needed for building JAX wheel.
PiperOrigin-RevId: 692260046
This CL deprecates the device_type parameter of `tpu_custom_call.as_tpu_kernel()` in favour of the `tpu.core_type` annotation.
The latter is more fine-grained: it is applied on `func.FuncOp` instead of the entire module, supports `tc`, `sc_scalar_subcore` and `sc_vector_subcore`.
`device_type` of the TPU CustomCall HLO is set to `sparsecore` if `sc_scalar_subcore` or `sc_vector_subcore` annotation is provided. Otherwise, `device_type` is not set and the CustomCall targets TC.
PiperOrigin-RevId: 692212644
Also improved the corresponding test cases to ensure better coverage and accuracy.
This PR is similar to https://github.com/jax-ml/jax/pull/22283, which adds lowering for `lax.sign`.
PiperOrigin-RevId: 691988164
mlir.ir.Context has the unfortunate behavior that it loads all dialects linked into the binary, even those we have no intention of using. This is fairly benign in JAX's usual configuration, but if JAX is linked together with other MLIR-using software it can be problematic.
PiperOrigin-RevId: 691984229
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.
Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d
PiperOrigin-RevId: 691568933
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike
Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.
Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.
PiperOrigin-RevId: 691496997
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.
PiperOrigin-RevId: 691494516
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.
PiperOrigin-RevId: 691458051