Avoid 'from jax import config' imports

In some environments this appears to import the config module rather than
the config object.
This commit is contained in:
Jake VanderPlas 2024-04-11 13:23:27 -07:00
parent 301c3518d8
commit f090074d86
83 changed files with 162 additions and 224 deletions

View File

@ -33,9 +33,7 @@ from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
partial = functools.partial

View File

@ -15,12 +15,12 @@
import google_benchmark as benchmark
from jax import config
import jax
from jax import core
from jax._src.numpy import lax_numpy
from jax.experimental import export
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
@benchmark.register

View File

@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily.
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
### Example(s)
```python
from jax import config
config.update("jax_debug_nans", True)
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
@ -47,14 +47,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
You can disable JIT-compilation by:
* setting the `JAX_DISABLE_JIT=True` environment variable;
* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
### Examples
```python
from jax import config
config.update("jax_disable_jit", True)
import jax
jax.config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)

View File

@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more!
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
```python
from jax import config
config.update("jax_debug_nans", True)
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y

View File

@ -1946,9 +1946,9 @@
"\n",
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
"\n",
"* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"\n",
"* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"\n",
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
"\n",
@ -2141,24 +2141,24 @@
"\n",
" ```python\n",
" # again, this only works on startup!\n",
" from jax import config\n",
" config.update(\"jax_enable_x64\", True)\n",
" import jax\n",
" jax.config.update(\"jax_enable_x64\", True)\n",
" ```\n",
"\n",
"3. You can parse command-line flags with `absl.app.run(main)`\n",
"\n",
" ```python\n",
" from jax import config\n",
" config.config_with_absl()\n",
" import jax\n",
" jax.config.config_with_absl()\n",
" ```\n",
"\n",
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
"\n",
" ```python\n",
" from jax import config\n",
" import jax\n",
" if __name__ == '__main__':\n",
" # calls config.config_with_absl() *and* runs absl parsing\n",
" config.parse_flags_with_absl()\n",
" # calls jax.config.config_with_absl() *and* runs absl parsing\n",
" jax.config.parse_flags_with_absl()\n",
" ```\n",
"\n",
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",

View File

@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.
@ -1087,24 +1087,24 @@ There are a few ways to do this:
```python
# again, this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
import jax
jax.config.update("jax_enable_x64", True)
```
3. You can parse command-line flags with `absl.app.run(main)`
```python
from jax import config
config.config_with_absl()
import jax
jax.config.config_with_absl()
```
4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use
```python
from jax import config
import jax
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
# calls jax.config.config_with_absl() *and* runs absl parsing
jax.config.parse_flags_with_absl()
```
Note that #2-#4 work for _any_ of JAX's configuration options.

View File

@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code:
.. code-block:: python
from jax import config
config.update("jax_numpy_rank_promotion", "warn")
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
You can also set the option using the environment variable
:code:`JAX_NUMPY_RANK_PROMOTION`, for example as

View File

@ -22,6 +22,7 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
@ -30,8 +31,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from examples import kernel_lsq
sys.path.pop()
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):

View File

@ -17,10 +17,11 @@
from absl import app
from functools import partial
import jax
from jax import grad
from jax import jit
from jax import vmap
from jax import config
import jax.numpy as jnp
import jax.random as random
import jax.scipy as scipy
@ -125,5 +126,5 @@ def main(unused_argv):
mu.flatten() - std * 2, mu.flatten() + std * 2)
if __name__ == "__main__":
config.config_with_absl()
jax.config.config_with_absl()
app.run(main)

View File

@ -23,6 +23,7 @@ import collections
import itertools
from typing import Union, cast
import jax
from jax import lax
from jax._src import dtypes
from jax._src import test_util
@ -30,8 +31,7 @@ from jax._src.util import safe_map, safe_zip
import numpy as np
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -24,16 +24,14 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax import config
from jax._src import array
from jax.sharding import NamedSharding, GSPMDSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.array_serialization import serialization
import numpy as np
import tensorstore as ts
import unittest
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -16,13 +16,13 @@ import os
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax import config
from jax.experimental.jax2tf.examples import keras_reuse_main
from jax.experimental.jax2tf.tests import tf_test_util
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
FLAGS = flags.FLAGS

View File

