mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Remove deprecated jax.experimental.array_api
This commit is contained in:
parent
52cc5c7f05
commit
c7b0d681bd
4
.github/workflows/ci-build.yaml
vendored
4
.github/workflows/ci-build.yaml
vendored
@ -140,8 +140,8 @@ jobs:
|
||||
JAX_ARRAY: 1
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
|
||||
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
|
||||
|
||||
|
||||
documentation_render:
|
||||
|
@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
|
||||
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
|
||||
and `jax.errors.JaxRuntimeError` respectively.
|
||||
* The `jax.experimental.array_api` module has been removed after being
|
||||
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
|
||||
the array API directly.
|
||||
|
||||
## jax 0.4.38 (Dec 17, 2024)
|
||||
|
||||
|
@ -659,7 +659,7 @@ You can find the up-to-date command to run doctests in
|
||||
E.g., you can run:
|
||||
|
||||
```
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
||||
```
|
||||
|
||||
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
|
||||
|
@ -1,28 +0,0 @@
|
||||
``jax.experimental.array_api`` module
|
||||
=====================================
|
||||
|
||||
.. note::
|
||||
The ``jax.experimental.array_api`` module is deprecated as of JAX v0.4.32, and
|
||||
importing ``jax.experimental.array_api`` is no longer necessary. {mod}`jax.numpy`
|
||||
implements the array API standard directly by default. See :ref:`python-array-api`
|
||||
for details.
|
||||
|
||||
This module includes experimental JAX support for the `Python array API standard`_.
|
||||
Support for this is currently experimental and not fully complete.
|
||||
|
||||
Example Usage::
|
||||
|
||||
>>> from jax.experimental import array_api as xp
|
||||
|
||||
>>> xp.__array_api_version__
|
||||
'2023.12'
|
||||
|
||||
>>> arr = xp.arange(1000)
|
||||
|
||||
>>> arr.sum()
|
||||
Array(499500, dtype=int32)
|
||||
|
||||
The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
|
||||
and implements most of the API listed in the standard.
|
||||
|
||||
.. _Python array API standard: https://data-apis.org/array-api/
|
@ -14,7 +14,6 @@ Experimental Modules
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.array_api
|
||||
jax.experimental.checkify
|
||||
jax.experimental.compilation_cache
|
||||
jax.experimental.custom_partitioning
|
||||
|
@ -542,7 +542,8 @@ Python Array API standard
|
||||
|
||||
Prior to JAX v0.4.32, you must ``import jax.experimental.array_api`` in order
|
||||
to enable the array API for JAX arrays. After JAX v0.4.32, importing this
|
||||
module is no longer required, and will raise a deprecation warning.
|
||||
module is no longer required, and will raise a deprecation warning. After
|
||||
JAX v0.5.0, this import will raise an error.
|
||||
|
||||
Starting with JAX v0.4.32, :class:`jax.Array` and :mod:`jax.numpy` are compatible
|
||||
with the `Python Array API Standard`_. You can access the Array API namespace via
|
||||
|
13
jax/BUILD
13
jax/BUILD
@ -1053,19 +1053,6 @@ pytype_library(
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_array_api",
|
||||
srcs = glob(
|
||||
[
|
||||
"experimental/array_api/*.py",
|
||||
],
|
||||
),
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":jax",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_sparse",
|
||||
srcs = glob(
|
||||
|
@ -1,32 +0,0 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
import sys as _sys
|
||||
import warnings as _warnings
|
||||
|
||||
import jax.numpy as _array_api
|
||||
|
||||
# Added 2024-08-01
|
||||
_warnings.warn(
|
||||
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
|
||||
"jax.numpy supports the array API by default.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
_sys.modules['jax.experimental.array_api'] = _array_api
|
||||
|
||||
del _array_api, _sys, _warnings
|
@ -59,7 +59,6 @@ jax_py_test(
|
||||
srcs = ["array_api_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:experimental_array_api",
|
||||
"//jax:test_util",
|
||||
] + py_deps("absl/testing"),
|
||||
)
|
||||
|
@ -246,12 +246,6 @@ class ArrayAPISmokeTest(absltest.TestCase):
|
||||
self.assertIsInstance(x, jax.Array)
|
||||
self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE)
|
||||
|
||||
def test_deprecated_import(self):
|
||||
msg = "jax.experimental.array_api import is no longer required"
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
import jax.experimental.array_api as nx
|
||||
self.assertIs(nx, ARRAY_API_NAMESPACE)
|
||||
|
||||
|
||||
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user