diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 30fb04ace..75cd38d10 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -33,9 +33,7 @@ from jax.experimental import multihost_utils import jax.numpy as jnp import numpy as np -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() partial = functools.partial diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index bd8dd42d1..b1b6b625c 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -15,12 +15,12 @@ import google_benchmark as benchmark -from jax import config +import jax from jax import core from jax._src.numpy import lax_numpy from jax.experimental import export -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @benchmark.register diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 8434384c4..90a6cb3bb 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily. If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by: * setting the `JAX_DEBUG_NANS=True` environment variable; -* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file; -* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; +* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; ### Example(s) ```python -from jax import config -config.update("jax_debug_nans", True) +import jax +jax.config.update("jax_debug_nans", True) def f(x, y): return x / y @@ -47,14 +47,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! You can disable JIT-compilation by: * setting the `JAX_DISABLE_JIT=True` environment variable; -* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file; -* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`; +* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`; ### Examples ```python -from jax import config -config.update("jax_disable_jit", True) +import jax +jax.config.update("jax_disable_jit", True) def f(x): y = jnp.log(x) diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 9a020b360..35e0f6895 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more! **TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python -from jax import config -config.update("jax_debug_nans", True) +import jax +jax.config.update("jax_debug_nans", True) def f(x, y): return x / y diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 13fdd572b..d8dffdb8a 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -1946,9 +1946,9 @@ "\n", "* setting the `JAX_DEBUG_NANS=True` environment variable;\n", "\n", - "* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n", + "* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n", "\n", - "* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n", + "* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n", "\n", "This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n", "\n", @@ -2141,24 +2141,24 @@ "\n", " ```python\n", " # again, this only works on startup!\n", - " from jax import config\n", - " config.update(\"jax_enable_x64\", True)\n", + " import jax\n", + " jax.config.update(\"jax_enable_x64\", True)\n", " ```\n", "\n", "3. You can parse command-line flags with `absl.app.run(main)`\n", "\n", " ```python\n", - " from jax import config\n", - " config.config_with_absl()\n", + " import jax\n", + " jax.config.config_with_absl()\n", " ```\n", "\n", "4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n", "\n", " ```python\n", - " from jax import config\n", + " import jax\n", " if __name__ == '__main__':\n", - " # calls config.config_with_absl() *and* runs absl parsing\n", - " config.parse_flags_with_absl()\n", + " # calls jax.config.config_with_absl() *and* runs absl parsing\n", + " jax.config.parse_flags_with_absl()\n", " ```\n", "\n", "Note that #2-#4 work for _any_ of JAX's configuration options.\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 0e5af8b04..e63d64d94 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo * setting the `JAX_DEBUG_NANS=True` environment variable; -* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file; +* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; -* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time. @@ -1087,24 +1087,24 @@ There are a few ways to do this: ```python # again, this only works on startup! - from jax import config - config.update("jax_enable_x64", True) + import jax + jax.config.update("jax_enable_x64", True) ``` 3. You can parse command-line flags with `absl.app.run(main)` ```python - from jax import config - config.config_with_absl() + import jax + jax.config.config_with_absl() ``` 4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use ```python - from jax import config + import jax if __name__ == '__main__': - # calls config.config_with_absl() *and* runs absl parsing - config.parse_flags_with_absl() + # calls jax.config.config_with_absl() *and* runs absl parsing + jax.config.parse_flags_with_absl() ``` Note that #2-#4 work for _any_ of JAX's configuration options. diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst index e81509e2a..5e4e7ec65 100644 --- a/docs/rank_promotion_warning.rst +++ b/docs/rank_promotion_warning.rst @@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code: .. code-block:: python - from jax import config - config.update("jax_numpy_rank_promotion", "warn") + import jax + jax.config.update("jax_numpy_rank_promotion", "warn") You can also set the option using the environment variable :code:`JAX_NUMPY_RANK_PROMOTION`, for example as diff --git a/examples/examples_test.py b/examples/examples_test.py index b8b4d11e2..c9cb2991c 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized import numpy as np +import jax from jax import lax from jax import random import jax.numpy as jnp @@ -30,8 +31,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from examples import kernel_lsq sys.path.pop() -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): diff --git a/examples/gaussian_process_regression.py b/examples/gaussian_process_regression.py index c42a024d4..75f7398d1 100644 --- a/examples/gaussian_process_regression.py +++ b/examples/gaussian_process_regression.py @@ -17,10 +17,11 @@ from absl import app from functools import partial + +import jax from jax import grad from jax import jit from jax import vmap -from jax import config import jax.numpy as jnp import jax.random as random import jax.scipy as scipy @@ -125,5 +126,5 @@ def main(unused_argv): mu.flatten() - std * 2, mu.flatten() + std * 2) if __name__ == "__main__": - config.config_with_absl() + jax.config.config_with_absl() app.run(main) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 5cc08eb05..b57b7d085 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -23,6 +23,7 @@ import collections import itertools from typing import Union, cast +import jax from jax import lax from jax._src import dtypes from jax._src import test_util @@ -30,8 +31,7 @@ from jax._src.util import safe_map, safe_zip import numpy as np -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index b9e3192f9..6cb5347e9 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -24,16 +24,14 @@ from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import test_util as jtu -from jax import config from jax._src import array from jax.sharding import NamedSharding, GSPMDSharding from jax.sharding import PartitionSpec as P from jax.experimental.array_serialization import serialization import numpy as np import tensorstore as ts -import unittest -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py index 562369cdb..293484291 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py @@ -16,13 +16,13 @@ import os from absl import flags from absl.testing import absltest from absl.testing import parameterized +import jax from jax._src import test_util as jtu -from jax import config from jax.experimental.jax2tf.examples import keras_reuse_main from jax.experimental.jax2tf.tests import tf_test_util -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 47c5c8360..bd31c19ba 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -27,7 +27,7 @@ import tarfile from typing import Callable, Optional from absl.testing import absltest -from jax import config +import jax from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu from jax._src.lib import xla_extension @@ -37,7 +37,7 @@ import jax.numpy as jnp import tensorflow as tf -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def serialize_directory(directory_path): diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 57eea5f6a..3a0bdffd0 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -23,7 +23,6 @@ from absl import logging from absl.testing import absltest from absl.testing import parameterized import jax -from jax import config from jax import dlpack from jax import dtypes from jax import lax @@ -42,7 +41,7 @@ try: except ImportError: tf = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _maybe_jit(with_jit: bool, func: Callable) -> Callable: @@ -1151,15 +1150,15 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): super().setUp() def override_serialization_version(self, version_override: int): - version = config.jax_serialization_version + version = jax.config.jax_serialization_version if version != version_override: - self.addCleanup(partial(config.update, + self.addCleanup(partial(jax.config.update, "jax_serialization_version", version_override)) - config.update("jax_serialization_version", version_override) + jax.config.update("jax_serialization_version", version_override) logging.info( "Using JAX serialization version %s", - config.jax_serialization_version) + jax.config.jax_serialization_version) def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX @@ -1275,7 +1274,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): @_parameterized_jit def test_shape_poly_static_output_shape(self, with_jit=True): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([0.7, 0.8], dtype=np.float32) @@ -1289,7 +1288,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): @_parameterized_jit def test_shape_poly(self, with_jit=False): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([7, 8, 9, 10], dtype=np.float32) def fun_jax(x): @@ -1308,7 +1307,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): @_parameterized_jit def test_shape_poly_pytree_result(self, with_jit=True): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([7, 8, 9, 10], dtype=np.float32) def fun_jax(x): @@ -1394,7 +1393,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): if kind == "bad_dim" and with_jit: # TODO: in jit more the error pops up later, at AddV2 expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2" - if kind == "bad_dim" and config.jax2tf_default_native_serialization: + if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization: # TODO(b/268386622): call_tf with shape polymorphism and native serialization. expect_error = "Error compiling TensorFlow function" fun_tf_rt = _maybe_tf_jit(with_jit, @@ -1432,7 +1431,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): f4_function=False, f4_saved_model=False): if (f2_saved_model and f4_saved_model and - not config.jax2tf_default_native_serialization): + not jax.config.jax2tf_default_native_serialization): # TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients # when saving f4, but only with non-native serialization. raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients") diff --git a/jax/experimental/jax2tf/tests/control_flow_ops_test.py b/jax/experimental/jax2tf/tests/control_flow_ops_test.py index 253a5ffc6..c66a6d696 100644 --- a/jax/experimental/jax2tf/tests/control_flow_ops_test.py +++ b/jax/experimental/jax2tf/tests/control_flow_ops_test.py @@ -23,8 +23,7 @@ import numpy as np from jax.experimental.jax2tf.tests import tf_test_util -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase): diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index 63e8928ee..0a4bf61f8 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -39,12 +39,11 @@ from absl import logging import numpy.random as npr -import jax -from jax import config # Must import before TF +import jax # Must import before TF from jax.experimental import jax2tf # Defines needed flags from jax._src import test_util # Defines needed flags -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Import after parsing flags from jax.experimental.jax2tf.tests import primitive_harness diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index 86ae81f83..37e7eb24f 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -25,8 +25,7 @@ from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class SavedModelTest(tf_test_util.JaxToTfTestCase): diff --git a/tests/ann_test.py b/tests/ann_test.py index ab35ce0c5..1d704c725 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -23,9 +23,7 @@ import jax from jax import lax from jax._src import test_util as jtu -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() ignore_jit_of_pmap_warning = partial( jtu.ignore_warning,message=".*jit-of-pmap.*") diff --git a/tests/aot_test.py b/tests/aot_test.py index dacfa620c..bca0d66ed 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -17,7 +17,6 @@ import contextlib import unittest from absl.testing import absltest import jax -from jax import config from jax._src import core from jax._src import test_util as jtu from jax._src.lib import xla_client as xc @@ -31,7 +30,7 @@ import jax.numpy as jnp from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/api_util_test.py b/tests/api_util_test.py index 7b7a479db..f78b5948f 100644 --- a/tests/api_util_test.py +++ b/tests/api_util_test.py @@ -16,12 +16,12 @@ import itertools as it from absl.testing import absltest from absl.testing import parameterized +import jax from jax._src import api_util from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ApiUtilTest(jtu.JaxTestCase): diff --git a/tests/array_test.py b/tests/array_test.py index 7c8d4c355..0d8dba0bd 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -40,8 +40,7 @@ from jax.sharding import PartitionSpec as P from jax._src import array from jax._src import prng -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/batching_test.py b/tests/batching_test.py index afbe9cf70..36e686443 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -37,8 +37,7 @@ from jax import vmap from jax.interpreters import batching from jax.tree_util import register_pytree_node -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # These are 'manual' tests for batching (vmap). The more exhaustive, more diff --git a/tests/clear_backends_test.py b/tests/clear_backends_test.py index f8d5271ce..9ea9cac3a 100644 --- a/tests/clear_backends_test.py +++ b/tests/clear_backends_test.py @@ -15,12 +15,11 @@ from absl.testing import absltest import jax -from jax import config from jax._src import api from jax._src import test_util as jtu from jax._src import xla_bridge as xb -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ClearBackendsTest(jtu.JaxTestCase): diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 2c3d2a258..830526826 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -28,8 +28,7 @@ from jax._src import test_util as jtu import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def high_precision_dot(a, b): diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 036f912a3..75ff39630 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -18,9 +18,9 @@ import unittest import numpy as np +import jax import jax.numpy as jnp from jax import jit, lax, make_jaxpr -from jax import config from jax.interpreters import mlir from jax.interpreters import xla @@ -34,7 +34,7 @@ from jax._src.lib import xla_client xc = xla_client xb = xla_bridge -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the # dictionaries associated with the following objects. diff --git a/tests/custom_root_test.py b/tests/custom_root_test.py index 88dee90aa..6a7eaab17 100644 --- a/tests/custom_root_test.py +++ b/tests/custom_root_test.py @@ -25,8 +25,7 @@ from jax._src import test_util as jtu import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def high_precision_dot(a, b): diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 8dc9818f8..e57439444 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -26,19 +26,18 @@ from jax import numpy as jnp from jax.experimental import pjit from jax._src.maps import xmap -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class DebugNaNsTest(jtu.JaxTestCase): def setUp(self): super().setUp() - self.cfg = config._read("jax_debug_nans") - config.update("jax_debug_nans", True) + self.cfg = jax.config._read("jax_debug_nans") + jax.config.update("jax_debug_nans", True) def tearDown(self): - config.update("jax_debug_nans", self.cfg) + jax.config.update("jax_debug_nans", self.cfg) super().tearDown() def testSinc(self): @@ -67,7 +66,7 @@ class DebugNaNsTest(jtu.JaxTestCase): ans.block_until_ready() def testJitComputationNaNContextManager(self): - config.update("jax_debug_nans", False) + jax.config.update("jax_debug_nans", False) A = jnp.array(0.) f = jax.jit(lambda x: 0. / x) ans = f(A) @@ -210,11 +209,11 @@ class DebugInfsTest(jtu.JaxTestCase): def setUp(self): super().setUp() - self.cfg = config._read("jax_debug_infs") - config.update("jax_debug_infs", True) + self.cfg = jax.config._read("jax_debug_infs") + jax.config.update("jax_debug_infs", True) def tearDown(self): - config.update("jax_debug_infs", self.cfg) + jax.config.update("jax_debug_infs", self.cfg) super().tearDown() def testSingleResultPrimitiveNoInf(self): diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 5e6d4388f..66488feb8 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -21,14 +21,13 @@ import unittest from absl.testing import absltest import jax -from jax import config from jax.experimental import pjit from jax._src import debugger from jax._src import test_util as jtu import jax.numpy as jnp import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]: fake_stdin = io.StringIO() diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index ce0e5e3b2..51c91d9aa 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -19,7 +19,6 @@ import unittest from absl.testing import absltest import jax from jax import lax -from jax import config from jax.experimental import pjit from jax.interpreters import pxla from jax._src import ad_checkpoint @@ -35,7 +34,7 @@ try: except ModuleNotFoundError: rich = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() debug_print = debugging.debug_print diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index c704d7e10..13e9cc5bb 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -23,7 +23,6 @@ from absl.testing import parameterized import jax import jax.numpy as jnp from jax import lax -from jax import config from jax.interpreters import batching import jax._src.lib @@ -31,7 +30,7 @@ import jax._src.util from jax._src import core from jax._src import test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") diff --git a/tests/extend_test.py b/tests/extend_test.py index b49c1ac09..a926861eb 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -24,8 +24,7 @@ from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ExtendTest(jtu.JaxTestCase): diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 641f12ff0..cbbe56a62 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -24,8 +24,7 @@ from jax._src import test_util as jtu from jax._src.lax.control_flow import for_loop import jax.numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def remat_of_for_loop(nsteps, body, state, **kwargs): return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state, diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index e96f100b4..a288e1a5f 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -22,11 +22,11 @@ from absl.testing import parameterized import itertools as it import jax.numpy as jnp +import jax from jax import jit, jvp, vjp import jax._src.test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() npr.seed(0) diff --git a/tests/heap_profiler_test.py b/tests/heap_profiler_test.py index 6d3468e95..240eec1c8 100644 --- a/tests/heap_profiler_test.py +++ b/tests/heap_profiler_test.py @@ -17,11 +17,10 @@ from absl.testing import absltest import jax import jax._src.xla_bridge as xla_bridge -from jax import config import jax._src.test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class HeapProfilerTest(unittest.TestCase): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 9c5ab78cb..99d34f30e 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -30,7 +30,6 @@ from absl.testing import absltest import jax from jax import ad_checkpoint -from jax import config from jax import dtypes from jax import lax from jax import numpy as jnp @@ -46,7 +45,7 @@ xops = xla_client.ops import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class _TestingOutputStream: diff --git a/tests/image_test.py b/tests/image_test.py index 6204ec91c..f3cd56ed7 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -24,8 +24,6 @@ from jax import image from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config - # We use TensorFlow and PIL as reference implementations. try: import tensorflow as tf @@ -37,7 +35,7 @@ try: except ImportError: PIL_Image = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.all_floating inexact_dtypes = jtu.dtypes.inexact diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 572920fa4..ba47d2417 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -19,7 +19,6 @@ from unittest import SkipTest from absl.testing import absltest import jax from jax import lax, numpy as jnp -from jax import config from jax.experimental import host_callback as hcb from jax._src import core from jax._src import xla_bridge @@ -27,7 +26,7 @@ from jax._src.lib import xla_client import jax._src.test_util as jtu import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class InfeedTest(jtu.JaxTestCase): diff --git a/tests/jet_test.py b/tests/jet_test.py index c72057246..566119750 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -29,8 +29,7 @@ from jax.example_libraries import stax from jax.experimental.jet import jet, fact, zero_series from jax import lax -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def jvp_taylor(fun, primals, series): # Computes the Taylor series the slow way, with nested jvp. diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 885b08224..d98984be5 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -29,8 +29,7 @@ from jax.experimental.key_reuse._core import ( Source, Sink, Forward, KeyReuseSignature) from jax.experimental.key_reuse import _core -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() key = jax.eval_shape(jax.random.key, 0) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 630b08cc3..ab3a18317 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -31,8 +31,7 @@ from jax._src import test_util as jtu from jax._src.util import NumpyComplexWarning from jax.test_util import check_grads -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() compatible_shapes = [[(3,)], diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 5c737cc4d..0a17a1421 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -42,8 +42,7 @@ from jax._src.lax import control_flow as lax_control_flow from jax._src.lax.control_flow import for_loop from jax._src.maps import xmap -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Some tests are useful for testing both lax.cond and lax.switch. This function diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index 92259c8f4..423289f3d 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -27,8 +27,7 @@ from jax import lax import jax.numpy as jnp import jax._src.test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class EinsumTest(jtu.JaxTestCase): diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 0f40e9d4d..40c9eb3bc 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -24,8 +24,7 @@ import jax.numpy as jnp from jax._src import test_util as jtu from jax._src.numpy.ufunc_api import get_if_single_primitive -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def scalar_add(x, y): diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index cb0d9a0dc..edc344467 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -21,8 +21,7 @@ import jax from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class VectorizeTest(jtu.JaxTestCase): diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 564ecca86..be10f03fb 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -26,8 +26,7 @@ import jax from jax._src import test_util as jtu from jax.scipy import special as lsp_special -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py index 2d353d590..a09dcac53 100644 --- a/tests/lax_scipy_spectral_dac_test.py +++ b/tests/lax_scipy_spectral_dac_test.py @@ -14,6 +14,7 @@ import unittest +import jax from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu @@ -21,8 +22,7 @@ from jax._src.lax import eigh as lax_eigh from absl.testing import absltest -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() linear_sizes = [16, 97, 128] diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index e9f2e6bb9..cf3edbfd3 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -34,8 +34,7 @@ from jax._src import test_util as jtu from jax.scipy import special as lsp_special from jax.scipy import cluster as lsp_cluster -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) diff --git a/tests/lax_vmap_op_test.py b/tests/lax_vmap_op_test.py index 5d3028132..c7059a293 100644 --- a/tests/lax_vmap_op_test.py +++ b/tests/lax_vmap_op_test.py @@ -26,8 +26,7 @@ from jax._src import test_util as jtu from jax._src.internal_test_util import lax_test_util from jax._src import util -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 0d22d801d..37d51c04f 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -35,8 +35,7 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client from jax._src.util import safe_map, safe_zip -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index 3a4a2196c..1953114cb 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -30,7 +30,6 @@ import scipy.linalg as sla import scipy.sparse as sps import jax -from jax import config from jax._src import test_util as jtu from jax.experimental.sparse import linalg, bcoo import jax.numpy as jnp @@ -433,5 +432,5 @@ class F64LobpcgTest(LobpcgTest): if __name__ == '__main__': - config.parse_flags_with_absl() + jax.config.parse_flags_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/logging_test.py b/tests/logging_test.py index 05bb31015..6b02432ce 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -22,7 +22,6 @@ import textwrap import unittest import jax -from jax import config import jax._src.test_util as jtu from jax._src import xla_bridge @@ -33,7 +32,7 @@ from jax._src import xla_bridge # parsing to work correctly with bazel (otherwise we could avoid importing # absltest/absl logging altogether). from absl.testing import absltest -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @contextlib.contextmanager @@ -96,27 +95,27 @@ class LoggingTest(jtu.JaxTestCase): self.assertEmpty(log_output.getvalue()) # Turn on all debug logging. - config.update("jax_debug_log_modules", "jax") + jax.config.update("jax_debug_log_modules", "jax") with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertIn("Finished tracing + transforming", log_output.getvalue()) self.assertIn("Compiling <lambda>", log_output.getvalue()) # Turn off all debug logging. - config.update("jax_debug_log_modules", None) + jax.config.update("jax_debug_log_modules", None) with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) # Turn on one module. - config.update("jax_debug_log_modules", "jax._src.dispatch") + jax.config.update("jax_debug_log_modules", "jax._src.dispatch") with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertIn("Finished tracing + transforming", log_output.getvalue()) self.assertNotIn("Compiling <lambda>", log_output.getvalue()) # Turn everything off again. - config.update("jax_debug_log_modules", None) + jax.config.update("jax_debug_log_modules", None) with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 3511595d6..e01ba538b 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -23,8 +23,7 @@ from jax._src import config as jax_config from jax._src.lib.mlir import ir from jax import numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def module_to_string(module: ir.Module) -> str: diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index b955f0398..ba735775b 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -17,14 +17,13 @@ import math from absl.testing import absltest import jax -from jax import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class MockGPUTest(jtu.JaxTestCase): diff --git a/tests/mosaic_test.py b/tests/mosaic_test.py index 518766c1e..03c8f1ce3 100644 --- a/tests/mosaic_test.py +++ b/tests/mosaic_test.py @@ -14,9 +14,9 @@ from absl.testing import absltest from jax._src import test_util as jtu -from jax import config +import jax -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ImportTest(jtu.JaxTestCase): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 0060df9de..853865668 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -26,8 +26,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax._src import test_util as jtu from jax._src import xla_bridge -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 40cbb6630..f498d788b 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -25,8 +25,7 @@ import jax from jax._src import test_util as jtu from jax import numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() npr.seed(0) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index bbe79ecff..76ed03890 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -26,7 +26,6 @@ from absl.testing import parameterized import numpy as np import jax -from jax import config from jax._src import core from jax._src import distributed from jax._src import maps @@ -40,7 +39,7 @@ try: except ImportError: portpicker = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index e6ac29e70..5f6dc95b9 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -20,12 +20,11 @@ from jax._src import core from jax import lax from jax._src.pjit import pjit from jax._src import linear_util as lu -from jax import config from jax._src import test_util as jtu from jax._src.lib import xla_client from jax._src import ad_checkpoint -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _get_hlo(f): def wrapped(*args, **kwargs): diff --git a/tests/ode_test.py b/tests/ode_test.py index 2d2bcc971..834745e1c 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -24,8 +24,7 @@ from jax.experimental.ode import odeint import scipy.integrate as osp_integrate -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ODETest(jtu.JaxTestCase): diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 3fb3101c4..b7710d9b9 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -26,8 +26,7 @@ from jax import jit, grad, jacfwd, jacrev from jax import lax from jax.example_libraries import optimizers -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class OptimizerTests(jtu.JaxTestCase): diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 188b56c8c..3dbf0232f 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,7 +21,6 @@ import tempfile from absl.testing import absltest import jax -from jax import config from jax._src import test_util as jtu from jax.sharding import NamedSharding from jax.experimental import profiler as exp_profiler @@ -29,7 +28,7 @@ import jax.numpy as jnp from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @jtu.pytest_mark_if_available('multiaccelerator') diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 8fa6613cf..1dede34d2 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -26,14 +26,13 @@ except ImportError: import jax from jax import numpy as jnp -from jax import config from jax.interpreters import pxla from jax._src import test_util as jtu from jax._src.lib import xla_client as xc import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _get_device_by_id(device_id: int) -> xc.Device: diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index ccba4c2ef..3eeaec482 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -19,12 +19,12 @@ from scipy.sparse import csgraph, csr_matrix from absl.testing import absltest +import jax from jax._src import dtypes from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex diff --git a/tests/profiler_test.py b/tests/profiler_test.py index c232c3afd..b67b078ae 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -26,7 +26,6 @@ from absl.testing import absltest import jax import jax.numpy as jnp import jax.profiler -from jax import config import jax._src.test_util as jtu from jax._src import profiler @@ -50,7 +49,7 @@ try: except ImportError: pass -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ProfilerTest(unittest.TestCase): diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 77ee057d2..17c1e9c2d 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -15,13 +15,12 @@ import itertools from absl.testing import absltest +import jax from jax._src import test_util as jtu import jax.scipy.fft as jsp_fft import scipy.fft as osp_fft -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.floating real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean diff --git a/tests/scipy_interpolate_test.py b/tests/scipy_interpolate_test.py index ee905b7f0..1fead634a 100644 --- a/tests/scipy_interpolate_test.py +++ b/tests/scipy_interpolate_test.py @@ -18,13 +18,13 @@ import operator from functools import reduce import numpy as np +import jax from jax._src import test_util as jtu import scipy.interpolate as sp_interp import jax.scipy.interpolate as jsp_interp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class LaxBackedScipyInterpolateTests(jtu.JaxTestCase): diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 7ce0df873..b206c77d0 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -21,13 +21,13 @@ import numpy as np from absl.testing import absltest import scipy.ndimage as osp_ndimage +import jax from jax import grad from jax._src import test_util as jtu from jax import dtypes from jax.scipy import ndimage as lsp_ndimage -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.floating diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index e07455e06..70a00e14c 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -17,13 +17,13 @@ import numpy as np import scipy import scipy.optimize +import jax from jax import numpy as jnp from jax._src import test_util as jtu from jax import jit -from jax import config import jax.scipy.optimize -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def rosenbrock(np): diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 70a367a04..11923257a 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -21,14 +21,14 @@ from absl.testing import absltest import numpy as np import scipy.signal as osp_signal +import jax from jax import lax import jax.numpy as jnp from jax._src import dtypes from jax._src import test_util as jtu import jax.scipy.signal as jsp_signal -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() onedim_shapes = [(1,), (2,), (5,), (10,)] twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)] diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index f51ad49ad..5acbdc0dd 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -25,9 +25,8 @@ from scipy.spatial.transform import Slerp as osp_Slerp import jax.numpy as jnp import numpy as onp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 786d4ae03..1ab0bb9e5 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -27,8 +27,7 @@ from jax._src import dtypes, test_util as jtu from jax.scipy import stats as lsp_stats from jax.scipy.special import expit -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 291ee5360..8b7f11e31 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -25,8 +25,7 @@ from jax.experimental.shard_alike import shard_alike from jax.experimental.shard_map import shard_map from jax._src.lib import xla_extension_version -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/source_info_test.py b/tests/source_info_test.py index aaa3abf55..0f876de1c 100644 --- a/tests/source_info_test.py +++ b/tests/source_info_test.py @@ -19,11 +19,10 @@ from absl.testing import absltest import jax from jax import lax -from jax import config from jax._src import source_info_util from jax._src import test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class SourceInfoTest(jtu.JaxTestCase): diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 441bee4ef..ba0ad5cb0 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -22,7 +22,6 @@ import unittest from absl.testing import absltest import jax -from jax import config from jax import jit from jax import lax from jax import vmap @@ -40,7 +39,7 @@ import jax.random from jax.util import split_list import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() COMPATIBLE_SHAPE_PAIRS = [ [(), ()], @@ -151,7 +150,7 @@ def _is_required_cuda_version_satisfied(cuda_version): class BCOOTest(sptu.SparseTestCase): def gpu_matmul_warning_context(self, msg): - if config.jax_bcoo_cusparse_lowering: + if jax.config.jax_bcoo_cusparse_lowering: return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg) return contextlib.nullcontext() diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 2522befa9..49438f411 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -22,7 +22,6 @@ from absl.testing import parameterized import jax import jax.random -from jax import config from jax import dtypes from jax.experimental import sparse from jax.experimental.sparse import coo as sparse_coo @@ -43,7 +42,7 @@ from jax.util import split_list import numpy as np import scipy.sparse -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 998ce1c40..46086511d 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -22,7 +22,7 @@ from absl.testing import parameterized import numpy as np import jax -from jax import config, jit, lax +from jax import jit, lax import jax.numpy as jnp import jax._src.test_util as jtu from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer @@ -31,7 +31,7 @@ from jax.experimental.sparse.transform import ( from jax.experimental.sparse.util import CuSparseEfficiencyWarning from jax.experimental.sparse import test_util as sptu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default): def _rand_sparse(shape, dtype, nse=nse): diff --git a/tests/stack_test.py b/tests/stack_test.py index acefc0630..655a42571 100644 --- a/tests/stack_test.py +++ b/tests/stack_test.py @@ -17,13 +17,13 @@ from absl.testing import absltest +import jax import jax.numpy as jnp from jax._src.lax.stack import Stack from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class StackTest(jtu.JaxTestCase): diff --git a/tests/stax_test.py b/tests/stax_test.py index 351a0fdb3..6850f36a0 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -18,13 +18,13 @@ from absl.testing import absltest import numpy as np +import jax from jax._src import test_util as jtu from jax import random from jax.example_libraries import stax from jax import dtypes -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def random_inputs(rng, input_shape): diff --git a/tests/third_party/scipy/line_search_test.py b/tests/third_party/scipy/line_search_test.py index 5e7d9a943..9b2480053 100644 --- a/tests/third_party/scipy/line_search_test.py +++ b/tests/third_party/scipy/line_search_test.py @@ -3,13 +3,12 @@ import scipy.optimize import jax from jax import grad -from jax import config import jax.numpy as jnp import jax._src.test_util as jtu from jax._src.scipy.optimize.line_search import line_search -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class TestLineSearch(jtu.JaxTestCase): diff --git a/tests/transfer_guard_test.py b/tests/transfer_guard_test.py index fa08c52b6..b6d9058db 100644 --- a/tests/transfer_guard_test.py +++ b/tests/transfer_guard_test.py @@ -25,9 +25,7 @@ import jax import jax._src.test_util as jtu import jax.numpy as jnp -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _host_to_device_funcs(): diff --git a/tests/util_test.py b/tests/util_test.py index e06df8b3f..5f07d2f50 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -16,13 +16,13 @@ import operator from absl.testing import absltest +import jax from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util -from jax import config from jax._src.util import weakref_lru_cache -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() try: from jax._src.lib import utils as jaxlib_utils diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 75919de8f..58cf4a2ba 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -24,12 +24,11 @@ import numpy as np import jax from jax import lax from jax import random -from jax import config from jax.experimental import enable_x64, disable_x64 import jax.numpy as jnp import jax._src.test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class X64ContextTests(jtu.JaxTestCase): @@ -49,12 +48,12 @@ class X64ContextTests(jtu.JaxTestCase): ) def test_correctly_capture_default(self, jit, enable_or_disable): # The fact we defined a jitted function with a block with a different value - # of `config.enable_x64` has no impact on the output. + # of `jax.config.enable_x64` has no impact on the output. with enable_or_disable(): func = jit(lambda: jnp.array(np.float64(0))) func() - expected_dtype = "float64" if config._read("jax_enable_x64") else "float32" + expected_dtype = "float64" if jax.config._read("jax_enable_x64") else "float32" self.assertEqual(func().dtype, expected_dtype) with enable_x64(): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 91b63488a..0d11bb878 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -53,8 +53,7 @@ from jax._src.nn import initializers as nn_initializers from jax._src.sharding_impls import NamedSharding from jax._src.util import unzip2 -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py @@ -248,10 +247,10 @@ class SPMDTestMixin: def setUp(self): super().setUp() self.spmd_lowering = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) + jax.config.update('experimental_xmap_spmd_lowering', True) def tearDown(self): - config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) + jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) class ManualSPMDTestMixin: @@ -261,12 +260,12 @@ class ManualSPMDTestMixin: super().setUp() self.spmd_lowering = maps.SPMD_LOWERING.value self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value - config.update('experimental_xmap_spmd_lowering', True) - config.update('experimental_xmap_spmd_lowering_manual', True) + jax.config.update('experimental_xmap_spmd_lowering', True) + jax.config.update('experimental_xmap_spmd_lowering_manual', True) def tearDown(self): - config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) - config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering) + jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) + jax.config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering) @jtu.pytest_mark_if_available('multiaccelerator') @@ -845,13 +844,13 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest): # TODO(apaszke): Add support for extracting XLA computations generated by # xmap and make this less of a smoke test. try: - config.update("experimental_xmap_ensure_fixed_sharding", True) + jax.config.update("experimental_xmap_ensure_fixed_sharding", True) f = xmap(lambda x: jnp.sin(2 * jnp.sum(jnp.cos(x) + 4, 'i')), in_axes=['i'], out_axes={}, axis_resources={'i': 'x'}) x = jnp.arange(20, dtype=jnp.float32) f(x) finally: - config.update("experimental_xmap_ensure_fixed_sharding", False) + jax.config.update("experimental_xmap_ensure_fixed_sharding", False) @jtu.with_mesh([('x', 2)]) def testConstantsInLowering(self):