mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Error if jax_array or jax_jit_pjit_api_merge is set to False.
PiperOrigin-RevId: 517485597
This commit is contained in:
parent
7c7c60eabf
commit
207cc10058
@ -8,6 +8,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.7
|
||||
|
||||
* Changes
|
||||
* As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
|
||||
`jax.config.jax_array` cannot be disabled anymore.
|
||||
* `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore.
|
||||
|
||||
* Deprecations
|
||||
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
|
||||
for which it is an alias.
|
||||
|
@ -173,6 +173,15 @@ from jax import util as util
|
||||
# Also circular dependency.
|
||||
from jax._src.array import Shard as Shard
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove after 2 jax releases from 0.4.6
|
||||
if not config.jax_jit_pjit_api_merge:
|
||||
raise ValueError(
|
||||
'jax.config.jax_jit_pjit_api_merge cannot be disabled after jax 0.4.7'
|
||||
' release. Please downgrade to jax and jaxlib 0.4.6 if you want to'
|
||||
' disable jax.config.jax_jit_pjit_api_merge.'
|
||||
)
|
||||
|
||||
import jax.lib # TODO(phawkins): remove this export.
|
||||
|
||||
# trailer
|
||||
|
@ -19,7 +19,6 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
|
||||
|
||||
from jax._src import lib
|
||||
@ -756,20 +755,18 @@ parallel_functions_output_gda = config.define_bool_state(
|
||||
|
||||
def _update_jax_array_global(val):
|
||||
if val is not None and not val:
|
||||
warnings.warn(
|
||||
'DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been '
|
||||
'deprecated. Please use `jax.Array`. See '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
|
||||
'to migrate to `jax.Array`.', DeprecationWarning)
|
||||
raise ValueError(
|
||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||
' jax.config.jax_array.')
|
||||
lib.jax_jit.global_state().jax_array = val
|
||||
|
||||
def _update_jax_array_thread_local(val):
|
||||
if val is not None and not val:
|
||||
warnings.warn(
|
||||
'DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been '
|
||||
'deprecated. Please use `jax.Array`. See '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html on how '
|
||||
'to migrate to `jax.Array`.', DeprecationWarning)
|
||||
raise ValueError(
|
||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||
' jax.config.jax_array.')
|
||||
lib.jax_jit.thread_local_state().jax_array = val
|
||||
|
||||
jax_array = config.define_bool_state(
|
||||
|
@ -37,10 +37,6 @@ jax_test(
|
||||
jax_test(
|
||||
name = "dynamic_api_test",
|
||||
srcs = ["dynamic_api_test.py"],
|
||||
# TODO(https://github.com/google/jax/issues/12291): Enable when jax.Array is supported.
|
||||
env = {
|
||||
"JAX_JIT_PJIT_API_MERGE": "0",
|
||||
},
|
||||
shard_count = 2,
|
||||
)
|
||||
|
||||
@ -76,10 +72,6 @@ jax_test(
|
||||
jax_test(
|
||||
name = "custom_object_test",
|
||||
srcs = ["custom_object_test.py"],
|
||||
# TODO(https://github.com/google/jax/issues/12291): Enable when jax.Array is supported.
|
||||
env = {
|
||||
"JAX_JIT_PJIT_API_MERGE": "0",
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -111,7 +103,6 @@ jax_test(
|
||||
# No need to test all other configs.
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"cpu_jit_pjit_api_merge",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user