From f090074d865a4f98fba32bd88460373438cb4f24 Mon Sep 17 00:00:00 2001
From: Jake VanderPlas <jakevdp@google.com>
Date: Thu, 11 Apr 2024 13:23:27 -0700
Subject: [PATCH] Avoid 'from jax import config' imports

In some environments this appears to import the config module rather than
the config object.
---
 benchmarks/api_benchmark.py                   |  4 +---
 benchmarks/shape_poly_benchmark.py            |  4 ++--
 docs/debugging/flags.md                       | 16 +++++++-------
 docs/debugging/index.md                       |  4 ++--
 docs/notebooks/Common_Gotchas_in_JAX.ipynb    | 18 ++++++++--------
 docs/notebooks/Common_Gotchas_in_JAX.md       | 18 ++++++++--------
 docs/rank_promotion_warning.rst               |  4 ++--
 examples/examples_test.py                     |  4 ++--
 examples/gaussian_process_regression.py       |  5 +++--
 jax/_src/internal_test_util/lax_test_util.py  |  4 ++--
 .../array_serialization/serialization_test.py |  4 +---
 .../jax2tf/examples/keras_reuse_main_test.py  |  4 ++--
 .../jax2tf/tests/back_compat_tf_test.py       |  4 ++--
 jax/experimental/jax2tf/tests/call_tf_test.py | 21 +++++++++----------
 .../jax2tf/tests/control_flow_ops_test.py     |  3 +--
 .../jax2tf/tests/cross_compilation_check.py   |  5 ++---
 .../jax2tf/tests/savedmodel_test.py           |  3 +--
 tests/ann_test.py                             |  4 +---
 tests/aot_test.py                             |  3 +--
 tests/api_util_test.py                        |  4 ++--
 tests/array_test.py                           |  3 +--
 tests/batching_test.py                        |  3 +--
 tests/clear_backends_test.py                  |  3 +--
 tests/custom_linear_solve_test.py             |  3 +--
 tests/custom_object_test.py                   |  4 ++--
 tests/custom_root_test.py                     |  3 +--
 tests/debug_nans_test.py                      | 17 +++++++--------
 tests/debugger_test.py                        |  3 +--
 tests/debugging_primitives_test.py            |  3 +--
 tests/dynamic_api_test.py                     |  3 +--
 tests/extend_test.py                          |  3 +--
 tests/for_loop_test.py                        |  3 +--
 tests/generated_fun_test.py                   |  4 ++--
 tests/heap_profiler_test.py                   |  3 +--
 tests/host_callback_test.py                   |  3 +--
 tests/image_test.py                           |  4 +---
 tests/infeed_test.py                          |  3 +--
 tests/jet_test.py                             |  3 +--
 tests/key_reuse_test.py                       |  3 +--
 tests/lax_autodiff_test.py                    |  3 +--
 tests/lax_control_flow_test.py                |  3 +--
 tests/lax_numpy_einsum_test.py                |  3 +--
 tests/lax_numpy_ufuncs_test.py                |  3 +--
 tests/lax_numpy_vectorize_test.py             |  3 +--
 tests/lax_scipy_special_functions_test.py     |  3 +--
 tests/lax_scipy_spectral_dac_test.py          |  4 ++--
 tests/lax_scipy_test.py                       |  3 +--
 tests/lax_vmap_op_test.py                     |  3 +--
 tests/lax_vmap_test.py                        |  3 +--
 tests/lobpcg_test.py                          |  3 +--
 tests/logging_test.py                         | 11 +++++-----
 tests/metadata_test.py                        |  3 +--
 tests/mock_gpu_test.py                        |  3 +--
 tests/mosaic_test.py                          |  4 ++--
 tests/multi_device_test.py                    |  3 +--
 tests/multibackend_test.py                    |  3 +--
 tests/multiprocess_gpu_test.py                |  3 +--
 tests/name_stack_test.py                      |  3 +--
 tests/ode_test.py                             |  3 +--
 tests/optimizers_test.py                      |  3 +--
 tests/pgle_test.py                            |  3 +--
 tests/pickle_test.py                          |  3 +--
 tests/polynomial_test.py                      |  4 ++--
 tests/profiler_test.py                        |  3 +--
 tests/scipy_fft_test.py                       |  5 ++---
 tests/scipy_interpolate_test.py               |  4 ++--
 tests/scipy_ndimage_test.py                   |  4 ++--
 tests/scipy_optimize_test.py                  |  4 ++--
 tests/scipy_signal_test.py                    |  4 ++--
 tests/scipy_spatial_test.py                   |  3 +--
 tests/scipy_stats_test.py                     |  3 +--
 tests/shard_alike_test.py                     |  3 +--
 tests/source_info_test.py                     |  3 +--
 tests/sparse_bcoo_bcsr_test.py                |  5 ++---
 tests/sparse_test.py                          |  3 +--
 tests/sparsify_test.py                        |  4 ++--
 tests/stack_test.py                           |  4 ++--
 tests/stax_test.py                            |  4 ++--
 tests/third_party/scipy/line_search_test.py   |  3 +--
 tests/transfer_guard_test.py                  |  4 +---
 tests/util_test.py                            |  4 ++--
 tests/x64_context_test.py                     |  7 +++----
 tests/xmap_test.py                            | 19 ++++++++---------
 83 files changed, 162 insertions(+), 224 deletions(-)

diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py
index 30fb04ace..75cd38d10 100644
--- a/benchmarks/api_benchmark.py
+++ b/benchmarks/api_benchmark.py
@@ -33,9 +33,7 @@ from jax.experimental import multihost_utils
 import jax.numpy as jnp
 import numpy as np
 
-from jax import config
-
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 partial = functools.partial
diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py
index bd8dd42d1..b1b6b625c 100644
--- a/benchmarks/shape_poly_benchmark.py
+++ b/benchmarks/shape_poly_benchmark.py
@@ -15,12 +15,12 @@
 
 import google_benchmark as benchmark
 
-from jax import config
+import jax
 from jax import core
 from jax._src.numpy import lax_numpy
 from jax.experimental import export
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 @benchmark.register
diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md
index 8434384c4..90a6cb3bb 100644
--- a/docs/debugging/flags.md
+++ b/docs/debugging/flags.md
@@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily.
 
 If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
 * setting the `JAX_DEBUG_NANS=True` environment variable;
-* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
-* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
+* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
+* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
 
 ### Example(s)
 
 ```python
-from jax import config
-config.update("jax_debug_nans", True)
+import jax
+jax.config.update("jax_debug_nans", True)
 
 def f(x, y):
   return x / y
@@ -47,14 +47,14 @@ jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!
 
 You can disable JIT-compilation by:
 * setting the `JAX_DISABLE_JIT=True` environment variable;
-* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
-* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
+* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file;
+* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
 
 ### Examples
 
 ```python
