mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
MAINT Do not import the config object in JAX internals
The longer term goal here is to move away from having the config object as part of the public API and migrate towards module-level functions instead. Note that we can preserve the dynamic attribute lookup behavior of the config object via a module-level `__getattr__`
This commit is contained in:
parent
e7dff2c816
commit
1079304259
@ -24,7 +24,7 @@
|
||||
# uniformity
|
||||
|
||||
from contextlib import contextmanager
|
||||
from jax._src.config import enable_x64 as _jax_enable_x64
|
||||
from jax._src import config
|
||||
|
||||
@contextmanager
|
||||
def enable_x64(new_val: bool = True):
|
||||
@ -42,7 +42,7 @@ def enable_x64(new_val: bool = True):
|
||||
--------
|
||||
jax.experimental.enable_x64 : temporarily enable X64 mode.
|
||||
"""
|
||||
with _jax_enable_x64(new_val):
|
||||
with config.enable_x64(new_val):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
@ -61,5 +61,5 @@ def disable_x64():
|
||||
--------
|
||||
jax.experimental.enable_x64 : temporarily enable X64 mode.
|
||||
"""
|
||||
with _jax_enable_x64(False):
|
||||
with config.enable_x64(False):
|
||||
yield
|
||||
|
@ -23,13 +23,12 @@ import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax._src import cache_key
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import compilation_cache_include_metadata_in_key
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
@ -244,7 +243,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
with compilation_cache_include_metadata_in_key(include_metadata):
|
||||
with config.compilation_cache_include_metadata_in_key(include_metadata):
|
||||
key1 = cache_key.get(computation1, devices, compile_options, backend)
|
||||
key2 = cache_key.get(computation2, devices, compile_options, backend)
|
||||
self.assertEqual(include_metadata, key1 != key2)
|
||||
|
@ -16,11 +16,10 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import rnn
|
||||
|
||||
from jax._src.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
@ -25,7 +25,7 @@ from scipy.spatial.transform import Slerp as osp_Slerp
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user