The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.
TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).
The net effect of this change is mostly to tighten up many test tolerances on TPU.
PiperOrigin-RevId: 562953488
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
Previously we were setting out-of-bound indices to zero, which works in most (but not all) cases. The problem is that if (0, 0) is a defined matrix element, these subsequent zeros effectively overwrite this element in some cusparse routines.
The fix here is to add another row or column to the matrix as necessary, and to push these undefined values into that row/col, where they can be sliced off at the end of the cusparse operation so that they will not affect the computation of interest.
PiperOrigin-RevId: 513639921
Also, remove the cusparse lowering for batched matmul, because in testing I found that it returns incorrect results for CUSPARSE_SPMM_COO_ALG4. Our tests haven't revealed that because we currently only test for a single batch. To re-land this, we can add it to the private primitives in _lowerings and add another elif clause in the GPU impl.
PiperOrigin-RevId: 513604587
This is in preparation for cleaning up our bcoo_dot_general GPU lowering rules: by creating private primitives that closely follow the API of the cusparse kernels, we will be able to better express lowered translation rules that preprocess that data appropriately.
PiperOrigin-RevId: 513212715
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397