mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Revert: Add deprecation warnings to DA, SDA and GDA.
This change is currently overly noisy for users. PiperOrigin-RevId: 489455729
This commit is contained in:
parent
7a3dbcf94e
commit
9f2a6acb61
@ -40,7 +40,8 @@ disable `jax.Array` and see if the issues go away.
|
||||
|
||||
### How can I disable jax.Array for now?
|
||||
|
||||
Through March 15, 2023 it will be possible to disable jax.Array by:
|
||||
You can disable `jax.Array` by: (After a certain date (TBD), the option to
|
||||
disable jax.Array won't exist)
|
||||
|
||||
* setting the shell environment variable `JAX_ARRAY` to something falsey
|
||||
(e.g., `0`);
|
||||
|
@ -60,12 +60,6 @@ def make_device_array(
|
||||
This is to be used only within JAX. It will return either a PythonDeviceArray
|
||||
or a C++ equivalent implementation.
|
||||
"""
|
||||
warnings.warn(
|
||||
'DeviceArray has been deprecated. '
|
||||
'Please use `jax.Array`. See '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
|
||||
'to migrate to `jax.Array`.', DeprecationWarning)
|
||||
|
||||
if isinstance(device_buffer, xc.Buffer):
|
||||
|
||||
if device_buffer.aval == aval and device_buffer._device == device:
|
||||
|
@ -15,7 +15,6 @@
|
||||
from collections import Counter
|
||||
import dataclasses
|
||||
import functools
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
|
||||
|
||||
@ -264,12 +263,6 @@ class GlobalDeviceArray:
|
||||
device_buffers: Union[xb.ShardedBuffer, Sequence[DeviceArray]],
|
||||
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
|
||||
_enable_checks: bool = True):
|
||||
warnings.warn(
|
||||
'GlobalDeviceArray has been deprecated. '
|
||||
'Please use `jax.Array`. See '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
|
||||
'to migrate to `jax.Array`.', DeprecationWarning)
|
||||
|
||||
self._global_shape = global_shape
|
||||
self._global_mesh = global_mesh
|
||||
self._mesh_axes = mesh_axes
|
||||
|
@ -41,7 +41,6 @@ import operator as op
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import warnings
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
|
||||
TYPE_CHECKING)
|
||||
@ -666,12 +665,6 @@ def make_sharded_device_array(
|
||||
be returned, for JAX extensions not implementing the C++ API.
|
||||
indices: For caching purposes, will be computed if `None`.
|
||||
"""
|
||||
warnings.warn(
|
||||
'ShardedDeviceArray has been deprecated. '
|
||||
'Please use `jax.Array`. See '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
|
||||
'to migrate to `jax.Array`.', DeprecationWarning)
|
||||
|
||||
if sharding_spec is None:
|
||||
sharding_spec = _create_pmap_sharding_spec(aval)
|
||||
|
||||
|
@ -19,11 +19,8 @@ filterwarnings =
|
||||
# numpy uses distutils which is deprecated
|
||||
ignore:The distutils.* is deprecated.*:DeprecationWarning
|
||||
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
|
||||
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
|
||||
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
|
||||
default:Error reading persistent compilation cache entry for 'jit__lambda_'
|
||||
default:Error writing persistent compilation cache entry for 'jit__lambda_'
|
||||
ignore:DeviceArray has been deprecated.*:DeprecationWarning
|
||||
ignore:ShardedDeviceArray has been deprecated.*:DeprecationWarning
|
||||
ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
@ -1814,7 +1814,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertGreaterEqual(len(w), 1)
|
||||
self.assertIn("The jitted function foo includes a pmap",
|
||||
str(w[0].message))
|
||||
str(w[-1].message))
|
||||
|
||||
def testPsumZeroCotangents(self):
|
||||
# https://github.com/google/jax/issues/3651
|
||||
|
Loading…
x
Reference in New Issue
Block a user