mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change to simpler import for jax.config
This commit is contained in:
parent
5647d5db98
commit
fbe4f10403
@ -34,7 +34,7 @@ from jax.experimental import multihost_utils
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
@ -12,13 +12,13 @@ JAX offers flags and context managers that enable catching errors more easily.
|
||||
|
||||
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
|
||||
* setting the `JAX_DEBUG_NANS=True` environment variable;
|
||||
* adding `from jax.config import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
|
||||
### Example(s)
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
@ -47,13 +47,13 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||||
|
||||
You can disable JIT-compilation by:
|
||||
* setting the `JAX_DISABLE_JIT=True` environment variable;
|
||||
* adding `from jax.config import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
|
||||
* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
|
||||
|
||||
### Examples
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.update("jax_disable_jit", True)
|
||||
|
||||
def f(x):
|
||||
|
@ -82,7 +82,7 @@ Click [here](checkify_guide) to learn more!
|
||||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
|
@ -1946,9 +1946,9 @@
|
||||
"\n",
|
||||
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
|
||||
"\n",
|
||||
"* adding `from jax.config import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
||||
"* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
||||
"\n",
|
||||
"* adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
||||
"* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
||||
"\n",
|
||||
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
|
||||
"\n",
|
||||
@ -2141,21 +2141,21 @@
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" # again, this only works on startup!\n",
|
||||
" from jax.config import config\n",
|
||||
" from jax import config\n",
|
||||
" config.update(\"jax_enable_x64\", True)\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"3. You can parse command-line flags with `absl.app.run(main)`\n",
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" from jax.config import config\n",
|
||||
" from jax import config\n",
|
||||
" config.config_with_absl()\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" from jax.config import config\n",
|
||||
" from jax import config\n",
|
||||
" if __name__ == '__main__':\n",
|
||||
" # calls config.config_with_absl() *and* runs absl parsing\n",
|
||||
" config.parse_flags_with_absl()\n",
|
||||
|
@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
|
||||
|
||||
* setting the `JAX_DEBUG_NANS=True` environment variable;
|
||||
|
||||
* adding `from jax.config import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
|
||||
* adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
|
||||
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.
|
||||
|
||||
@ -1087,21 +1087,21 @@ There are a few ways to do this:
|
||||
|
||||
```python
|
||||
# again, this only works on startup!
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.update("jax_enable_x64", True)
|
||||
```
|
||||
|
||||
3. You can parse command-line flags with `absl.app.run(main)`
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.config_with_absl()
|
||||
```
|
||||
|
||||
4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
if __name__ == '__main__':
|
||||
# calls config.config_with_absl() *and* runs absl parsing
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -40,7 +40,7 @@ One is by using :code:`jax.config` in your code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.update("jax_numpy_rank_promotion", "warn")
|
||||
|
||||
You can also set the option using the environment variable
|
||||
|
@ -30,7 +30,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from examples import kernel_lsq
|
||||
sys.path.pop()
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ from functools import partial
|
||||
from jax import grad
|
||||
from jax import jit
|
||||
from jax import vmap
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
import jax.random as random
|
||||
import jax.scipy as scipy
|
||||
|
@ -28,7 +28,7 @@ from jax._src.util import safe_map, safe_zip
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
|
@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
register_pytree_node, Partial)
|
||||
from jax._src import core
|
||||
|
@ -22,7 +22,7 @@ from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import linear_util as lu
|
||||
|
@ -21,7 +21,7 @@ import operator
|
||||
|
||||
from typing import Callable, Sequence, List, Tuple
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
|
@ -22,7 +22,7 @@ import jax
|
||||
import weakref
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
tree_map, tree_flatten_with_path, keystr)
|
||||
|
@ -21,7 +21,7 @@ import tracemalloc as tm
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import array
|
||||
from jax.sharding import NamedSharding, GSPMDSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
@ -507,7 +507,7 @@ import warnings
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
|
@ -17,7 +17,7 @@ from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
from jax.experimental.jax2tf.examples import keras_reuse_main
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -18,7 +18,7 @@ from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
from jax.experimental.jax2tf.examples import saved_model_main
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -77,7 +77,7 @@ import numpy as np
|
||||
from numpy import array, float32
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
|
||||
|
@ -26,7 +26,7 @@ from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax._src.lib.mlir import ir
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ from absl import logging
|
||||
import numpy.random as npr
|
||||
|
||||
import jax
|
||||
from jax.config import config # Must import before TF
|
||||
from jax import config # Must import before TF
|
||||
from jax.experimental import jax2tf # Defines needed flags
|
||||
from jax._src import test_util # Defines needed flags
|
||||
|
||||
|
@ -19,7 +19,7 @@ from typing import Any, Callable, Optional, Sequence, Union
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import dtypes
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
|
@ -35,7 +35,7 @@ from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
import jax._src.xla_bridge
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -28,7 +28,7 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -64,7 +64,7 @@ import jax
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
|
@ -25,7 +25,7 @@ from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -44,7 +44,7 @@ from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src.config import numpy_dtype_promotion
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -31,7 +31,7 @@ from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import pjit
|
||||
|
@ -28,7 +28,7 @@ from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax._src import xla_bridge
|
||||
import numpy as np
|
||||
|
@ -28,7 +28,7 @@ import jax
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import (
|
||||
nfold_vmap, _count_stored_elements,
|
||||
|
@ -74,5 +74,5 @@ from jax._src.interpreters.ad import (
|
||||
zeros_like_p as zeros_like_p,
|
||||
)
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import source_info_util
|
||||
|
@ -23,7 +23,7 @@ import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
@ -31,7 +31,7 @@ from jax.experimental.serialize_executable import (
|
||||
from jax.experimental import topologies
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
@ -75,7 +75,7 @@ import jax._src.util as jax_util
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -20,7 +20,7 @@ from jax._src import api_util
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
@ -37,7 +37,7 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import array
|
||||
from jax._src import prng
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -35,7 +35,7 @@ from jax import vmap
|
||||
from jax.interpreters import batching
|
||||
from jax.tree_util import register_pytree_node
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ import jax
|
||||
from jax import lax
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src.lib import xla_extension
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
from jax.sharding import NamedSharding
|
||||
|
@ -18,7 +18,7 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
|
||||
|
@ -43,7 +43,7 @@ from jax._src.lib import xla_client
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
@ -28,7 +28,7 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
|
||||
tree_leaves)
|
||||
|
||||
|
@ -29,7 +29,7 @@ from jax import tree_util
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, lax, make_jaxpr
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
|
@ -26,7 +26,7 @@ from jax import tree_util
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from jax._src import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import pjit, maps
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ from typing import IO, Sequence, Tuple
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import pjit
|
||||
from jax._src import debugger
|
||||
from jax._src import test_util as jtu
|
||||
|
@ -19,7 +19,7 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import pxla
|
||||
|
@ -23,7 +23,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.interpreters import batching
|
||||
|
||||
import jax._src.lib
|
||||
|
@ -28,7 +28,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -27,7 +27,7 @@ from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.numpy.util import promote_dtypes_complex
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FFT_NORMS = [None, "ortho", "forward", "backward"]
|
||||
|
@ -24,7 +24,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def remat_of_for_loop(nsteps, body, state, **kwargs):
|
||||
|
@ -25,7 +25,7 @@ import jax.numpy as jnp
|
||||
from jax import jit, jvp, vjp
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
npr.seed(0)
|
||||
|
@ -17,7 +17,7 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax._src.xla_bridge as xla_bridge
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
from jax._src import core
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
@ -25,7 +25,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
|
@ -24,7 +24,7 @@ from jax import image
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
# We use TensorFlow and PIL as reference implementations.
|
||||
try:
|
||||
|
@ -24,7 +24,7 @@ from jax import dtypes
|
||||
from jax._src import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -23,7 +23,7 @@ from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax._src.interpreters import ad
|
||||
|
@ -22,7 +22,7 @@ import jax
|
||||
from jax import jaxpr_util, jit, make_jaxpr, numpy as jnp
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -29,7 +29,7 @@ from jax.example_libraries import stax
|
||||
from jax.experimental.jet import jet, fact, zero_series
|
||||
from jax import lax
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def jvp_taylor(fun, primals, series):
|
||||
|
@ -30,7 +30,7 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.test_util import check_grads
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -41,7 +41,7 @@ import jax.scipy as jsp
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ from jax import lax
|
||||
import jax.numpy as jnp
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -35,7 +35,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# We disable the whitespace continuation check in this file because otherwise it
|
||||
|
@ -33,7 +33,7 @@ from jax import numpy as jnp
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -29,7 +29,7 @@ from jax import numpy as jnp
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -49,7 +49,7 @@ from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src import array
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -20,7 +20,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ import jax._src.scipy.sparse.linalg as sp_linalg
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -21,7 +21,7 @@ from jax._src.lax import eigh as lax_eigh
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -32,7 +32,7 @@ from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
from jax.scipy import cluster as lsp_cluster
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -47,7 +47,7 @@ from jax._src import lax_reference
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src import util
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
|
@ -34,7 +34,7 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -32,7 +32,7 @@ from jax import scipy as jsp
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
@ -30,7 +30,7 @@ import scipy.linalg as sla
|
||||
import scipy.sparse as sps
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.sparse import linalg, bcoo
|
||||
import jax.numpy as jnp
|
||||
|
@ -23,7 +23,7 @@ from jax._src import config as jax_config
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
@ -25,7 +25,7 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
npr.seed(0)
|
||||
|
@ -26,7 +26,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import core
|
||||
from jax._src import distributed
|
||||
import jax.numpy as jnp
|
||||
|
@ -20,7 +20,7 @@ from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src.pjit import pjit
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
|
@ -31,7 +31,7 @@ from jax import random
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from jax.tree_util import tree_map
|
||||
|
||||
import scipy.integrate as osp_integrate
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ from jax import tree_util
|
||||
from jax import lax
|
||||
from jax.example_libraries import optimizers
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ except ImportError:
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src import test_util as jtu
|
||||
|
@ -58,7 +58,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2, safe_zip
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
@ -55,7 +55,7 @@ from jax._src import array
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
|
@ -23,7 +23,7 @@ from jax._src import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
try:
|
||||
|
@ -29,7 +29,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import mlir
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import jax.dlpack
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
|
@ -16,7 +16,7 @@
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.linalg as osp_linalg
|
||||
|
@ -42,7 +42,7 @@ from jax.interpreters import xla
|
||||
from jax._src import random as jax_random
|
||||
from jax._src import prng as prng_internal
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
|
@ -19,7 +19,7 @@ from jax._src import test_util as jtu
|
||||
import jax.scipy.fft as jsp_fft
|
||||
import scipy.fft as osp_fft
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
@ -22,7 +22,7 @@ from jax._src import test_util as jtu
|
||||
import scipy.interpolate as sp_interp
|
||||
import jax.scipy.interpolate as jsp_interp
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
@ -26,7 +26,7 @@ from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax.scipy import ndimage as lsp_ndimage
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ import scipy.optimize
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import jit
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
import jax.scipy.optimize
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -27,7 +27,7 @@ from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
import jax.scipy.signal as jsp_signal
|
||||
|
||||
from jax.config import config
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
onedim_shapes = [(1,), (2,), (5,), (10,)]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user