-from jax import config
-config.update("jax_disable_jit", True)
+import jax
+jax.config.update("jax_disable_jit", True)
 
 def f(x):
   y = jnp.log(x)
diff --git a/docs/debugging/index.md b/docs/debugging/index.md
index 9a020b360..35e0f6895 100644
--- a/docs/debugging/index.md
+++ b/docs/debugging/index.md
@@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more!
 **TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
 
 ```python
-from jax import config
-config.update("jax_debug_nans", True)
+import jax
+jax.config.update("jax_debug_nans", True)
 
 def f(x, y):
   return x / y
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
index 13fdd572b..d8dffdb8a 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb
+++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
@@ -1946,9 +1946,9 @@
     "\n",
     "* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
     "\n",
-    "* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
+    "* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
     "\n",
-    "* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
+    "* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
     "\n",
     "This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
     "\n",
@@ -2141,24 +2141,24 @@
     "\n",
     "   ```python\n",
     "   # again, this only works on startup!\n",
-    "   from jax import config\n",
-    "   config.update(\"jax_enable_x64\", True)\n",
+    "   import jax\n",
+    "   jax.config.update(\"jax_enable_x64\", True)\n",
     "   ```\n",
     "\n",
     "3. You can parse command-line flags with `absl.app.run(main)`\n",
     "\n",
     "   ```python\n",
-    "   from jax import config\n",
-    "   config.config_with_absl()\n",
+    "   import jax\n",
+    "   jax.config.config_with_absl()\n",
     "   ```\n",
     "\n",
     "4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
     "\n",
     "   ```python\n",
-    "   from jax import config\n",
+    "   import jax\n",
     "   if __name__ == '__main__':\n",
-    "     # calls config.config_with_absl() *and* runs absl parsing\n",
-    "     config.parse_flags_with_absl()\n",
+    "     # calls jax.config.config_with_absl() *and* runs absl parsing\n",
+    "     jax.config.parse_flags_with_absl()\n",
     "   ```\n",
     "\n",
     "Note that #2-#4 work for _any_ of JAX's configuration options.\n",
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md
index 0e5af8b04..e63d64d94 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.md
+++ b/docs/notebooks/Common_Gotchas_in_JAX.md
@@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
 
 * setting the `JAX_DEBUG_NANS=True` environment variable;
 
-* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
+* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
 
-* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
+* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
 
 This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.
 
@@ -1087,24 +1087,24 @@ There are a few ways to do this:
 
    ```python
    # again, this only works on startup!
-   from jax import config
-   config.update("jax_enable_x64", True)
+   import jax
+   jax.config.update("jax_enable_x64", True)
    ```
 
 3. You can parse command-line flags with `absl.app.run(main)`
 
    ```python
-   from jax import config
-   config.config_with_absl()
+   import jax
+   jax.config.config_with_absl()
    ```
 
 4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use
 
    ```python
-   from jax import config
+   import jax
    if __name__ == '__main__':
-     # calls config.config_with_absl() *and* runs absl parsing
-     config.parse_flags_with_absl()
+     # calls jax.config.config_with_absl() *and* runs absl parsing
+     jax.config.parse_flags_with_absl()
    ```
 
 Note that #2-#4 work for _any_ of JAX's configuration options.
diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst
index e81509e2a..5e4e7ec65 100644
--- a/docs/rank_promotion_warning.rst
+++ b/docs/rank_promotion_warning.rst
@@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code:
 
 .. code-block:: python
 
-  from jax import config
-  config.update("jax_numpy_rank_promotion", "warn")
+  import jax
+  jax.config.update("jax_numpy_rank_promotion", "warn")
 
 You can also set the option using the environment variable
 :code:`JAX_NUMPY_RANK_PROMOTION`, for example as
diff --git a/examples/examples_test.py b/examples/examples_test.py
index b8b4d11e2..c9cb2991c 100644
--- a/examples/examples_test.py
+++ b/examples/examples_test.py
@@ -22,6 +22,7 @@ from absl.testing import parameterized
 
 import numpy as np
 
+import jax
 from jax import lax
 from jax import random
 import jax.numpy as jnp
@@ -30,8 +31,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 from examples import kernel_lsq
 sys.path.pop()
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
diff --git a/examples/gaussian_process_regression.py b/examples/gaussian_process_regression.py
index c42a024d4..75f7398d1 100644
--- a/examples/gaussian_process_regression.py
+++ b/examples/gaussian_process_regression.py
@@ -17,10 +17,11 @@
 
 from absl import app
 from functools import partial
+
+import jax
 from jax import grad
 from jax import jit
 from jax import vmap
-from jax import config
 import jax.numpy as jnp
 import jax.random as random
 import jax.scipy as scipy
@@ -125,5 +126,5 @@ def main(unused_argv):
                     mu.flatten() - std * 2, mu.flatten() + std * 2)
 
 if __name__ == "__main__":
-  config.config_with_absl()
+  jax.config.config_with_absl()
   app.run(main)
diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py
index 5cc08eb05..b57b7d085 100644
--- a/jax/_src/internal_test_util/lax_test_util.py
+++ b/jax/_src/internal_test_util/lax_test_util.py
@@ -23,6 +23,7 @@ import collections
 import itertools
 from typing import Union, cast
 
+import jax
 from jax import lax
 from jax._src import dtypes
 from jax._src import test_util
@@ -30,8 +31,7 @@ from jax._src.util import safe_map, safe_zip
 
 import numpy as np
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 map, unsafe_map = safe_map, map
 zip, unsafe_zip = safe_zip, zip
diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py
index b9e3192f9..6cb5347e9 100644
--- a/jax/experimental/array_serialization/serialization_test.py
+++ b/jax/experimental/array_serialization/serialization_test.py
@@ -24,16 +24,14 @@ from absl.testing import absltest
 from absl.testing import parameterized
 import jax
 from jax._src import test_util as jtu
-from jax import config
 from jax._src import array
 from jax.sharding import NamedSharding, GSPMDSharding
 from jax.sharding import PartitionSpec as P
 from jax.experimental.array_serialization import serialization
 import numpy as np
 import tensorstore as ts
-import unittest
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 prev_xla_flags = None
 
diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py
index 562369cdb..293484291 100644
--- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py
+++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py
@@ -16,13 +16,13 @@ import os
 from absl import flags
 from absl.testing import absltest
 from absl.testing import parameterized
+import jax
 from jax._src import test_util as jtu
-from jax import config
 
 from jax.experimental.jax2tf.examples import keras_reuse_main
 from jax.experimental.jax2tf.tests import tf_test_util
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 FLAGS = flags.FLAGS
 
 
diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py
index 47c5c8360..bd31c19ba 100644
--- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py
+++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py
@@ -27,7 +27,7 @@ import tarfile
 from typing import Callable, Optional
 
 from absl.testing import absltest
