Revert: Add deprecation warnings to DA, SDA and GDA.

This change is currently overly noisy for users.

PiperOrigin-RevId: 489455729
This commit is contained in:
Peter Hawkins 2022-11-18 06:05:28 -08:00 committed by jax authors
parent 7a3dbcf94e
commit 9f2a6acb61
6 changed files with 4 additions and 26 deletions

View File

@ -40,7 +40,8 @@ disable `jax.Array` and see if the issues go away.
### How can I disable jax.Array for now?
Through March 15, 2023 it will be possible to disable jax.Array by:
You can disable `jax.Array` by: (After a certain date (TBD), the option to
disable jax.Array won't exist)
* setting the shell environment variable `JAX_ARRAY` to something falsey
(e.g., `0`);

View File

@ -60,12 +60,6 @@ def make_device_array(
This is to be used only within JAX. It will return either a PythonDeviceArray
or a C++ equivalent implementation.
"""
warnings.warn(
'DeviceArray has been deprecated. '
'Please use `jax.Array`. See '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
'to migrate to `jax.Array`.', DeprecationWarning)
if isinstance(device_buffer, xc.Buffer):
if device_buffer.aval == aval and device_buffer._device == device:

View File

@ -15,7 +15,6 @@
from collections import Counter
import dataclasses
import functools
import warnings
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
@ -264,12 +263,6 @@ class GlobalDeviceArray:
device_buffers: Union[xb.ShardedBuffer, Sequence[DeviceArray]],
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
_enable_checks: bool = True):
warnings.warn(
'GlobalDeviceArray has been deprecated. '
'Please use `jax.Array`. See '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
'to migrate to `jax.Array`.', DeprecationWarning)
self._global_shape = global_shape
self._global_mesh = global_mesh
self._mesh_axes = mesh_axes

View File

@ -41,7 +41,6 @@ import operator as op
import sys
import threading
import types
import warnings
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
TYPE_CHECKING)
@ -666,12 +665,6 @@ def make_sharded_device_array(
be returned, for JAX extensions not implementing the C++ API.
indices: For caching purposes, will be computed if `None`.
"""
warnings.warn(
'ShardedDeviceArray has been deprecated. '
'Please use `jax.Array`. See '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
'to migrate to `jax.Array`.', DeprecationWarning)
if sharding_spec is None:
sharding_spec = _create_pmap_sharding_spec(aval)

View File

@ -19,11 +19,8 @@ filterwarnings =
# numpy uses distutils which is deprecated
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
default:Error reading persistent compilation cache entry for 'jit__lambda_'
default:Error writing persistent compilation cache entry for 'jit__lambda_'
ignore:DeviceArray has been deprecated.*:DeprecationWarning
ignore:ShardedDeviceArray has been deprecated.*:DeprecationWarning
ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"

View File

@ -1814,7 +1814,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertGreaterEqual(len(w), 1)
self.assertIn("The jitted function foo includes a pmap",
str(w[0].message))
str(w[-1].message))
def testPsumZeroCotangents(self):
# https://github.com/google/jax/issues/3651