Error if jax_array or jax_jit_pjit_api_merge is set to False.

PiperOrigin-RevId: 517485597
This commit is contained in:
Yash Katariya 2023-03-17 12:57:18 -07:00 committed by jax authors
parent 7c7c60eabf
commit 207cc10058
4 changed files with 22 additions and 20 deletions

View File

@ -8,6 +8,11 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.7
* Changes
* As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
`jax.config.jax_array` cannot be disabled anymore.
* `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore.
* Deprecations
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
for which it is an alias.

View File

@ -173,6 +173,15 @@ from jax import util as util
# Also circular dependency.
from jax._src.array import Shard as Shard
# TODO(yashkatariya): Remove after 2 jax releases from 0.4.6
if not config.jax_jit_pjit_api_merge:
raise ValueError(
'jax.config.jax_jit_pjit_api_merge cannot be disabled after jax 0.4.7'
' release. Please downgrade to jax and jaxlib 0.4.6 if you want to'
' disable jax.config.jax_jit_pjit_api_merge.'
)
import jax.lib # TODO(phawkins): remove this export.
# trailer

View File

@ -19,7 +19,6 @@ import logging
import os
import sys
import threading
import warnings
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
from jax._src import lib
@ -756,20 +755,18 @@ parallel_functions_output_gda = config.define_bool_state(
def _update_jax_array_global(val):
if val is not None and not val:
warnings.warn(
'DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been '
'deprecated. Please use `jax.Array`. See '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
'to migrate to `jax.Array`.', DeprecationWarning)
raise ValueError(
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
lib.jax_jit.global_state().jax_array = val
def _update_jax_array_thread_local(val):
if val is not None and not val:
warnings.warn(
'DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been '
'deprecated. Please use `jax.Array`. See '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
'to migrate to `jax.Array`.', DeprecationWarning)
raise ValueError(
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
lib.jax_jit.thread_local_state().jax_array = val
jax_array = config.define_bool_state(

View File

@ -37,10 +37,6 @@ jax_test(
jax_test(
name = "dynamic_api_test",
srcs = ["dynamic_api_test.py"],
# TODO(https://github.com/google/jax/issues/12291): Enable when jax.Array is supported.
env = {
"JAX_JIT_PJIT_API_MERGE": "0",
},
shard_count = 2,
)
@ -76,10 +72,6 @@ jax_test(
jax_test(
name = "custom_object_test",
srcs = ["custom_object_test.py"],
# TODO(https://github.com/google/jax/issues/12291): Enable when jax.Array is supported.
env = {
"JAX_JIT_PJIT_API_MERGE": "0",
},
)
jax_test(
@ -111,7 +103,6 @@ jax_test(
# No need to test all other configs.
enable_configs = [
"cpu",
"cpu_jit_pjit_api_merge",
],
)