-from jax import config
+import jax
 from jax._src import test_util as jtu
 from jax._src.internal_test_util import export_back_compat_test_util as bctu
 from jax._src.lib import xla_extension
@@ -37,7 +37,7 @@ import jax.numpy as jnp
 import tensorflow as tf
 
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def serialize_directory(directory_path):
diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py
index 57eea5f6a..3a0bdffd0 100644
--- a/jax/experimental/jax2tf/tests/call_tf_test.py
+++ b/jax/experimental/jax2tf/tests/call_tf_test.py
@@ -23,7 +23,6 @@ from absl import logging
 from absl.testing import absltest
 from absl.testing import parameterized
 import jax
-from jax import config
 from jax import dlpack
 from jax import dtypes
 from jax import lax
@@ -42,7 +41,7 @@ try:
 except ImportError:
   tf = None
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
@@ -1151,15 +1150,15 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
     super().setUp()
 
   def override_serialization_version(self, version_override: int):
-      version = config.jax_serialization_version
+      version = jax.config.jax_serialization_version
       if version != version_override:
-        self.addCleanup(partial(config.update,
+        self.addCleanup(partial(jax.config.update,
                                 "jax_serialization_version",
                                 version_override))
-        config.update("jax_serialization_version", version_override)
+        jax.config.update("jax_serialization_version", version_override)
       logging.info(
         "Using JAX serialization version %s",
-        config.jax_serialization_version)
+        jax.config.jax_serialization_version)
 
   def test_alternate(self):
     # Alternate sin/cos with sin in TF and cos in JAX
@@ -1275,7 +1274,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
 
   @_parameterized_jit
   def test_shape_poly_static_output_shape(self, with_jit=True):
-    if config.jax2tf_default_native_serialization:
+    if jax.config.jax2tf_default_native_serialization:
       raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
     x = np.array([0.7, 0.8], dtype=np.float32)
 
@@ -1289,7 +1288,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
 
   @_parameterized_jit
   def test_shape_poly(self, with_jit=False):
-    if config.jax2tf_default_native_serialization:
+    if jax.config.jax2tf_default_native_serialization:
       raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
     x = np.array([7, 8, 9, 10], dtype=np.float32)
     def fun_jax(x):
@@ -1308,7 +1307,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
 
   @_parameterized_jit
   def test_shape_poly_pytree_result(self, with_jit=True):
-    if config.jax2tf_default_native_serialization:
+    if jax.config.jax2tf_default_native_serialization:
       raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
     x = np.array([7, 8, 9, 10], dtype=np.float32)
     def fun_jax(x):
@@ -1394,7 +1393,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
     if kind == "bad_dim" and with_jit:
       # TODO: in jit more the error pops up later, at AddV2
       expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
-    if kind == "bad_dim" and config.jax2tf_default_native_serialization:
+    if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization:
       # TODO(b/268386622): call_tf with shape polymorphism and native serialization.
       expect_error = "Error compiling TensorFlow function"
     fun_tf_rt = _maybe_tf_jit(with_jit,
@@ -1432,7 +1431,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
                                f4_function=False, f4_saved_model=False):
     if (f2_saved_model and
         f4_saved_model and
-        not config.jax2tf_default_native_serialization):
+        not jax.config.jax2tf_default_native_serialization):
       # TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients
       # when saving f4, but only with non-native serialization.
       raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients")
diff --git a/jax/experimental/jax2tf/tests/control_flow_ops_test.py b/jax/experimental/jax2tf/tests/control_flow_ops_test.py
index 253a5ffc6..c66a6d696 100644
--- a/jax/experimental/jax2tf/tests/control_flow_ops_test.py
+++ b/jax/experimental/jax2tf/tests/control_flow_ops_test.py
@@ -23,8 +23,7 @@ import numpy as np
 
 from jax.experimental.jax2tf.tests import tf_test_util
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py
index 63e8928ee..0a4bf61f8 100644
--- a/jax/experimental/jax2tf/tests/cross_compilation_check.py
+++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py
@@ -39,12 +39,11 @@ from absl import logging
 
 import numpy.random as npr
 
-import jax
-from jax import config   # Must import before TF
+import jax # Must import before TF
 from jax.experimental import jax2tf  # Defines needed flags
 from jax._src import test_util  # Defines needed flags
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 # Import after parsing flags
 from jax.experimental.jax2tf.tests import primitive_harness
diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py
index 86ae81f83..37e7eb24f 100644
--- a/jax/experimental/jax2tf/tests/savedmodel_test.py
+++ b/jax/experimental/jax2tf/tests/savedmodel_test.py
@@ -25,8 +25,7 @@ from jax.experimental import jax2tf
 from jax.experimental.jax2tf.tests import tf_test_util
 from jax._src import test_util as jtu
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class SavedModelTest(tf_test_util.JaxToTfTestCase):
diff --git a/tests/ann_test.py b/tests/ann_test.py
index ab35ce0c5..1d704c725 100644
--- a/tests/ann_test.py
+++ b/tests/ann_test.py
@@ -23,9 +23,7 @@ import jax
 from jax import lax
 from jax._src import test_util as jtu
 
-from jax import config
-
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 ignore_jit_of_pmap_warning = partial(
     jtu.ignore_warning,message=".*jit-of-pmap.*")
diff --git a/tests/aot_test.py b/tests/aot_test.py
index dacfa620c..bca0d66ed 100644
--- a/tests/aot_test.py
+++ b/tests/aot_test.py
@@ -17,7 +17,6 @@ import contextlib
 import unittest
 from absl.testing import absltest
 import jax
-from jax import config
 from jax._src import core
 from jax._src import test_util as jtu
 from jax._src.lib import xla_client as xc
@@ -31,7 +30,7 @@ import jax.numpy as jnp
 from jax.sharding import PartitionSpec as P
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 prev_xla_flags = None
 
diff --git a/tests/api_util_test.py b/tests/api_util_test.py
index 7b7a479db..f78b5948f 100644
--- a/tests/api_util_test.py
+++ b/tests/api_util_test.py
@@ -16,12 +16,12 @@
 import itertools as it
 from absl.testing import absltest
 from absl.testing import parameterized
+import jax
 from jax._src import api_util
 from jax import numpy as jnp
 from jax._src import test_util as jtu
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class ApiUtilTest(jtu.JaxTestCase):
diff --git a/tests/array_test.py b/tests/array_test.py
index 7c8d4c355..0d8dba0bd 100644
--- a/tests/array_test.py
+++ b/tests/array_test.py
@@ -40,8 +40,7 @@ from jax.sharding import PartitionSpec as P
 from jax._src import array
 from jax._src import prng
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 prev_xla_flags = None
diff --git a/tests/batching_test.py b/tests/batching_test.py
index afbe9cf70..36e686443 100644
--- a/tests/batching_test.py
+++ b/tests/batching_test.py
@@ -37,8 +37,7 @@ from jax import vmap
 from jax.interpreters import batching
 from jax.tree_util import register_pytree_node
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 # These are 'manual' tests for batching (vmap). The more exhaustive, more
diff --git a/tests/clear_backends_test.py b/tests/clear_backends_test.py
index f8d5271ce..9ea9cac3a 100644
--- a/tests/clear_backends_test.py
+++ b/tests/clear_backends_test.py
@@ -15,12 +15,11 @@
 
 from absl.testing import absltest
 import jax