@ -27,7 +27,7 @@ import tarfile
from typing import Callable, Optional
from absl.testing import absltest
from jax import config
import jax
from jax._src import test_util as jtu
from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax._src.lib import xla_extension
@ -37,7 +37,7 @@ import jax.numpy as jnp
import tensorflow as tf
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def serialize_directory(directory_path):

View File

@ -23,7 +23,6 @@ from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import dlpack
from jax import dtypes
from jax import lax
@ -42,7 +41,7 @@ try:
except ImportError:
tf = None
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
@ -1151,15 +1150,15 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
super().setUp()
def override_serialization_version(self, version_override: int):
version = config.jax_serialization_version
version = jax.config.jax_serialization_version
if version != version_override:
self.addCleanup(partial(config.update,
self.addCleanup(partial(jax.config.update,
"jax_serialization_version",
version_override))
config.update("jax_serialization_version", version_override)
jax.config.update("jax_serialization_version", version_override)
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version)
jax.config.jax_serialization_version)
def test_alternate(self):
# Alternate sin/cos with sin in TF and cos in JAX
@ -1275,7 +1274,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
@_parameterized_jit
def test_shape_poly_static_output_shape(self, with_jit=True):
if config.jax2tf_default_native_serialization:
if jax.config.jax2tf_default_native_serialization:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
x = np.array([0.7, 0.8], dtype=np.float32)
@ -1289,7 +1288,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
@_parameterized_jit
def test_shape_poly(self, with_jit=False):
if config.jax2tf_default_native_serialization:
if jax.config.jax2tf_default_native_serialization:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
@ -1308,7 +1307,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
@_parameterized_jit
def test_shape_poly_pytree_result(self, with_jit=True):
if config.jax2tf_default_native_serialization:
if jax.config.jax2tf_default_native_serialization:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
@ -1394,7 +1393,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
if kind == "bad_dim" and with_jit:
# TODO: in jit more the error pops up later, at AddV2
expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
if kind == "bad_dim" and config.jax2tf_default_native_serialization:
if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization:
# TODO(b/268386622): call_tf with shape polymorphism and native serialization.
expect_error = "Error compiling TensorFlow function"
fun_tf_rt = _maybe_tf_jit(with_jit,
@ -1432,7 +1431,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
f4_function=False, f4_saved_model=False):
if (f2_saved_model and
f4_saved_model and
not config.jax2tf_default_native_serialization):
not jax.config.jax2tf_default_native_serialization):
# TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients
# when saving f4, but only with non-native serialization.
raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients")

View File

@ -23,8 +23,7 @@ import numpy as np
from jax.experimental.jax2tf.tests import tf_test_util
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):

View File

@ -39,12 +39,11 @@ from absl import logging
import numpy.random as npr
import jax
from jax import config # Must import before TF
import jax # Must import before TF
from jax.experimental import jax2tf # Defines needed flags
from jax._src import test_util # Defines needed flags
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
# Import after parsing flags
from jax.experimental.jax2tf.tests import primitive_harness

View File

@ -25,8 +25,7 @@ from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class SavedModelTest(tf_test_util.JaxToTfTestCase):

View File

@ -23,9 +23,7 @@ import jax
from jax import lax
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning,message=".*jit-of-pmap.*")

View File

@ -17,7 +17,6 @@ import contextlib
import unittest
from absl.testing import absltest
import jax
from jax import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
@ -31,7 +30,7 @@ import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -16,12 +16,12 @@
import itertools as it
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import api_util
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class ApiUtilTest(jtu.JaxTestCase):

View File

@ -40,8 +40,7 @@ from jax.sharding import PartitionSpec as P
from jax._src import array
from jax._src import prng
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -37,8 +37,7 @@ from jax import vmap
from jax.interpreters import batching
from jax.tree_util import register_pytree_node
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
# These are 'manual' tests for batching (vmap). The more exhaustive, more

View File

@ -15,12 +15,11 @@
from absl.testing import absltest
import jax
from jax import config
from jax._src import api
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class ClearBackendsTest(jtu.JaxTestCase):

View File

@ -28,8 +28,7 @@ from jax._src import test_util as jtu
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def high_precision_dot(a, b):

View File

@ -18,9 +18,9 @@ import unittest
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, lax, make_jaxpr
from jax import config
from jax.interpreters import mlir
from jax.interpreters import xla
@ -34,7 +34,7 @@ from jax._src.lib import xla_client
xc = xla_client
xb = xla_bridge
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
# TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the
# dictionaries associated with the following objects.

