Remove deprecated jax.experimental.array_api

This commit is contained in:
Jake VanderPlas 2025-01-06 15:19:02 -08:00
parent 52cc5c7f05
commit c7b0d681bd
10 changed files with 8 additions and 85 deletions

View File

@ -140,8 +140,8 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
documentation_render:

View File

@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
* The `jax.experimental.array_api` module has been removed after being
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
the array API directly.
## jax 0.4.38 (Dec 17, 2024)

View File

@ -659,7 +659,7 @@ You can find the up-to-date command to run doctests in
E.g., you can run:
```
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
```
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in

View File

@ -1,28 +0,0 @@
``jax.experimental.array_api`` module
=====================================
.. note::
The ``jax.experimental.array_api`` module is deprecated as of JAX v0.4.32, and
importing ``jax.experimental.array_api`` is no longer necessary. {mod}`jax.numpy`
implements the array API standard directly by default. See :ref:`python-array-api`
for details.
This module includes experimental JAX support for the `Python array API standard`_.
Support for this is currently experimental and not fully complete.
Example Usage::
>>> from jax.experimental import array_api as xp
>>> xp.__array_api_version__
'2023.12'
>>> arr = xp.arange(1000)
>>> arr.sum()
Array(499500, dtype=int32)
The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
and implements most of the API listed in the standard.
.. _Python array API standard: https://data-apis.org/array-api/

View File

@ -14,7 +14,6 @@ Experimental Modules
.. toctree::
:maxdepth: 1
jax.experimental.array_api
jax.experimental.checkify
jax.experimental.compilation_cache
jax.experimental.custom_partitioning

View File

@ -542,7 +542,8 @@ Python Array API standard
Prior to JAX v0.4.32, you must ``import jax.experimental.array_api`` in order
to enable the array API for JAX arrays. After JAX v0.4.32, importing this
module is no longer required, and will raise a deprecation warning.
module is no longer required, and will raise a deprecation warning. After
JAX v0.5.0, this import will raise an error.
Starting with JAX v0.4.32, :class:`jax.Array` and :mod:`jax.numpy` are compatible
with the `Python Array API Standard`_. You can access the Array API namespace via

View File

@ -1053,19 +1053,6 @@ pytype_library(
deps = [":jax"],
)
pytype_library(
name = "experimental_array_api",
srcs = glob(
[
"experimental/array_api/*.py",
],
),
visibility = [":internal"],
deps = [
":jax",
],
)
pytype_library(
name = "experimental_sparse",
srcs = glob(

View File

@ -1,32 +0,0 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
import sys as _sys
import warnings as _warnings
import jax.numpy as _array_api
# Added 2024-08-01
_warnings.warn(
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
"jax.numpy supports the array API by default.",
DeprecationWarning, stacklevel=2
)
_sys.modules['jax.experimental.array_api'] = _array_api
del _array_api, _sys, _warnings

View File

@ -59,7 +59,6 @@ jax_py_test(
srcs = ["array_api_test.py"],
deps = [
"//jax",
"//jax:experimental_array_api",
"//jax:test_util",
] + py_deps("absl/testing"),
)

View File

@ -246,12 +246,6 @@ class ArrayAPISmokeTest(absltest.TestCase):
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE)
def test_deprecated_import(self):
msg = "jax.experimental.array_api import is no longer required"
with self.assertWarnsRegex(DeprecationWarning, msg):
import jax.experimental.array_api as nx
self.assertIs(nx, ARRAY_API_NAMESPACE)
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):