-from jax import config
 from jax._src import api
 from jax._src import test_util as jtu
 from jax._src import xla_bridge as xb
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class ClearBackendsTest(jtu.JaxTestCase):
diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py
index 2c3d2a258..830526826 100644
--- a/tests/custom_linear_solve_test.py
+++ b/tests/custom_linear_solve_test.py
@@ -28,8 +28,7 @@ from jax._src import test_util as jtu
 import jax.numpy as jnp  # scan tests use numpy
 import jax.scipy as jsp
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def high_precision_dot(a, b):
diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py
index 036f912a3..75ff39630 100644
--- a/tests/custom_object_test.py
+++ b/tests/custom_object_test.py
@@ -18,9 +18,9 @@ import unittest
 
 import numpy as np
 
+import jax
 import jax.numpy as jnp
 from jax import jit, lax, make_jaxpr
-from jax import config
 from jax.interpreters import mlir
 from jax.interpreters import xla
 
@@ -34,7 +34,7 @@ from jax._src.lib import xla_client
 xc = xla_client
 xb = xla_bridge
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 # TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the
 # dictionaries associated with the following objects.
diff --git a/tests/custom_root_test.py b/tests/custom_root_test.py
index 88dee90aa..6a7eaab17 100644
--- a/tests/custom_root_test.py
+++ b/tests/custom_root_test.py
@@ -25,8 +25,7 @@ from jax._src import test_util as jtu
 import jax.numpy as jnp  # scan tests use numpy
 import jax.scipy as jsp
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def high_precision_dot(a, b):
diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py
index 8dc9818f8..e57439444 100644
--- a/tests/debug_nans_test.py
+++ b/tests/debug_nans_test.py
@@ -26,19 +26,18 @@ from jax import numpy as jnp
 from jax.experimental import pjit
 from jax._src.maps import xmap
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class DebugNaNsTest(jtu.JaxTestCase):
 
   def setUp(self):
     super().setUp()
-    self.cfg = config._read("jax_debug_nans")
-    config.update("jax_debug_nans", True)
+    self.cfg = jax.config._read("jax_debug_nans")
+    jax.config.update("jax_debug_nans", True)
 
   def tearDown(self):
-    config.update("jax_debug_nans", self.cfg)
+    jax.config.update("jax_debug_nans", self.cfg)
     super().tearDown()
 
   def testSinc(self):
@@ -67,7 +66,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
       ans.block_until_ready()
 
   def testJitComputationNaNContextManager(self):
-    config.update("jax_debug_nans", False)
+    jax.config.update("jax_debug_nans", False)
     A = jnp.array(0.)
     f = jax.jit(lambda x: 0. / x)
     ans = f(A)
@@ -210,11 +209,11 @@ class DebugInfsTest(jtu.JaxTestCase):
 
   def setUp(self):
     super().setUp()
-    self.cfg = config._read("jax_debug_infs")
-    config.update("jax_debug_infs", True)
+    self.cfg = jax.config._read("jax_debug_infs")
+    jax.config.update("jax_debug_infs", True)
 
   def tearDown(self):
-    config.update("jax_debug_infs", self.cfg)
+    jax.config.update("jax_debug_infs", self.cfg)
     super().tearDown()
 
   def testSingleResultPrimitiveNoInf(self):
diff --git a/tests/debugger_test.py b/tests/debugger_test.py
index 5e6d4388f..66488feb8 100644
--- a/tests/debugger_test.py
+++ b/tests/debugger_test.py
@@ -21,14 +21,13 @@ import unittest
 
 from absl.testing import absltest
 import jax
-from jax import config
 from jax.experimental import pjit
 from jax._src import debugger
 from jax._src import test_util as jtu
 import jax.numpy as jnp
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
   fake_stdin = io.StringIO()
diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py
index ce0e5e3b2..51c91d9aa 100644
--- a/tests/debugging_primitives_test.py
+++ b/tests/debugging_primitives_test.py
@@ -19,7 +19,6 @@ import unittest
 from absl.testing import absltest
 import jax
 from jax import lax
-from jax import config
 from jax.experimental import pjit
 from jax.interpreters import pxla
 from jax._src import ad_checkpoint
@@ -35,7 +34,7 @@ try:
 except ModuleNotFoundError:
   rich = None
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 debug_print = debugging.debug_print
 
diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py
index c704d7e10..13e9cc5bb 100644
--- a/tests/dynamic_api_test.py
+++ b/tests/dynamic_api_test.py
@@ -23,7 +23,6 @@ from absl.testing import parameterized
 import jax
 import jax.numpy as jnp
 from jax import lax
-from jax import config
 from jax.interpreters import batching
 
 import jax._src.lib
@@ -31,7 +30,7 @@ import jax._src.util
 from jax._src import core
 from jax._src import test_util as jtu
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
diff --git a/tests/extend_test.py b/tests/extend_test.py
index b49c1ac09..a926861eb 100644
--- a/tests/extend_test.py
+++ b/tests/extend_test.py
@@ -24,8 +24,7 @@ from jax._src import linear_util
 from jax._src import prng
 from jax._src import test_util as jtu
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class ExtendTest(jtu.JaxTestCase):
diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py
index 641f12ff0..cbbe56a62 100644
--- a/tests/for_loop_test.py
+++ b/tests/for_loop_test.py
@@ -24,8 +24,7 @@ from jax._src import test_util as jtu
 from jax._src.lax.control_flow import for_loop
 import jax.numpy as jnp
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 def remat_of_for_loop(nsteps, body, state, **kwargs):
   return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state,
diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py
index e96f100b4..a288e1a5f 100644
--- a/tests/generated_fun_test.py
+++ b/tests/generated_fun_test.py
@@ -22,11 +22,11 @@ from absl.testing import parameterized
 
 import itertools as it
 import jax.numpy as jnp
+import jax
 from jax import jit, jvp, vjp
 import jax._src.test_util as jtu
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 npr.seed(0)
 
diff --git a/tests/heap_profiler_test.py b/tests/heap_profiler_test.py
index 6d3468e95..240eec1c8 100644
--- a/tests/heap_profiler_test.py
+++ b/tests/heap_profiler_test.py
@@ -17,11 +17,10 @@ from absl.testing import absltest
 
 import jax
 import jax._src.xla_bridge as xla_bridge
