26 Commits

Author SHA1 Message Date
Peter Hawkins
6ae01247f0 Fix pytest failures from compilation cache test.
The names of the functions in the compilation cache tests changed, causing warnings emitted by that test to become errors.
2024-04-29 11:08:07 -04:00
Jake VanderPlas
d33144e298 CI: avoid deprecated ruff configurations 2024-04-08 14:05:22 -07:00
George Necula
ca59971bef [host_callback] Deprecate the jax.experimental.host_callback module. 2024-03-21 09:11:17 +02:00
Sergei Lebedev
930aaa5e47 Deprecated the jax.experimental.maps submodule
PiperOrigin-RevId: 614082251
2024-03-08 16:50:52 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
Jake VanderPlas
053e2cff11 Update array-api-tests to most recent commit 2023-11-28 08:17:45 -08:00
Jake VanderPlas
271d31c1c8 Add jax.experimental.array_api interface 2023-11-16 14:21:04 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
George Necula
5001a21bad Move primitive_harness.py to jax._src.internal_test_util.test_harnesses.
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.

Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.

  * E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
  * E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
  * E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.

Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)

This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.

PiperOrigin-RevId: 581016785
2023-11-09 13:58:00 -08:00
Jake VanderPlas
6f3f0d5e57 build: write appropriate version strings to build artifacts 2023-09-07 08:45:48 -07:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Jake VanderPlas
f8a4afd1f6 Fix mypy error on Python 3.9 2023-08-16 13:27:11 -07:00
Jake VanderPlas
227eec159a Ignore numpy deprecation warning 2023-08-14 10:08:57 -07:00
Jake VanderPlas
4bb54d32d8 mypy: suppress annotation-unchecked notes 2023-07-17 11:18:48 -07:00
Jake VanderPlas
9d1f3b4dd2 pre-commit: update mypy to most recent version 2023-07-12 10:41:51 -07:00
Yash Katariya
677b0d9d3f Ignore JAX_USE_PJRT_C_API_ON_TPU=false user warning raised.
PiperOrigin-RevId: 542570578
2023-06-22 08:42:15 -07:00
Skye Wanderman-Milne
d77bfbe5c8 Ignore ml_dtypes.float8_e4m3b11 warning until new jaxlib is released 2023-06-06 11:34:01 -07:00
Jake VanderPlas
13f7291ff6 Remove obsolete warning suppression from pyproject.toml 2023-05-03 13:16:41 -07:00
Jake VanderPlas
9cfe77d5e1 Remove use of deprecated make_sharded_device_array 2023-05-03 10:11:29 -07:00
Jake VanderPlas
5310562250 mypy: type-check ml_dtypes 2023-04-17 11:35:59 -07:00
Jake VanderPlas
4f0edc08a3 [typing] ignore zstandard in mypy 2023-04-17 11:22:08 -07:00
Saurav Maheshkar
cfd8762d4a feat: move configurations to pyproject 2023-04-15 02:39:39 +01:00
Jake VanderPlas
f282c251d4 Add minimal pyproject.toml specifying build system
Replaces #15274, Fixes #15256

PiperOrigin-RevId: 520367622
2023-03-29 10:08:30 -07:00