View File

@ -25,8 +25,7 @@ from jax._src import test_util as jtu
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def high_precision_dot(a, b):

View File

@ -26,19 +26,18 @@ from jax import numpy as jnp
from jax.experimental import pjit
from jax._src.maps import xmap
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class DebugNaNsTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.cfg = config._read("jax_debug_nans")
config.update("jax_debug_nans", True)
self.cfg = jax.config._read("jax_debug_nans")
jax.config.update("jax_debug_nans", True)
def tearDown(self):
config.update("jax_debug_nans", self.cfg)
jax.config.update("jax_debug_nans", self.cfg)
super().tearDown()
def testSinc(self):
@ -67,7 +66,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
ans.block_until_ready()
def testJitComputationNaNContextManager(self):
config.update("jax_debug_nans", False)
jax.config.update("jax_debug_nans", False)
A = jnp.array(0.)
f = jax.jit(lambda x: 0. / x)
ans = f(A)
@ -210,11 +209,11 @@ class DebugInfsTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.cfg = config._read("jax_debug_infs")
config.update("jax_debug_infs", True)
self.cfg = jax.config._read("jax_debug_infs")
jax.config.update("jax_debug_infs", True)
def tearDown(self):
config.update("jax_debug_infs", self.cfg)
jax.config.update("jax_debug_infs", self.cfg)
super().tearDown()
def testSingleResultPrimitiveNoInf(self):

View File

@ -21,14 +21,13 @@ import unittest
from absl.testing import absltest
import jax
from jax import config
from jax.experimental import pjit
from jax._src import debugger
from jax._src import test_util as jtu
import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
fake_stdin = io.StringIO()

View File

@ -19,7 +19,6 @@ import unittest
from absl.testing import absltest
import jax
from jax import lax
from jax import config
from jax.experimental import pjit
from jax.interpreters import pxla
from jax._src import ad_checkpoint
@ -35,7 +34,7 @@ try:
except ModuleNotFoundError:
rich = None
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
debug_print = debugging.debug_print

View File

@ -23,7 +23,6 @@ from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax import lax
from jax import config
from jax.interpreters import batching
import jax._src.lib
@ -31,7 +30,7 @@ import jax._src.util
from jax._src import core
from jax._src import test_util as jtu
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")

View File

@ -24,8 +24,7 @@ from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class ExtendTest(jtu.JaxTestCase):

View File

@ -24,8 +24,7 @@ from jax._src import test_util as jtu
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def remat_of_for_loop(nsteps, body, state, **kwargs):
return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state,

View File

@ -22,11 +22,11 @@ from absl.testing import parameterized
import itertools as it
import jax.numpy as jnp
import jax
from jax import jit, jvp, vjp
import jax._src.test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
npr.seed(0)

View File

@ -17,11 +17,10 @@ from absl.testing import absltest
import jax
import jax._src.xla_bridge as xla_bridge
from jax import config
import jax._src.test_util as jtu
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class HeapProfilerTest(unittest.TestCase):

View File

@ -30,7 +30,6 @@ from absl.testing import absltest
import jax
from jax import ad_checkpoint
from jax import config
from jax import dtypes
from jax import lax
from jax import numpy as jnp
@ -46,7 +45,7 @@ xops = xla_client.ops
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class _TestingOutputStream:

View File

@ -24,8 +24,6 @@ from jax import image
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import config
# We use TensorFlow and PIL as reference implementations.
try:
import tensorflow as tf
@ -37,7 +35,7 @@ try:
except ImportError:
PIL_Image = None
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.all_floating
inexact_dtypes = jtu.dtypes.inexact

View File

@ -19,7 +19,6 @@ from unittest import SkipTest
from absl.testing import absltest
import jax
from jax import lax, numpy as jnp
from jax import config
from jax.experimental import host_callback as hcb
from jax._src import core
from jax._src import xla_bridge
@ -27,7 +26,7 @@ from jax._src.lib import xla_client
import jax._src.test_util as jtu
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class InfeedTest(jtu.JaxTestCase):

View File

@ -29,8 +29,7 @@ from jax.example_libraries import stax
from jax.experimental.jet import jet, fact, zero_series
from jax import lax
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def jvp_taylor(fun, primals, series):
# Computes the Taylor series the slow way, with nested jvp.