-from jax import config
 import jax._src.test_util as jtu
 
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class HeapProfilerTest(unittest.TestCase):
diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py
index 9c5ab78cb..99d34f30e 100644
--- a/tests/host_callback_test.py
+++ b/tests/host_callback_test.py
@@ -30,7 +30,6 @@ from absl.testing import absltest
 
 import jax
 from jax import ad_checkpoint
-from jax import config
 from jax import dtypes
 from jax import lax
 from jax import numpy as jnp
@@ -46,7 +45,7 @@ xops = xla_client.ops
 
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class _TestingOutputStream:
diff --git a/tests/image_test.py b/tests/image_test.py
index 6204ec91c..f3cd56ed7 100644
--- a/tests/image_test.py
+++ b/tests/image_test.py
@@ -24,8 +24,6 @@ from jax import image
 from jax import numpy as jnp
 from jax._src import test_util as jtu
 
-from jax import config
-
 # We use TensorFlow and PIL as reference implementations.
 try:
   import tensorflow as tf
@@ -37,7 +35,7 @@ try:
 except ImportError:
   PIL_Image = None
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 float_dtypes = jtu.dtypes.all_floating
 inexact_dtypes = jtu.dtypes.inexact
diff --git a/tests/infeed_test.py b/tests/infeed_test.py
index 572920fa4..ba47d2417 100644
--- a/tests/infeed_test.py
+++ b/tests/infeed_test.py
@@ -19,7 +19,6 @@ from unittest import SkipTest
 from absl.testing import absltest
 import jax
 from jax import lax, numpy as jnp
-from jax import config
 from jax.experimental import host_callback as hcb
 from jax._src import core
 from jax._src import xla_bridge
@@ -27,7 +26,7 @@ from jax._src.lib import xla_client
 import jax._src.test_util as jtu
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class InfeedTest(jtu.JaxTestCase):
diff --git a/tests/jet_test.py b/tests/jet_test.py
index c72057246..566119750 100644
--- a/tests/jet_test.py
+++ b/tests/jet_test.py
@@ -29,8 +29,7 @@ from jax.example_libraries import stax
 from jax.experimental.jet import jet, fact, zero_series
 from jax import lax
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 def jvp_taylor(fun, primals, series):
   # Computes the Taylor series the slow way, with nested jvp.
diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py
index 885b08224..d98984be5 100644
--- a/tests/key_reuse_test.py
+++ b/tests/key_reuse_test.py
@@ -29,8 +29,7 @@ from jax.experimental.key_reuse._core import (
   Source, Sink, Forward, KeyReuseSignature)
 from jax.experimental.key_reuse import _core
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 key = jax.eval_shape(jax.random.key, 0)
diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py
index 630b08cc3..ab3a18317 100644
--- a/tests/lax_autodiff_test.py
+++ b/tests/lax_autodiff_test.py
@@ -31,8 +31,7 @@ from jax._src import test_util as jtu
 from jax._src.util import NumpyComplexWarning
 from jax.test_util import check_grads
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 compatible_shapes = [[(3,)],
diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py
index 5c737cc4d..0a17a1421 100644
--- a/tests/lax_control_flow_test.py
+++ b/tests/lax_control_flow_test.py
@@ -42,8 +42,7 @@ from jax._src.lax import control_flow as lax_control_flow
 from jax._src.lax.control_flow import for_loop
 from jax._src.maps import xmap
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 # Some tests are useful for testing both lax.cond and lax.switch. This function
diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py
index 92259c8f4..423289f3d 100644
--- a/tests/lax_numpy_einsum_test.py
+++ b/tests/lax_numpy_einsum_test.py
@@ -27,8 +27,7 @@ from jax import lax
 import jax.numpy as jnp
 import jax._src.test_util as jtu
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class EinsumTest(jtu.JaxTestCase):
diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py
index 0f40e9d4d..40c9eb3bc 100644
--- a/tests/lax_numpy_ufuncs_test.py
+++ b/tests/lax_numpy_ufuncs_test.py
@@ -24,8 +24,7 @@ import jax.numpy as jnp
 from jax._src import test_util as jtu
 from jax._src.numpy.ufunc_api import get_if_single_primitive
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def scalar_add(x, y):
diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py
index cb0d9a0dc..edc344467 100644
--- a/tests/lax_numpy_vectorize_test.py
+++ b/tests/lax_numpy_vectorize_test.py
@@ -21,8 +21,7 @@ import jax
 from jax import numpy as jnp
 from jax._src import test_util as jtu
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class VectorizeTest(jtu.JaxTestCase):
diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py
index 564ecca86..be10f03fb 100644
--- a/tests/lax_scipy_special_functions_test.py
+++ b/tests/lax_scipy_special_functions_test.py
@@ -26,8 +26,7 @@ import jax
 from jax._src import test_util as jtu
 from jax.scipy import special as lsp_special
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py
index 2d353d590..a09dcac53 100644
--- a/tests/lax_scipy_spectral_dac_test.py
+++ b/tests/lax_scipy_spectral_dac_test.py
@@ -14,6 +14,7 @@
 
 import unittest
 
+import jax
 from jax import lax
 from jax import numpy as jnp
 from jax._src import test_util as jtu
@@ -21,8 +22,7 @@ from jax._src.lax import eigh as lax_eigh
 
 from absl.testing import absltest
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 linear_sizes = [16, 97, 128]
diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py
index e9f2e6bb9..cf3edbfd3 100644
--- a/tests/lax_scipy_test.py
+++ b/tests/lax_scipy_test.py
@@ -34,8 +34,7 @@ from jax._src import test_util as jtu
 from jax.scipy import special as lsp_special
 from jax.scipy import cluster as lsp_cluster
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 scipy_version = jtu.parse_version(scipy.version.version)
 
diff --git a/tests/lax_vmap_op_test.py b/tests/lax_vmap_op_test.py
index 5d3028132..c7059a293 100644
--- a/tests/lax_vmap_op_test.py
+++ b/tests/lax_vmap_op_test.py
@@ -26,8 +26,7 @@ from jax._src import test_util as jtu
 from jax._src.internal_test_util import lax_test_util
 from jax._src import util
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 map, unsafe_map = util.safe_map, map
 zip, unsafe_zip = util.safe_zip, zip
diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py
index 0d22d801d..37d51c04f 100644
--- a/tests/lax_vmap_test.py
+++ b/tests/lax_vmap_test.py
@@ -35,8 +35,7 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
 from jax._src.lib import xla_client
 from jax._src.util import safe_map, safe_zip
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 map, unsafe_map = safe_map, map
 zip, unsafe_zip = safe_zip, zip
diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py
index 3a4a2196c..1953114cb 100644
--- a/tests/lobpcg_test.py
+++ b/tests/lobpcg_test.py
@@ -30,7 +30,6 @@ import scipy.linalg as sla
 import scipy.sparse as sps
 
 import jax
-from jax import config
 from jax._src import test_util as jtu
 from jax.experimental.sparse import linalg, bcoo
 import jax.numpy as jnp
@@ -433,5 +432,5 @@ class F64LobpcgTest(LobpcgTest):
 
 
 if __name__ == '__main__':
-  config.parse_flags_with_absl()
+  jax.config.parse_flags_with_absl()
   absltest.main(testLoader=jtu.JaxTestLoader())
diff --git a/tests/logging_test.py b/tests/logging_test.py
index 05bb31015..6b02432ce 100644
--- a/tests/logging_test.py
+++ b/tests/logging_test.py
@@ -22,7 +22,6 @@ import textwrap
 import unittest
 
 import jax
-from jax import config
 import jax._src.test_util as jtu
 from jax._src import xla_bridge
 
@@ -33,7 +32,7 @@ from jax._src import xla_bridge
 # parsing to work correctly with bazel (otherwise we could avoid importing
 # absltest/absl logging altogether).
 from absl.testing import absltest
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 @contextlib.contextmanager
@@ -96,27 +95,27 @@ class LoggingTest(jtu.JaxTestCase):
     self.assertEmpty(log_output.getvalue())
 
     # Turn on all debug logging.
-    config.update("jax_debug_log_modules", "jax")
+    jax.config.update("jax_debug_log_modules", "jax")
     with capture_jax_logs() as log_output:
       jax.jit(lambda x: x + 1)(1)
     self.assertIn("Finished tracing + transforming", log_output.getvalue())
     self.assertIn("Compiling <lambda>", log_output.getvalue())
 
     # Turn off all debug logging.
-    config.update("jax_debug_log_modules", None)
+    jax.config.update("jax_debug_log_modules", None)
     with capture_jax_logs() as log_output:
       jax.jit(lambda x: x + 1)(1)
     self.assertEmpty(log_output.getvalue())
 
     # Turn on one module.
-    config.update("jax_debug_log_modules", "jax._src.dispatch")
+    jax.config.update("jax_debug_log_modules", "jax._src.dispatch")
     with capture_jax_logs() as log_output:
       jax.jit(lambda x: x + 1)(1)
     self.assertIn("Finished tracing + transforming", log_output.getvalue())
     self.assertNotIn("Compiling <lambda>", log_output.getvalue())
 
     # Turn everything off again.
-    config.update("jax_debug_log_modules", None)
+    jax.config.update("jax_debug_log_modules", None)
     with capture_jax_logs() as log_output:
       jax.jit(lambda x: x + 1)(1)
     self.assertEmpty(log_output.getvalue())
diff --git a/tests/metadata_test.py b/tests/metadata_test.py
index 3511595d6..e01ba538b 100644
--- a/tests/metadata_test.py
+++ b/tests/metadata_test.py
@@ -23,8 +23,7 @@ from jax._src import config as jax_config
 from jax._src.lib.mlir import ir
 from jax import numpy as jnp
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def module_to_string(module: ir.Module) -> str:
diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py
index b955f0398..ba735775b 100644
--- a/tests/mock_gpu_test.py
+++ b/tests/mock_gpu_test.py
@@ -17,14 +17,13 @@ import math
 
 from absl.testing import absltest
 import jax
-from jax import config
 from jax._src import test_util as jtu
 import jax.numpy as jnp
 from jax.sharding import NamedSharding
 from jax.sharding import PartitionSpec as P
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class MockGPUTest(jtu.JaxTestCase):
diff --git a/tests/mosaic_test.py b/tests/mosaic_test.py
index 518766c1e..03c8f1ce3 100644
--- a/tests/mosaic_test.py
+++ b/tests/mosaic_test.py
@@ -14,9 +14,9 @@
 from absl.testing import absltest
 from jax._src import test_util as jtu
 
-from jax import config
+import jax
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class ImportTest(jtu.JaxTestCase):
diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py
index 0060df9de..853865668 100644
--- a/tests/multi_device_test.py
+++ b/tests/multi_device_test.py
@@ -26,8 +26,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
 from jax._src import test_util as jtu
 from jax._src import xla_bridge
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 prev_xla_flags = None
 
diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py
index 40cbb6630..f498d788b 100644
--- a/tests/multibackend_test.py
+++ b/tests/multibackend_test.py
@@ -25,8 +25,7 @@ import jax
 from jax._src import test_util as jtu
 from jax import numpy as jnp
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 npr.seed(0)
 
diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py
index bbe79ecff..76ed03890 100644
--- a/tests/multiprocess_gpu_test.py
+++ b/tests/multiprocess_gpu_test.py
@@ -26,7 +26,6 @@ from absl.testing import parameterized
 import numpy as np
 
 import jax
-from jax import config
 from jax._src import core
 from jax._src import distributed
 from jax._src import maps
@@ -40,7 +39,7 @@ try:
 except ImportError:
   portpicker = None
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 @unittest.skipIf(not portpicker, "Test requires portpicker")
 class DistributedTest(jtu.JaxTestCase):
diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py
index e6ac29e70..5f6dc95b9 100644
--- a/tests/name_stack_test.py
+++ b/tests/name_stack_test.py
@@ -20,12 +20,11 @@ from jax._src import core
 from jax import lax
 from jax._src.pjit import pjit
 from jax._src import linear_util as lu
-from jax import config
 from jax._src import test_util as jtu
 from jax._src.lib import xla_client
 from jax._src import ad_checkpoint
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 def _get_hlo(f):
   def wrapped(*args, **kwargs):
diff --git a/tests/ode_test.py b/tests/ode_test.py
index 2d2bcc971..834745e1c 100644
--- a/tests/ode_test.py
+++ b/tests/ode_test.py
@@ -24,8 +24,7 @@ from jax.experimental.ode import odeint
 
 import scipy.integrate as osp_integrate
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class ODETest(jtu.JaxTestCase):
diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py
index 3fb3101c4..b7710d9b9 100644
--- a/tests/optimizers_test.py
+++ b/tests/optimizers_test.py
@@ -26,8 +26,7 @@ from jax import jit, grad, jacfwd, jacrev
 from jax import lax
 from jax.example_libraries import optimizers
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class OptimizerTests(jtu.JaxTestCase):
diff --git a/tests/pgle_test.py b/tests/pgle_test.py
index 188b56c8c..3dbf0232f 100644
--- a/tests/pgle_test.py
+++ b/tests/pgle_test.py
@@ -21,7 +21,6 @@ import tempfile
 
 from absl.testing import absltest
 import jax
-from jax import config
 from jax._src import test_util as jtu
 from jax.sharding import NamedSharding
 from jax.experimental import profiler as exp_profiler
@@ -29,7 +28,7 @@ import jax.numpy as jnp
 from jax.sharding import PartitionSpec as P
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 @jtu.pytest_mark_if_available('multiaccelerator')
diff --git a/tests/pickle_test.py b/tests/pickle_test.py
index 8fa6613cf..1dede34d2 100644
--- a/tests/pickle_test.py
+++ b/tests/pickle_test.py
@@ -26,14 +26,13 @@ except ImportError:
 
 import jax
 from jax import numpy as jnp
-from jax import config
 from jax.interpreters import pxla
 from jax._src import test_util as jtu
 from jax._src.lib import xla_client as xc
 
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def _get_device_by_id(device_id: int) -> xc.Device:
diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py
index ccba4c2ef..3eeaec482 100644
--- a/tests/polynomial_test.py
+++ b/tests/polynomial_test.py
@@ -19,12 +19,12 @@ from scipy.sparse import csgraph, csr_matrix
 
 from absl.testing import absltest
 
+import jax
 from jax._src import dtypes
 from jax import numpy as jnp
 from jax._src import test_util as jtu
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
diff --git a/tests/profiler_test.py b/tests/profiler_test.py
index c232c3afd..b67b078ae 100644
--- a/tests/profiler_test.py
+++ b/tests/profiler_test.py
@@ -26,7 +26,6 @@ from absl.testing import absltest
 import jax
 import jax.numpy as jnp
 import jax.profiler
-from jax import config
 import jax._src.test_util as jtu
 from jax._src import profiler
 
@@ -50,7 +49,7 @@ try:
 except ImportError:
   pass
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class ProfilerTest(unittest.TestCase):
diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py
index 77ee057d2..17c1e9c2d 100644
--- a/tests/scipy_fft_test.py
+++ b/tests/scipy_fft_test.py
@@ -15,13 +15,12 @@ import itertools
 
 from absl.testing import absltest
 
+import jax
 from jax._src import test_util as jtu
 import jax.scipy.fft as jsp_fft
 import scipy.fft as osp_fft
 
-from jax import config
-
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 float_dtypes = jtu.dtypes.floating
 real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean
diff --git a/tests/scipy_interpolate_test.py b/tests/scipy_interpolate_test.py
index ee905b7f0..1fead634a 100644
--- a/tests/scipy_interpolate_test.py
+++ b/tests/scipy_interpolate_test.py
@@ -18,13 +18,13 @@ import operator
 from functools import reduce
 import numpy as np
 
+import jax
 from jax._src import test_util as jtu
 import scipy.interpolate as sp_interp
 import jax.scipy.interpolate as jsp_interp
 
-from jax import config
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class LaxBackedScipyInterpolateTests(jtu.JaxTestCase):
diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py
index 7ce0df873..b206c77d0 100644
--- a/tests/scipy_ndimage_test.py
+++ b/tests/scipy_ndimage_test.py
@@ -21,13 +21,13 @@ import numpy as np
 from absl.testing import absltest
 import scipy.ndimage as osp_ndimage
 
+import jax
 from jax import grad
 from jax._src import test_util as jtu
 from jax import dtypes
 from jax.scipy import ndimage as lsp_ndimage
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 float_dtypes = jtu.dtypes.floating
diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py
index e07455e06..70a00e14c 100644
--- a/tests/scipy_optimize_test.py
+++ b/tests/scipy_optimize_test.py
@@ -17,13 +17,13 @@ import numpy as np
 import scipy
 import scipy.optimize
 
+import jax
 from jax import numpy as jnp
 from jax._src import test_util as jtu
 from jax import jit
-from jax import config
 import jax.scipy.optimize
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def rosenbrock(np):
diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py
index 70a367a04..11923257a 100644
--- a/tests/scipy_signal_test.py
+++ b/tests/scipy_signal_test.py
@@ -21,14 +21,14 @@ from absl.testing import absltest
 import numpy as np
 import scipy.signal as osp_signal
 
+import jax
 from jax import lax
 import jax.numpy as jnp
 from jax._src import dtypes
 from jax._src import test_util as jtu
 import jax.scipy.signal as jsp_signal
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 onedim_shapes = [(1,), (2,), (5,), (10,)]
 twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)]
diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py
index f51ad49ad..5acbdc0dd 100644
--- a/tests/scipy_spatial_test.py
+++ b/tests/scipy_spatial_test.py
@@ -25,9 +25,8 @@ from scipy.spatial.transform import Slerp as osp_Slerp
 
 import jax.numpy as jnp
 import numpy as onp
-from jax import config
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 scipy_version = jtu.parse_version(scipy.version.version)
 
diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py
index 786d4ae03..1ab0bb9e5 100644
--- a/tests/scipy_stats_test.py
+++ b/tests/scipy_stats_test.py
@@ -27,8 +27,7 @@ from jax._src import dtypes, test_util as jtu
 from jax.scipy import stats as lsp_stats
 from jax.scipy.special import expit
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 scipy_version = jtu.parse_version(scipy.version.version)
 
diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py
index 291ee5360..8b7f11e31 100644
--- a/tests/shard_alike_test.py
+++ b/tests/shard_alike_test.py
@@ -25,8 +25,7 @@ from jax.experimental.shard_alike import shard_alike
 from jax.experimental.shard_map import shard_map
 from jax._src.lib import xla_extension_version
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 prev_xla_flags = None
 
diff --git a/tests/source_info_test.py b/tests/source_info_test.py
index aaa3abf55..0f876de1c 100644
--- a/tests/source_info_test.py
+++ b/tests/source_info_test.py
@@ -19,11 +19,10 @@ from absl.testing import absltest
 
 import jax
 from jax import lax
-from jax import config
 from jax._src import source_info_util
 from jax._src import test_util as jtu
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class SourceInfoTest(jtu.JaxTestCase):
diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py
index 441bee4ef..ba0ad5cb0 100644
--- a/tests/sparse_bcoo_bcsr_test.py
+++ b/tests/sparse_bcoo_bcsr_test.py
@@ -22,7 +22,6 @@ import unittest
 
 from absl.testing import absltest
 import jax
-from jax import config
 from jax import jit
 from jax import lax
 from jax import vmap
@@ -40,7 +39,7 @@ import jax.random
 from jax.util import split_list
 import numpy as np
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 COMPATIBLE_SHAPE_PAIRS = [
     [(), ()],
@@ -151,7 +150,7 @@ def _is_required_cuda_version_satisfied(cuda_version):
 class BCOOTest(sptu.SparseTestCase):
 
   def gpu_matmul_warning_context(self, msg):
-    if config.jax_bcoo_cusparse_lowering:
+    if jax.config.jax_bcoo_cusparse_lowering:
       return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg)
     return contextlib.nullcontext()
 
diff --git a/tests/sparse_test.py b/tests/sparse_test.py
index 2522befa9..49438f411 100644
--- a/tests/sparse_test.py
+++ b/tests/sparse_test.py
@@ -22,7 +22,6 @@ from absl.testing import parameterized
 
 import jax
 import jax.random
