MAINT Do not import the config object in JAX internals

The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.

Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
This commit is contained in:
Sergei Lebedev 2023-10-18 10:37:19 +01:00
parent e7dff2c816
commit 1079304259
5 changed files with 8 additions and 10 deletions

View File

@ -24,7 +24,7 @@
# uniformity
from contextlib import contextmanager
from jax._src.config import enable_x64 as _jax_enable_x64
from jax._src import config
@contextmanager
def enable_x64(new_val: bool = True):
@ -42,7 +42,7 @@ def enable_x64(new_val: bool = True):
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
with _jax_enable_x64(new_val):
with config.enable_x64(new_val):
yield
@contextmanager
@ -61,5 +61,5 @@ def disable_x64():
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
with _jax_enable_x64(False):
with config.enable_x64(False):
yield

View File

@ -23,13 +23,12 @@ import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import lax
from jax._src import cache_key
from jax._src import compiler
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.config import compilation_cache_include_metadata_in_key
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
@ -244,7 +243,7 @@ class CacheKeyTest(jtu.JaxTestCase):
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
with compilation_cache_include_metadata_in_key(include_metadata):
with config.compilation_cache_include_metadata_in_key(include_metadata):
key1 = cache_key.get(computation1, devices, compile_options, backend)
key2 = cache_key.get(computation2, devices, compile_options, backend)
self.assertEqual(include_metadata, key1 != key2)

View File

@ -16,11 +16,10 @@ from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax.experimental import rnn
from jax._src.config import config
config.parse_flags_with_absl()

View File

@ -14,7 +14,7 @@
from absl.testing import absltest
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -25,7 +25,7 @@ from scipy.spatial.transform import Slerp as osp_Slerp
import jax.numpy as jnp
import numpy as onp
from jax.config import config
from jax import config
config.parse_flags_with_absl()