View File

@ -29,8 +29,7 @@ from jax.experimental.key_reuse._core import (
Source, Sink, Forward, KeyReuseSignature)
from jax.experimental.key_reuse import _core
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
key = jax.eval_shape(jax.random.key, 0)

View File

@ -31,8 +31,7 @@ from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
from jax.test_util import check_grads
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
compatible_shapes = [[(3,)],

View File

@ -42,8 +42,7 @@ from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax.control_flow import for_loop
from jax._src.maps import xmap
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
# Some tests are useful for testing both lax.cond and lax.switch. This function

View File

@ -27,8 +27,7 @@ from jax import lax
import jax.numpy as jnp
import jax._src.test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class EinsumTest(jtu.JaxTestCase):

View File

@ -24,8 +24,7 @@ import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.numpy.ufunc_api import get_if_single_primitive
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def scalar_add(x, y):

View File

@ -21,8 +21,7 @@ import jax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class VectorizeTest(jtu.JaxTestCase):

View File

@ -26,8 +26,7 @@ import jax
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]

View File

@ -14,6 +14,7 @@
import unittest
import jax
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
@ -21,8 +22,7 @@ from jax._src.lax import eigh as lax_eigh
from absl.testing import absltest
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
linear_sizes = [16, 97, 128]

View File

@ -34,8 +34,7 @@ from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
from jax.scipy import cluster as lsp_cluster
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
scipy_version = jtu.parse_version(scipy.version.version)

View File

@ -26,8 +26,7 @@ from jax._src import test_util as jtu
from jax._src.internal_test_util import lax_test_util
from jax._src import util
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

View File

@ -35,8 +35,7 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.util import safe_map, safe_zip
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -30,7 +30,6 @@ import scipy.linalg as sla
import scipy.sparse as sps
import jax
from jax import config
from jax._src import test_util as jtu
from jax.experimental.sparse import linalg, bcoo
import jax.numpy as jnp
@ -433,5 +432,5 @@ class F64LobpcgTest(LobpcgTest):
if __name__ == '__main__':
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -22,7 +22,6 @@ import textwrap
import unittest
import jax
from jax import config
import jax._src.test_util as jtu
from jax._src import xla_bridge
@ -33,7 +32,7 @@ from jax._src import xla_bridge
# parsing to work correctly with bazel (otherwise we could avoid importing
# absltest/absl logging altogether).
from absl.testing import absltest
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
@contextlib.contextmanager
@ -96,27 +95,27 @@ class LoggingTest(jtu.JaxTestCase):
self.assertEmpty(log_output.getvalue())
# Turn on all debug logging.
config.update("jax_debug_log_modules", "jax")
jax.config.update("jax_debug_log_modules", "jax")
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertIn("Finished tracing + transforming", log_output.getvalue())
self.assertIn("Compiling <lambda>", log_output.getvalue())
# Turn off all debug logging.
config.update("jax_debug_log_modules", None)
jax.config.update("jax_debug_log_modules", None)
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertEmpty(log_output.getvalue())
# Turn on one module.
config.update("jax_debug_log_modules", "jax._src.dispatch")
jax.config.update("jax_debug_log_modules", "jax._src.dispatch")
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertIn("Finished tracing + transforming", log_output.getvalue())
self.assertNotIn("Compiling <lambda>", log_output.getvalue())
# Turn everything off again.
config.update("jax_debug_log_modules", None)
jax.config.update("jax_debug_log_modules", None)
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertEmpty(log_output.getvalue())

View File

@ -23,8 +23,7 @@ from jax._src import config as jax_config
from jax._src.lib.mlir import ir
from jax import numpy as jnp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def module_to_string(module: ir.Module) -> str:

View File

@ -17,14 +17,13 @@ import math
from absl.testing import absltest
import jax
from jax import config
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class MockGPUTest(jtu.JaxTestCase):

View File

@ -14,9 +14,9 @@
from absl.testing import absltest
from jax._src import test_util as jtu
from jax import config
import jax
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class ImportTest(jtu.JaxTestCase):

View File

@ -26,8 +26,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -25,8 +25,7 @@ import jax
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
npr.seed(0)

View File

@ -26,7 +26,6 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import config
from jax._src import core
from jax._src import distributed
from jax._src import maps
@ -40,7 +39,7 @@ try:
except ImportError:
portpicker = None
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):