-from jax import config
 from jax import dtypes
 from jax.experimental import sparse
 from jax.experimental.sparse import coo as sparse_coo
@@ -43,7 +42,7 @@ from jax.util import split_list
 import numpy as np
 import scipy.sparse
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
 
diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py
index 998ce1c40..46086511d 100644
--- a/tests/sparsify_test.py
+++ b/tests/sparsify_test.py
@@ -22,7 +22,7 @@ from absl.testing import parameterized
 import numpy as np
 
 import jax
-from jax import config, jit, lax
+from jax import jit, lax
 import jax.numpy as jnp
 import jax._src.test_util as jtu
 from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer
@@ -31,7 +31,7 @@ from jax.experimental.sparse.transform import (
 from jax.experimental.sparse.util import CuSparseEfficiencyWarning
 from jax.experimental.sparse import test_util as sptu
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
   def _rand_sparse(shape, dtype, nse=nse):
diff --git a/tests/stack_test.py b/tests/stack_test.py
index acefc0630..655a42571 100644
--- a/tests/stack_test.py
+++ b/tests/stack_test.py
@@ -17,13 +17,13 @@
 
 from absl.testing import absltest
 
+import jax
 import jax.numpy as jnp
 from jax._src.lax.stack import Stack
 from jax._src import test_util as jtu
 
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class StackTest(jtu.JaxTestCase):
diff --git a/tests/stax_test.py b/tests/stax_test.py
index 351a0fdb3..6850f36a0 100644
--- a/tests/stax_test.py
+++ b/tests/stax_test.py
@@ -18,13 +18,13 @@ from absl.testing import absltest
 
 import numpy as np
 
+import jax
 from jax._src import test_util as jtu
 from jax import random
 from jax.example_libraries import stax
 from jax import dtypes
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def random_inputs(rng, input_shape):
diff --git a/tests/third_party/scipy/line_search_test.py b/tests/third_party/scipy/line_search_test.py
index 5e7d9a943..9b2480053 100644
--- a/tests/third_party/scipy/line_search_test.py
+++ b/tests/third_party/scipy/line_search_test.py
@@ -3,13 +3,12 @@ import scipy.optimize
 
 import jax
 from jax import grad
-from jax import config
 import jax.numpy as jnp
 import jax._src.test_util as jtu
 from jax._src.scipy.optimize.line_search import line_search
 
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class TestLineSearch(jtu.JaxTestCase):
diff --git a/tests/transfer_guard_test.py b/tests/transfer_guard_test.py
index fa08c52b6..b6d9058db 100644
--- a/tests/transfer_guard_test.py
+++ b/tests/transfer_guard_test.py
@@ -25,9 +25,7 @@ import jax
 import jax._src.test_util as jtu
 import jax.numpy as jnp
 
-from jax import config
-
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 def _host_to_device_funcs():
diff --git a/tests/util_test.py b/tests/util_test.py
index e06df8b3f..5f07d2f50 100644
--- a/tests/util_test.py
+++ b/tests/util_test.py
@@ -16,13 +16,13 @@ import operator
 
 from absl.testing import absltest
 
+import jax
 from jax._src import linear_util as lu
 from jax._src import test_util as jtu
 from jax._src import util
 
-from jax import config
 from jax._src.util import weakref_lru_cache
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 try:
   from jax._src.lib import utils as jaxlib_utils
diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py
index 75919de8f..58cf4a2ba 100644
--- a/tests/x64_context_test.py
+++ b/tests/x64_context_test.py
@@ -24,12 +24,11 @@ import numpy as np
 import jax
 from jax import lax
 from jax import random
-from jax import config
 from jax.experimental import enable_x64, disable_x64
 import jax.numpy as jnp
 import jax._src.test_util as jtu
 
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 class X64ContextTests(jtu.JaxTestCase):
@@ -49,12 +48,12 @@ class X64ContextTests(jtu.JaxTestCase):
   )
   def test_correctly_capture_default(self, jit, enable_or_disable):
     # The fact we defined a jitted function with a block with a different value
-    # of `config.enable_x64` has no impact on the output.
+    # of `jax.config.enable_x64` has no impact on the output.
     with enable_or_disable():
       func = jit(lambda: jnp.array(np.float64(0)))
       func()
 
-    expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
+    expected_dtype = "float64" if jax.config._read("jax_enable_x64") else "float32"
     self.assertEqual(func().dtype, expected_dtype)
 
     with enable_x64():
diff --git a/tests/xmap_test.py b/tests/xmap_test.py
index 91b63488a..0d11bb878 100644
--- a/tests/xmap_test.py
+++ b/tests/xmap_test.py
@@ -53,8 +53,7 @@ from jax._src.nn import initializers as nn_initializers
 from jax._src.sharding_impls import NamedSharding
 from jax._src.util import unzip2
 
-from jax import config
-config.parse_flags_with_absl()
+jax.config.parse_flags_with_absl()
 
 
 # TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
@@ -248,10 +247,10 @@ class SPMDTestMixin:
   def setUp(self):
     super().setUp()
     self.spmd_lowering = maps.SPMD_LOWERING.value
-    config.update('experimental_xmap_spmd_lowering', True)
+    jax.config.update('experimental_xmap_spmd_lowering', True)
 
   def tearDown(self):
-    config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
+    jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
 
 
 class ManualSPMDTestMixin:
@@ -261,12 +260,12 @@ class ManualSPMDTestMixin:
     super().setUp()
     self.spmd_lowering = maps.SPMD_LOWERING.value
     self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
-    config.update('experimental_xmap_spmd_lowering', True)
-    config.update('experimental_xmap_spmd_lowering_manual', True)
+    jax.config.update('experimental_xmap_spmd_lowering', True)
+    jax.config.update('experimental_xmap_spmd_lowering_manual', True)
 
   def tearDown(self):
-    config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
-    config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
+    jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
+    jax.config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
 
 
 @jtu.pytest_mark_if_available('multiaccelerator')
@@ -845,13 +844,13 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
     # TODO(apaszke): Add support for extracting XLA computations generated by
     # xmap and make this less of a smoke test.
     try:
-      config.update("experimental_xmap_ensure_fixed_sharding", True)
+      jax.config.update("experimental_xmap_ensure_fixed_sharding", True)
       f = xmap(lambda x: jnp.sin(2 * jnp.sum(jnp.cos(x) + 4, 'i')),
                in_axes=['i'], out_axes={}, axis_resources={'i': 'x'})
       x = jnp.arange(20, dtype=jnp.float32)
       f(x)
     finally:
-      config.update("experimental_xmap_ensure_fixed_sharding", False)
+      jax.config.update("experimental_xmap_ensure_fixed_sharding", False)
 
   @jtu.with_mesh([('x', 2)])
   def testConstantsInLowering(self):