61 Commits

Author SHA1 Message Date
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Olli Lupton
1bba1ea2e2 Add JAX_COMPILATION_CACHE_EXPECT_PGLE option
This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:

export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
2025-02-06 08:19:45 +00:00
Jake VanderPlas
f749fca760 [array api] use most recent version of array_api_tests 2024-11-20 14:50:06 -08:00
Jake VanderPlas
a115b2cec5 Update array-api-tests commit 2024-11-14 16:05:30 -08:00
Matthew Johnson
61150607e5 don't warn on unused type: ignore 2024-11-08 01:21:11 +00:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Jake VanderPlas
e1f280c843 CI: enable additional ruff formatting checks 2024-10-16 16:09:54 -07:00
Jake VanderPlas
b574d2ceb1 Fix aliases in jax.numpy type interface file.
This includes removing some alias declarations for functions that were
previously removed.
2024-10-16 10:40:56 -07:00
Kristian Hartikainen
1ea8e3c29d Update _cuda_path
- Remove jax-relative module path test
- Use `$CUDA_ROOT` environment variable if available
- Use `cuda_nvcc` module's path if installed
2024-10-07 20:32:05 +03:00
Alexander Pivovarov
69193aa6a4 Remove pylint sections from pyproject.toml.
use ruff instead
2024-09-26 23:29:56 +00:00
Dan Foreman-Mackey
e1a68eee5e Add FFI example project and test on CI.
This PR includes an end-to-end example project which demonstrates the
use of the FFI. This complements [the FFI
tutorial](https://jax.readthedocs.io/en/latest/ffi.html) by putting all
of the code in one place, as well as demonstrating how FFI extensions
can be packaged. Alongside the example project, I have also added a new
GitHub Actions workflow to test the example as part of CI.

For now, the tests only run on CPU, but once we have GPU runners for
GitHub Actions (soon!), I plan on migrating the custom call examples
from `docs/gpu_ops` and `docs/cuda_custom_call` into this test case.

Similarly, I wanted to start small and this example project only
includes exactly the same functions as the tutorial for now, but I think
this could be a good place to showcase more advanced examples (including
custom calls with state).
2024-09-24 17:23:13 -04:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
Yu-Hang Tang
c88c3aecae add k8s cluster environment 2024-09-20 17:26:53 +00:00
Sergei Lebedev
1289640f09 Deprecated calling `jax.dlpack.from_dlpack` with a DLPack tensor
PiperOrigin-RevId: 670723176
2024-09-03 15:16:02 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
Jake VanderPlas
d999208863 [array API] update test suite to most recent commit 2024-08-08 12:33:30 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
Jake VanderPlas
f887b66d5d Remove the unaccelerate_deprecation utility 2024-07-23 05:07:49 -07:00
Sergei Lebedev
89b5f4d151 Silenced the deprecation warning coming from TF dependencies 2024-07-05 10:57:05 +01:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Jake VanderPlas
8ebd8a3b35 CI: set ruff to Python 3.10 2024-06-26 14:58:39 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
Yash Katariya
e6f26ff256 Deprecate jax.xla_computation. Use JAX AOT APIs to get the equivalent of jax.xla_computation functionality.
PiperOrigin-RevId: 644107276
2024-06-17 13:02:35 -07:00
Jake VanderPlas
3f210c63a0 avoid globally silencing the jit backend/device warning 2024-06-12 14:43:14 -07:00
Ayaka
1a3a15c9e3 Implement LRU cache eviction for persistent compilation cache
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
Sergei Lebedev
0786da8fd8 Removed unnecessary mypy exclusions from pyproject.toml
* 2/3 files type check just fine now
* the remaining one could be handled via a file-level directive
2024-06-07 20:07:42 +01:00
Peter Hawkins
971ab0fba2 Make CuDNN SDPA API work with JAX with a CUDA plugin configuration. 2024-06-06 12:09:19 -04:00
Jake VanderPlas
b441a09a34 CI: remove stale warning filters 2024-05-28 13:13:40 -07:00
Sergei Lebedev
2473ebf508 Removed mentions of iree from the test suite 2024-05-24 10:31:57 +01:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Jake VanderPlas
d5d2fb087f Ignore deprecation warnings locally rather than globally 2024-05-20 20:28:25 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Meekail Zain
b88e2e808b Refactor array_api namespace, relying more directly on jax.numpy 2024-05-02 18:17:45 +00:00
Peter Hawkins
6ae01247f0 Fix pytest failures from compilation cache test.
The names of the functions in the compilation cache tests changed, causing warnings emitted by that test to become errors.
2024-04-29 11:08:07 -04:00
Jake VanderPlas
d33144e298 CI: avoid deprecated ruff configurations 2024-04-08 14:05:22 -07:00
George Necula
ca59971bef [host_callback] Deprecate the jax.experimental.host_callback module. 2024-03-21 09:11:17 +02:00
Sergei Lebedev
930aaa5e47 Deprecated the jax.experimental.maps submodule
PiperOrigin-RevId: 614082251
2024-03-08 16:50:52 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
Jake VanderPlas
053e2cff11 Update array-api-tests to most recent commit 2023-11-28 08:17:45 -08:00
Jake VanderPlas
271d31c1c8 Add jax.experimental.array_api interface 2023-11-16 14:21:04 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
George Necula
5001a21bad Move primitive_harness.py to jax._src.internal_test_util.test_harnesses.
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.

Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.

  * E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
  * E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
  * E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.

Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)

This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.

PiperOrigin-RevId: 581016785
2023-11-09 13:58:00 -08:00
Jake VanderPlas
6f3f0d5e57 build: write appropriate version strings to build artifacts 2023-09-07 08:45:48 -07:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Jake VanderPlas
f8a4afd1f6 Fix mypy error on Python 3.9 2023-08-16 13:27:11 -07:00