View File

@ -20,12 +20,11 @@ from jax._src import core
from jax import lax
from jax._src.pjit import pjit
from jax._src import linear_util as lu
from jax import config
from jax._src import test_util as jtu
from jax._src.lib import xla_client
from jax._src import ad_checkpoint
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def _get_hlo(f):
def wrapped(*args, **kwargs):

View File

@ -24,8 +24,7 @@ from jax.experimental.ode import odeint
import scipy.integrate as osp_integrate
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class ODETest(jtu.JaxTestCase):

View File

@ -26,8 +26,7 @@ from jax import jit, grad, jacfwd, jacrev
from jax import lax
from jax.example_libraries import optimizers
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class OptimizerTests(jtu.JaxTestCase):

View File

@ -21,7 +21,6 @@ import tempfile
from absl.testing import absltest
import jax
from jax import config
from jax._src import test_util as jtu
from jax.sharding import NamedSharding
from jax.experimental import profiler as exp_profiler
@ -29,7 +28,7 @@ import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('multiaccelerator')

View File

@ -26,14 +26,13 @@ except ImportError:
import jax
from jax import numpy as jnp
from jax import config
from jax.interpreters import pxla
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def _get_device_by_id(device_id: int) -> xc.Device:

View File

@ -19,12 +19,12 @@ from scipy.sparse import csgraph, csr_matrix
from absl.testing import absltest
import jax
from jax._src import dtypes
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex

View File

@ -26,7 +26,6 @@ from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.profiler
from jax import config
import jax._src.test_util as jtu
from jax._src import profiler
@ -50,7 +49,7 @@ try:
except ImportError:
pass
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class ProfilerTest(unittest.TestCase):

View File

@ -15,13 +15,12 @@ import itertools
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
import jax.scipy.fft as jsp_fft
import scipy.fft as osp_fft
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.floating
real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean

View File

@ -18,13 +18,13 @@ import operator
from functools import reduce
import numpy as np
import jax
from jax._src import test_util as jtu
import scipy.interpolate as sp_interp
import jax.scipy.interpolate as jsp_interp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class LaxBackedScipyInterpolateTests(jtu.JaxTestCase):

View File

@ -21,13 +21,13 @@ import numpy as np
from absl.testing import absltest
import scipy.ndimage as osp_ndimage
import jax
from jax import grad
from jax._src import test_util as jtu
from jax import dtypes
from jax.scipy import ndimage as lsp_ndimage
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.floating

View File

@ -17,13 +17,13 @@ import numpy as np
import scipy
import scipy.optimize
import jax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import jit
from jax import config
import jax.scipy.optimize
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def rosenbrock(np):

View File

@ -21,14 +21,14 @@ from absl.testing import absltest
import numpy as np
import scipy.signal as osp_signal
import jax
from jax import lax
import jax.numpy as jnp
from jax._src import dtypes
from jax._src import test_util as jtu
import jax.scipy.signal as jsp_signal
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
onedim_shapes = [(1,), (2,), (5,), (10,)]
twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)]

View File

@ -25,9 +25,8 @@ from scipy.spatial.transform import Slerp as osp_Slerp
import jax.numpy as jnp
import numpy as onp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
scipy_version = jtu.parse_version(scipy.version.version)

View File

@ -27,8 +27,7 @@ from jax._src import dtypes, test_util as jtu
from jax.scipy import stats as lsp_stats
from jax.scipy.special import expit
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
scipy_version = jtu.parse_version(scipy.version.version)

View File

@ -25,8 +25,7 @@ from jax.experimental.shard_alike import shard_alike
from jax.experimental.shard_map import shard_map
from jax._src.lib import xla_extension_version
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -19,11 +19,10 @@ from absl.testing import absltest
import jax
from jax import lax
from jax import config
from jax._src import source_info_util
from jax._src import test_util as jtu
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class SourceInfoTest(jtu.JaxTestCase):

View File

@ -22,7 +22,6 @@ import unittest
from absl.testing import absltest
import jax
from jax import config
from jax import jit
from jax import lax
from jax import vmap
@ -40,7 +39,7 @@ import jax.random
from jax.util import split_list
import numpy as np
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
COMPATIBLE_SHAPE_PAIRS = [
[(), ()],
@ -151,7 +150,7 @@ def _is_required_cuda_version_satisfied(cuda_version):
class BCOOTest(sptu.SparseTestCase):
def gpu_matmul_warning_context(self, msg):
if config.jax_bcoo_cusparse_lowering:
if jax.config.jax_bcoo_cusparse_lowering:
return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg)
return contextlib.nullcontext()

