When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.
We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.
Fixes https://github.com/jax-ml/jax/issues/26888
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.
This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).
This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)
We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.
PiperOrigin-RevId: 697631402
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.
PiperOrigin-RevId: 682656411
The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286
This is to fix Mac arm64 pytests on CI. The tests started failing after integrating ml-dtypes-0.5.0. Ignoring warnings is probably Ok, as it is inspired by a similar PR in ml-dtypes repo itself: https://github.com/jax-ml/ml_dtypes/pull/186
PiperOrigin-RevId: 676458202
simpler bitwise_right_shift implementation
to match previous PR
updating bitwise_right_shift_doc as an alias
readded jnp.bitwise_left_shift, jnp.bitwise_right_shift
Update sharded-computation doc to use make_mesh()
Rename `jtu.create_global_mesh` to `jtu.create_mesh` and use `jax.make_mesh` inside `jtu.create_mesh` to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
better true_divide and divide docs
doc wording update
[Mosaic TPU] Fix mosaic alignment check in concatenate rule.
PiperOrigin-RevId: 670837792
Fix pytype errors and args for jax.Array methods
Add docker builds for ubu22 and 24
Better docs for jax.numpy: log and log1p
random.key_impl: improve repr of output
Remove unused docstring addition: _PRECISION_DOC
update example optimizers library docstring
* JAXopt is being merged into Optax, so point only to Optax
* Update Optax's github repository URL
fixing merge duplication
updating tests to skip bitwise shift if numpy major version < 2
removed whitespace 659
keep non-bitwise tests for numpy < 2.0.0
more readable edit