mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than the config object.
This commit is contained in:
parent
301c3518d8
commit
f090074d86
@ -33,9 +33,7 @@ from jax.experimental import multihost_utils
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
partial = functools.partial
|
||||
|
@ -15,12 +15,12 @@
|
||||
|
||||
import google_benchmark as benchmark
|
||||
|
||||
from jax import config
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax.experimental import export
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@benchmark.register
|
||||
|
@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily.
|
||||
|
||||
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
|
||||
* setting the `JAX_DEBUG_NANS=True` environment variable;
|
||||
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
|
||||
### Example(s)
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.update("jax_debug_nans", True)
|
||||
import jax
|
||||
jax.config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
return x / y
|
||||
@ -47,14 +47,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||||
|
||||
You can disable JIT-compilation by:
|
||||
* setting the `JAX_DISABLE_JIT=True` environment variable;
|
||||
* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
|
||||
* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file;
|
||||
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
|
||||
|
||||
### Examples
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.update("jax_disable_jit", True)
|
||||
import jax
|
||||
jax.config.update("jax_disable_jit", True)
|
||||
|
||||
def f(x):
|
||||
y = jnp.log(x)
|
||||
|
@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more!
|
||||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.update("jax_debug_nans", True)
|
||||
import jax
|
||||
jax.config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
return x / y
|
||||
|
@ -1946,9 +1946,9 @@
|
||||
"\n",
|
||||
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
|
||||
"\n",
|
||||
"* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
||||
"* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
||||
"\n",
|
||||
"* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
||||
"* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
||||
"\n",
|
||||
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
|
||||
"\n",
|
||||
@ -2141,24 +2141,24 @@
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" # again, this only works on startup!\n",
|
||||
" from jax import config\n",
|
||||
" config.update(\"jax_enable_x64\", True)\n",
|
||||
" import jax\n",
|
||||
" jax.config.update(\"jax_enable_x64\", True)\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"3. You can parse command-line flags with `absl.app.run(main)`\n",
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" from jax import config\n",
|
||||
" config.config_with_absl()\n",
|
||||
" import jax\n",
|
||||
" jax.config.config_with_absl()\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" from jax import config\n",
|
||||
" import jax\n",
|
||||
" if __name__ == '__main__':\n",
|
||||
" # calls config.config_with_absl() *and* runs absl parsing\n",
|
||||
" config.parse_flags_with_absl()\n",
|
||||
" # calls jax.config.config_with_absl() *and* runs absl parsing\n",
|
||||
" jax.config.parse_flags_with_absl()\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",
|
||||
|
@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
|
||||
|
||||
* setting the `JAX_DEBUG_NANS=True` environment variable;
|
||||
|
||||
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
|
||||
* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
|
||||
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.
|
||||
|
||||
@ -1087,24 +1087,24 @@ There are a few ways to do this:
|
||||
|
||||
```python
|
||||
# again, this only works on startup!
|
||||
from jax import config
|
||||
config.update("jax_enable_x64", True)
|
||||
import jax
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
```
|
||||
|
||||
3. You can parse command-line flags with `absl.app.run(main)`
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.config_with_absl()
|
||||
import jax
|
||||
jax.config.config_with_absl()
|
||||
```
|
||||
|
||||
4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
import jax
|
||||
if __name__ == '__main__':
|
||||
# calls config.config_with_absl() *and* runs absl parsing
|
||||
config.parse_flags_with_absl()
|
||||
# calls jax.config.config_with_absl() *and* runs absl parsing
|
||||
jax.config.parse_flags_with_absl()
|
||||
```
|
||||
|
||||
Note that #2-#4 work for _any_ of JAX's configuration options.
|
||||
|
@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from jax import config
|
||||
config.update("jax_numpy_rank_promotion", "warn")
|
||||
import jax
|
||||
jax.config.update("jax_numpy_rank_promotion", "warn")
|
||||
|
||||
You can also set the option using the environment variable
|
||||
:code:`JAX_NUMPY_RANK_PROMOTION`, for example as
|
||||
|
@ -22,6 +22,7 @@ from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
@ -30,8 +31,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from examples import kernel_lsq
|
||||
sys.path.pop()
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
|
@ -17,10 +17,11 @@
|
||||
|
||||
from absl import app
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import grad
|
||||
from jax import jit
|
||||
from jax import vmap
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
import jax.random as random
|
||||
import jax.scipy as scipy
|
||||
@ -125,5 +126,5 @@ def main(unused_argv):
|
||||
mu.flatten() - std * 2, mu.flatten() + std * 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
|
@ -23,6 +23,7 @@ import collections
|
||||
import itertools
|
||||
from typing import Union, cast
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util
|
||||
@ -30,8 +31,7 @@ from jax._src.util import safe_map, safe_zip
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -24,16 +24,14 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
from jax._src import array
|
||||
from jax.sharding import NamedSharding, GSPMDSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.array_serialization import serialization
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
import unittest
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
|
@ -16,13 +16,13 @@ import os
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
|
||||
from jax.experimental.jax2tf.examples import keras_reuse_main
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ import tarfile
|
||||
from typing import Callable, Optional
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax import config
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
||||
from jax._src.lib import xla_extension
|
||||
@ -37,7 +37,7 @@ import jax.numpy as jnp
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def serialize_directory(directory_path):
|
||||
|
@ -23,7 +23,6 @@ from absl import logging
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
@ -42,7 +41,7 @@ try:
|
||||
except ImportError:
|
||||
tf = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
|
||||
@ -1151,15 +1150,15 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
super().setUp()
|
||||
|
||||
def override_serialization_version(self, version_override: int):
|
||||
version = config.jax_serialization_version
|
||||
version = jax.config.jax_serialization_version
|
||||
if version != version_override:
|
||||
self.addCleanup(partial(config.update,
|
||||
self.addCleanup(partial(jax.config.update,
|
||||
"jax_serialization_version",
|
||||
version_override))
|
||||
config.update("jax_serialization_version", version_override)
|
||||
jax.config.update("jax_serialization_version", version_override)
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version)
|
||||
jax.config.jax_serialization_version)
|
||||
|
||||
def test_alternate(self):
|
||||
# Alternate sin/cos with sin in TF and cos in JAX
|
||||
@ -1275,7 +1274,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@_parameterized_jit
|
||||
def test_shape_poly_static_output_shape(self, with_jit=True):
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if jax.config.jax2tf_default_native_serialization:
|
||||
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
|
||||
@ -1289,7 +1288,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@_parameterized_jit
|
||||
def test_shape_poly(self, with_jit=False):
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if jax.config.jax2tf_default_native_serialization:
|
||||
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
|
||||
x = np.array([7, 8, 9, 10], dtype=np.float32)
|
||||
def fun_jax(x):
|
||||
@ -1308,7 +1307,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@_parameterized_jit
|
||||
def test_shape_poly_pytree_result(self, with_jit=True):
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if jax.config.jax2tf_default_native_serialization:
|
||||
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
|
||||
x = np.array([7, 8, 9, 10], dtype=np.float32)
|
||||
def fun_jax(x):
|
||||
@ -1394,7 +1393,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
if kind == "bad_dim" and with_jit:
|
||||
# TODO: in jit more the error pops up later, at AddV2
|
||||
expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
|
||||
if kind == "bad_dim" and config.jax2tf_default_native_serialization:
|
||||
if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization:
|
||||
# TODO(b/268386622): call_tf with shape polymorphism and native serialization.
|
||||
expect_error = "Error compiling TensorFlow function"
|
||||
fun_tf_rt = _maybe_tf_jit(with_jit,
|
||||
@ -1432,7 +1431,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
f4_function=False, f4_saved_model=False):
|
||||
if (f2_saved_model and
|
||||
f4_saved_model and
|
||||
not config.jax2tf_default_native_serialization):
|
||||
not jax.config.jax2tf_default_native_serialization):
|
||||
# TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients
|
||||
# when saving f4, but only with non-native serialization.
|
||||
raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients")
|
||||
|
@ -23,8 +23,7 @@ import numpy as np
|
||||
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -39,12 +39,11 @@ from absl import logging
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax
|
||||
from jax import config # Must import before TF
|
||||
import jax # Must import before TF
|
||||
from jax.experimental import jax2tf # Defines needed flags
|
||||
from jax._src import test_util # Defines needed flags
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
# Import after parsing flags
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
|
@ -25,8 +25,7 @@ from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -23,9 +23,7 @@ import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
ignore_jit_of_pmap_warning = partial(
|
||||
jtu.ignore_warning,message=".*jit-of-pmap.*")
|
||||
|
@ -17,7 +17,6 @@ import contextlib
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -31,7 +30,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
|
@ -16,12 +16,12 @@
|
||||
import itertools as it
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ApiUtilTest(jtu.JaxTestCase):
|
||||
|
@ -40,8 +40,7 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import array
|
||||
from jax._src import prng
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
prev_xla_flags = None
|
||||
|
@ -37,8 +37,7 @@ from jax import vmap
|
||||
from jax.interpreters import batching
|
||||
from jax.tree_util import register_pytree_node
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
# These are 'manual' tests for batching (vmap). The more exhaustive, more
|
||||
|
@ -15,12 +15,11 @@
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import api
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ClearBackendsTest(jtu.JaxTestCase):
|
||||
|
@ -28,8 +28,7 @@ from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def high_precision_dot(a, b):
|
||||
|
@ -18,9 +18,9 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, lax, make_jaxpr
|
||||
from jax import config
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
@ -34,7 +34,7 @@ from jax._src.lib import xla_client
|
||||
xc = xla_client
|
||||
xb = xla_bridge
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
# TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the
|
||||
# dictionaries associated with the following objects.
|
||||
|
@ -25,8 +25,7 @@ from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def high_precision_dot(a, b):
|
||||
|
@ -26,19 +26,18 @@ from jax import numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
from jax._src.maps import xmap
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cfg = config._read("jax_debug_nans")
|
||||
config.update("jax_debug_nans", True)
|
||||
self.cfg = jax.config._read("jax_debug_nans")
|
||||
jax.config.update("jax_debug_nans", True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_debug_nans", self.cfg)
|
||||
jax.config.update("jax_debug_nans", self.cfg)
|
||||
super().tearDown()
|
||||
|
||||
def testSinc(self):
|
||||
@ -67,7 +66,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
ans.block_until_ready()
|
||||
|
||||
def testJitComputationNaNContextManager(self):
|
||||
config.update("jax_debug_nans", False)
|
||||
jax.config.update("jax_debug_nans", False)
|
||||
A = jnp.array(0.)
|
||||
f = jax.jit(lambda x: 0. / x)
|
||||
ans = f(A)
|
||||
@ -210,11 +209,11 @@ class DebugInfsTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cfg = config._read("jax_debug_infs")
|
||||
config.update("jax_debug_infs", True)
|
||||
self.cfg = jax.config._read("jax_debug_infs")
|
||||
jax.config.update("jax_debug_infs", True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_debug_infs", self.cfg)
|
||||
jax.config.update("jax_debug_infs", self.cfg)
|
||||
super().tearDown()
|
||||
|
||||
def testSingleResultPrimitiveNoInf(self):
|
||||
|
@ -21,14 +21,13 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax.experimental import pjit
|
||||
from jax._src import debugger
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
|
||||
fake_stdin = io.StringIO()
|
||||
|
@ -19,7 +19,6 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import ad_checkpoint
|
||||
@ -35,7 +34,7 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
rich = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
debug_print = debugging.debug_print
|
||||
|
||||
|
@ -23,7 +23,6 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax.interpreters import batching
|
||||
|
||||
import jax._src.lib
|
||||
@ -31,7 +30,7 @@ import jax._src.util
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
|
||||
|
@ -24,8 +24,7 @@ from jax._src import linear_util
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ExtendTest(jtu.JaxTestCase):
|
||||
|
@ -24,8 +24,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
def remat_of_for_loop(nsteps, body, state, **kwargs):
|
||||
return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state,
|
||||
|
@ -22,11 +22,11 @@ from absl.testing import parameterized
|
||||
|
||||
import itertools as it
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
from jax import jit, jvp, vjp
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
npr.seed(0)
|
||||
|
||||
|
@ -17,11 +17,10 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax._src.xla_bridge as xla_bridge
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class HeapProfilerTest(unittest.TestCase):
|
||||
|
@ -30,7 +30,6 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
@ -46,7 +45,7 @@ xops = xla_client.ops
|
||||
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class _TestingOutputStream:
|
||||
|
@ -24,8 +24,6 @@ from jax import image
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
|
||||
# We use TensorFlow and PIL as reference implementations.
|
||||
try:
|
||||
import tensorflow as tf
|
||||
@ -37,7 +35,7 @@ try:
|
||||
except ImportError:
|
||||
PIL_Image = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
inexact_dtypes = jtu.dtypes.inexact
|
||||
|
@ -19,7 +19,6 @@ from unittest import SkipTest
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import lax, numpy as jnp
|
||||
from jax import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax._src import core
|
||||
from jax._src import xla_bridge
|
||||
@ -27,7 +26,7 @@ from jax._src.lib import xla_client
|
||||
import jax._src.test_util as jtu
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class InfeedTest(jtu.JaxTestCase):
|
||||
|
@ -29,8 +29,7 @@ from jax.example_libraries import stax
|
||||
from jax.experimental.jet import jet, fact, zero_series
|
||||
from jax import lax
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
def jvp_taylor(fun, primals, series):
|
||||
# Computes the Taylor series the slow way, with nested jvp.
|
||||
|
@ -29,8 +29,7 @@ from jax.experimental.key_reuse._core import (
|
||||
Source, Sink, Forward, KeyReuseSignature)
|
||||
from jax.experimental.key_reuse import _core
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
key = jax.eval_shape(jax.random.key, 0)
|
||||
|
@ -31,8 +31,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
from jax.test_util import check_grads
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
compatible_shapes = [[(3,)],
|
||||
|
@ -42,8 +42,7 @@ from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src.maps import xmap
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
# Some tests are useful for testing both lax.cond and lax.switch. This function
|
||||
|
@ -27,8 +27,7 @@ from jax import lax
|
||||
import jax.numpy as jnp
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class EinsumTest(jtu.JaxTestCase):
|
||||
|
@ -24,8 +24,7 @@ import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.numpy.ufunc_api import get_if_single_primitive
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def scalar_add(x, y):
|
||||
|
@ -21,8 +21,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class VectorizeTest(jtu.JaxTestCase):
|
||||
|
@ -26,8 +26,7 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
@ -21,8 +22,7 @@ from jax._src.lax import eigh as lax_eigh
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
linear_sizes = [16, 97, 128]
|
||||
|
@ -34,8 +34,7 @@ from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
from jax.scipy import cluster as lsp_cluster
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
scipy_version = jtu.parse_version(scipy.version.version)
|
||||
|
||||
|
@ -26,8 +26,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src import util
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
@ -35,8 +35,7 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -30,7 +30,6 @@ import scipy.linalg as sla
|
||||
import scipy.sparse as sps
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.sparse import linalg, bcoo
|
||||
import jax.numpy as jnp
|
||||
@ -433,5 +432,5 @@ class F64LobpcgTest(LobpcgTest):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -22,7 +22,6 @@ import textwrap
|
||||
import unittest
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
|
||||
@ -33,7 +32,7 @@ from jax._src import xla_bridge
|
||||
# parsing to work correctly with bazel (otherwise we could avoid importing
|
||||
# absltest/absl logging altogether).
|
||||
from absl.testing import absltest
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -96,27 +95,27 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
self.assertEmpty(log_output.getvalue())
|
||||
|
||||
# Turn on all debug logging.
|
||||
config.update("jax_debug_log_modules", "jax")
|
||||
jax.config.update("jax_debug_log_modules", "jax")
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertIn("Finished tracing + transforming", log_output.getvalue())
|
||||
self.assertIn("Compiling <lambda>", log_output.getvalue())
|
||||
|
||||
# Turn off all debug logging.
|
||||
config.update("jax_debug_log_modules", None)
|
||||
jax.config.update("jax_debug_log_modules", None)
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEmpty(log_output.getvalue())
|
||||
|
||||
# Turn on one module.
|
||||
config.update("jax_debug_log_modules", "jax._src.dispatch")
|
||||
jax.config.update("jax_debug_log_modules", "jax._src.dispatch")
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertIn("Finished tracing + transforming", log_output.getvalue())
|
||||
self.assertNotIn("Compiling <lambda>", log_output.getvalue())
|
||||
|
||||
# Turn everything off again.
|
||||
config.update("jax_debug_log_modules", None)
|
||||
jax.config.update("jax_debug_log_modules", None)
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEmpty(log_output.getvalue())
|
||||
|
@ -23,8 +23,7 @@ from jax._src import config as jax_config
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def module_to_string(module: ir.Module) -> str:
|
||||
|
@ -17,14 +17,13 @@ import math
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class MockGPUTest(jtu.JaxTestCase):
|
||||
|
@ -14,9 +14,9 @@
|
||||
from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
import jax
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ImportTest(jtu.JaxTestCase):
|
||||
|
@ -26,8 +26,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
|
@ -25,8 +25,7 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
npr.seed(0)
|
||||
|
||||
|
@ -26,7 +26,6 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import core
|
||||
from jax._src import distributed
|
||||
from jax._src import maps
|
||||
@ -40,7 +39,7 @@ try:
|
||||
except ImportError:
|
||||
portpicker = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@unittest.skipIf(not portpicker, "Test requires portpicker")
|
||||
class DistributedTest(jtu.JaxTestCase):
|
||||
|
@ -20,12 +20,11 @@ from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src.pjit import pjit
|
||||
from jax._src import linear_util as lu
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import ad_checkpoint
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
def _get_hlo(f):
|
||||
def wrapped(*args, **kwargs):
|
||||
|
@ -24,8 +24,7 @@ from jax.experimental.ode import odeint
|
||||
|
||||
import scipy.integrate as osp_integrate
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ODETest(jtu.JaxTestCase):
|
||||
|
@ -26,8 +26,7 @@ from jax import jit, grad, jacfwd, jacrev
|
||||
from jax import lax
|
||||
from jax.example_libraries import optimizers
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class OptimizerTests(jtu.JaxTestCase):
|
||||
|
@ -21,7 +21,6 @@ import tempfile
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.experimental import profiler as exp_profiler
|
||||
@ -29,7 +28,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
|
@ -26,14 +26,13 @@ except ImportError:
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import config
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _get_device_by_id(device_id: int) -> xc.Device:
|
||||
|
@ -19,12 +19,12 @@ from scipy.sparse import csgraph, csr_matrix
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
|
@ -26,7 +26,6 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src import profiler
|
||||
|
||||
@ -50,7 +49,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ProfilerTest(unittest.TestCase):
|
||||
|
@ -15,13 +15,12 @@ import itertools
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
import jax.scipy.fft as jsp_fft
|
||||
import scipy.fft as osp_fft
|
||||
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
float_dtypes = jtu.dtypes.floating
|
||||
real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean
|
||||
|
@ -18,13 +18,13 @@ import operator
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
import scipy.interpolate as sp_interp
|
||||
import jax.scipy.interpolate as jsp_interp
|
||||
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class LaxBackedScipyInterpolateTests(jtu.JaxTestCase):
|
||||
|
@ -21,13 +21,13 @@ import numpy as np
|
||||
from absl.testing import absltest
|
||||
import scipy.ndimage as osp_ndimage
|
||||
|
||||
import jax
|
||||
from jax import grad
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax.scipy import ndimage as lsp_ndimage
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
float_dtypes = jtu.dtypes.floating
|
||||
|
@ -17,13 +17,13 @@ import numpy as np
|
||||
import scipy
|
||||
import scipy.optimize
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import jit
|
||||
from jax import config
|
||||
import jax.scipy.optimize
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def rosenbrock(np):
|
||||
|
@ -21,14 +21,14 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
import scipy.signal as osp_signal
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
import jax.scipy.signal as jsp_signal
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
onedim_shapes = [(1,), (2,), (5,), (10,)]
|
||||
twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)]
|
||||
|
@ -25,9 +25,8 @@ from scipy.spatial.transform import Slerp as osp_Slerp
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
scipy_version = jtu.parse_version(scipy.version.version)
|
||||
|
||||
|
@ -27,8 +27,7 @@ from jax._src import dtypes, test_util as jtu
|
||||
from jax.scipy import stats as lsp_stats
|
||||
from jax.scipy.special import expit
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
scipy_version = jtu.parse_version(scipy.version.version)
|
||||
|
||||
|
@ -25,8 +25,7 @@ from jax.experimental.shard_alike import shard_alike
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
|
@ -19,11 +19,10 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class SourceInfoTest(jtu.JaxTestCase):
|
||||
|
@ -22,7 +22,6 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import vmap
|
||||
@ -40,7 +39,7 @@ import jax.random
|
||||
from jax.util import split_list
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
COMPATIBLE_SHAPE_PAIRS = [
|
||||
[(), ()],
|
||||
@ -151,7 +150,7 @@ def _is_required_cuda_version_satisfied(cuda_version):
|
||||
class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
def gpu_matmul_warning_context(self, msg):
|
||||
if config.jax_bcoo_cusparse_lowering:
|
||||
if jax.config.jax_bcoo_cusparse_lowering:
|
||||
return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
@ -22,7 +22,6 @@ from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
import jax.random
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax.experimental import sparse
|
||||
from jax.experimental.sparse import coo as sparse_coo
|
||||
@ -43,7 +42,7 @@ from jax.util import split_list
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
||||
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import config, jit, lax
|
||||
from jax import jit, lax
|
||||
import jax.numpy as jnp
|
||||
import jax._src.test_util as jtu
|
||||
from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer
|
||||
@ -31,7 +31,7 @@ from jax.experimental.sparse.transform import (
|
||||
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
|
||||
from jax.experimental.sparse import test_util as sptu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
||||
def _rand_sparse(shape, dtype, nse=nse):
|
||||
|
@ -17,13 +17,13 @@
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.stack import Stack
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class StackTest(jtu.JaxTestCase):
|
||||
|
@ -18,13 +18,13 @@ from absl.testing import absltest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax import random
|
||||
from jax.example_libraries import stax
|
||||
from jax import dtypes
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def random_inputs(rng, input_shape):
|
||||
|
3
tests/third_party/scipy/line_search_test.py
vendored
3
tests/third_party/scipy/line_search_test.py
vendored
@ -3,13 +3,12 @@ import scipy.optimize
|
||||
|
||||
import jax
|
||||
from jax import grad
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src.scipy.optimize.line_search import line_search
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class TestLineSearch(jtu.JaxTestCase):
|
||||
|
@ -25,9 +25,7 @@ import jax
|
||||
import jax._src.test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _host_to_device_funcs():
|
||||
|
@ -16,13 +16,13 @@ import operator
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
|
||||
from jax import config
|
||||
from jax._src.util import weakref_lru_cache
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
try:
|
||||
from jax._src.lib import utils as jaxlib_utils
|
||||
|
@ -24,12 +24,11 @@ import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax import config
|
||||
from jax.experimental import enable_x64, disable_x64
|
||||
import jax.numpy as jnp
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class X64ContextTests(jtu.JaxTestCase):
|
||||
@ -49,12 +48,12 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
)
|
||||
def test_correctly_capture_default(self, jit, enable_or_disable):
|
||||
# The fact we defined a jitted function with a block with a different value
|
||||
# of `config.enable_x64` has no impact on the output.
|
||||
# of `jax.config.enable_x64` has no impact on the output.
|
||||
with enable_or_disable():
|
||||
func = jit(lambda: jnp.array(np.float64(0)))
|
||||
func()
|
||||
|
||||
expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
|
||||
expected_dtype = "float64" if jax.config._read("jax_enable_x64") else "float32"
|
||||
self.assertEqual(func().dtype, expected_dtype)
|
||||
|
||||
with enable_x64():
|
||||
|
@ -53,8 +53,7 @@ from jax._src.nn import initializers as nn_initializers
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.util import unzip2
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
|
||||
@ -248,10 +247,10 @@ class SPMDTestMixin:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.spmd_lowering = maps.SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
jax.config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
|
||||
jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
|
||||
|
||||
|
||||
class ManualSPMDTestMixin:
|
||||
@ -261,12 +260,12 @@ class ManualSPMDTestMixin:
|
||||
super().setUp()
|
||||
self.spmd_lowering = maps.SPMD_LOWERING.value
|
||||
self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', True)
|
||||
jax.config.update('experimental_xmap_spmd_lowering', True)
|
||||
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
|
||||
jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
|
||||
jax.config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
@ -845,13 +844,13 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
# TODO(apaszke): Add support for extracting XLA computations generated by
|
||||
# xmap and make this less of a smoke test.
|
||||
try:
|
||||
config.update("experimental_xmap_ensure_fixed_sharding", True)
|
||||
jax.config.update("experimental_xmap_ensure_fixed_sharding", True)
|
||||
f = xmap(lambda x: jnp.sin(2 * jnp.sum(jnp.cos(x) + 4, 'i')),
|
||||
in_axes=['i'], out_axes={}, axis_resources={'i': 'x'})
|
||||
x = jnp.arange(20, dtype=jnp.float32)
|
||||
f(x)
|
||||
finally:
|
||||
config.update("experimental_xmap_ensure_fixed_sharding", False)
|
||||
jax.config.update("experimental_xmap_ensure_fixed_sharding", False)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testConstantsInLowering(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user