View File

@ -22,7 +22,6 @@ from absl.testing import parameterized
import jax
import jax.random
from jax import config
from jax import dtypes
from jax.experimental import sparse
from jax.experimental.sparse import coo as sparse_coo
@ -43,7 +42,7 @@ from jax.util import split_list
import numpy as np
import scipy.sparse
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import config, jit, lax
from jax import jit, lax
import jax.numpy as jnp
import jax._src.test_util as jtu
from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer
@ -31,7 +31,7 @@ from jax.experimental.sparse.transform import (
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
from jax.experimental.sparse import test_util as sptu
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
def _rand_sparse(shape, dtype, nse=nse):

View File

@ -17,13 +17,13 @@
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax._src.lax.stack import Stack
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class StackTest(jtu.JaxTestCase):

View File

@ -18,13 +18,13 @@ from absl.testing import absltest
import numpy as np
import jax
from jax._src import test_util as jtu
from jax import random
from jax.example_libraries import stax
from jax import dtypes
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def random_inputs(rng, input_shape):

View File

@ -3,13 +3,12 @@ import scipy.optimize
import jax
from jax import grad
from jax import config
import jax.numpy as jnp
import jax._src.test_util as jtu
from jax._src.scipy.optimize.line_search import line_search
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class TestLineSearch(jtu.JaxTestCase):

View File

@ -25,9 +25,7 @@ import jax
import jax._src.test_util as jtu
import jax.numpy as jnp
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def _host_to_device_funcs():

View File

@ -16,13 +16,13 @@ import operator
from absl.testing import absltest
import jax
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import util
from jax import config
from jax._src.util import weakref_lru_cache
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
try:
from jax._src.lib import utils as jaxlib_utils

View File

@ -24,12 +24,11 @@ import numpy as np
import jax
from jax import lax
from jax import random
from jax import config
from jax.experimental import enable_x64, disable_x64
import jax.numpy as jnp
import jax._src.test_util as jtu
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
class X64ContextTests(jtu.JaxTestCase):
@ -49,12 +48,12 @@ class X64ContextTests(jtu.JaxTestCase):
)
def test_correctly_capture_default(self, jit, enable_or_disable):
# The fact we defined a jitted function with a block with a different value
# of `config.enable_x64` has no impact on the output.
# of `jax.config.enable_x64` has no impact on the output.
with enable_or_disable():
func = jit(lambda: jnp.array(np.float64(0)))
func()
expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
expected_dtype = "float64" if jax.config._read("jax_enable_x64") else "float32"
self.assertEqual(func().dtype, expected_dtype)
with enable_x64():

View File

@ -53,8 +53,7 @@ from jax._src.nn import initializers as nn_initializers
from jax._src.sharding_impls import NamedSharding
from jax._src.util import unzip2
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
@ -248,10 +247,10 @@ class SPMDTestMixin:
def setUp(self):
super().setUp()
self.spmd_lowering = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering', True)
def tearDown(self):
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
class ManualSPMDTestMixin:
@ -261,12 +260,12 @@ class ManualSPMDTestMixin:
super().setUp()
self.spmd_lowering = maps.SPMD_LOWERING.value
self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def tearDown(self):
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
jax.config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
@jtu.pytest_mark_if_available('multiaccelerator')
@ -845,13 +844,13 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
# TODO(apaszke): Add support for extracting XLA computations generated by
# xmap and make this less of a smoke test.
try:
config.update("experimental_xmap_ensure_fixed_sharding", True)
jax.config.update("experimental_xmap_ensure_fixed_sharding", True)
f = xmap(lambda x: jnp.sin(2 * jnp.sum(jnp.cos(x) + 4, 'i')),
in_axes=['i'], out_axes={}, axis_resources={'i': 'x'})
x = jnp.arange(20, dtype=jnp.float32)
f(x)
finally:
config.update("experimental_xmap_ensure_fixed_sharding", False)
jax.config.update("experimental_xmap_ensure_fixed_sharding", False)
@jtu.with_mesh([('x', 2)])
def testConstantsInLowering(self):