Change to simpler import for jax.config

This commit is contained in:
Jake VanderPlas 2023-04-21 11:51:22 -07:00
parent 5647d5db98
commit fbe4f10403
112 changed files with 123 additions and 123 deletions

View File

@ -34,7 +34,7 @@ from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -12,13 +12,13 @@ JAX offers flags and context managers that enable catching errors more easily.
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax.config import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
### Example(s)
```python
from jax.config import config
from jax import config
config.update("jax_debug_nans", True)
def f(x, y):
@ -47,13 +47,13 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
You can disable JIT-compilation by:
* setting the `JAX_DISABLE_JIT=True` environment variable;
* adding `from jax.config import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
### Examples
```python
from jax.config import config
from jax import config
config.update("jax_disable_jit", True)
def f(x):

View File

@ -82,7 +82,7 @@ Click [here](checkify_guide) to learn more!
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
```python
from jax.config import config
from jax import config
config.update("jax_debug_nans", True)
def f(x, y):

View File

@ -1946,9 +1946,9 @@
"\n",
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
"\n",
"* adding `from jax.config import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"\n",
"* adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"\n",
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
"\n",
@ -2141,21 +2141,21 @@
"\n",
" ```python\n",
" # again, this only works on startup!\n",
" from jax.config import config\n",
" from jax import config\n",
" config.update(\"jax_enable_x64\", True)\n",
" ```\n",
"\n",
"3. You can parse command-line flags with `absl.app.run(main)`\n",
"\n",
" ```python\n",
" from jax.config import config\n",
" from jax import config\n",
" config.config_with_absl()\n",
" ```\n",
"\n",
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
"\n",
" ```python\n",
" from jax.config import config\n",
" from jax import config\n",
" if __name__ == '__main__':\n",
" # calls config.config_with_absl() *and* runs absl parsing\n",
" config.parse_flags_with_absl()\n",

View File

@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax.config import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.
@ -1087,21 +1087,21 @@ There are a few ways to do this:
```python
# again, this only works on startup!
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)
```
3. You can parse command-line flags with `absl.app.run(main)`
```python
from jax.config import config
from jax import config
config.config_with_absl()
```
4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use
```python
from jax.config import config
from jax import config
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()

View File

@ -40,7 +40,7 @@ One is by using :code:`jax.config` in your code:
.. code-block:: python
from jax.config import config
from jax import config
config.update("jax_numpy_rank_promotion", "warn")
You can also set the option using the environment variable

View File

@ -30,7 +30,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from examples import kernel_lsq
sys.path.pop()
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -21,7 +21,7 @@ from functools import partial
from jax import grad
from jax import jit
from jax import vmap
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as random
import jax.scipy as scipy

View File

@ -28,7 +28,7 @@ from jax._src.util import safe_map, safe_zip
import numpy as np
from jax.config import config
from jax import config
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map

View File

@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
import jax
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax.config import config
from jax import config
from jax.tree_util import (tree_flatten, tree_unflatten,
register_pytree_node, Partial)
from jax._src import core

View File

@ -22,7 +22,7 @@ from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
import numpy as np
import jax
from jax.config import config
from jax import config
from jax._src import core
from jax._src import source_info_util
from jax._src import linear_util as lu

View File

@ -21,7 +21,7 @@ import operator
from typing import Callable, Sequence, List, Tuple
from jax.config import config
from jax import config
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import core

View File

@ -22,7 +22,7 @@ import jax
import weakref
from jax._src import core
from jax._src import linear_util as lu
from jax.config import config
from jax import config
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map, tree_flatten_with_path, keystr)

View File

@ -21,7 +21,7 @@ import tracemalloc as tm
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.config import config
from jax import config
from jax._src import array
from jax.sharding import NamedSharding, GSPMDSharding
from jax.sharding import PartitionSpec as P

View File

@ -507,7 +507,7 @@ import warnings
from jax._src import api
from jax._src import core
from jax.config import config
from jax import config
from jax import custom_derivatives
from jax._src import dtypes
from jax import lax

View File

@ -17,7 +17,7 @@ from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from jax._src import test_util as jtu
from jax.config import config
from jax import config
from jax.experimental.jax2tf.examples import keras_reuse_main
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -18,7 +18,7 @@ from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from jax._src import test_util as jtu
from jax.config import config
from jax import config
from jax.experimental.jax2tf.examples import saved_model_main
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -77,7 +77,7 @@ import numpy as np
from numpy import array, float32
import jax
from jax.config import config
from jax import config
from jax import lax
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft

View File

@ -26,7 +26,7 @@ from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
from jax import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax._src.lib.mlir import ir

View File

@ -23,7 +23,7 @@ import numpy as np
from jax.experimental.jax2tf.tests import tf_test_util
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -36,7 +36,7 @@ from absl import logging
import numpy.random as npr
import jax
from jax.config import config # Must import before TF
from jax import config # Must import before TF
from jax.experimental import jax2tf # Defines needed flags
from jax._src import test_util # Defines needed flags

View File

@ -19,7 +19,7 @@ from typing import Any, Callable, Optional, Sequence, Union
import jax
from jax import lax
from jax import numpy as jnp
from jax.config import config
from jax import config
from jax._src import test_util as jtu
from jax._src import dtypes
from jax.experimental.jax2tf.tests import primitive_harness

View File

@ -35,7 +35,7 @@ from jax._src import core
from jax._src import source_info_util
from jax._src import test_util as jtu
import jax._src.xla_bridge
from jax.config import config
from jax import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf import jax_export
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -28,7 +28,7 @@ import unittest
from absl.testing import absltest
from jax._src import test_util as jtu
from jax.config import config
from jax import config
import numpy as np

View File

@ -64,7 +64,7 @@ import jax
from jax import dtypes
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
from jax import config
from jax.experimental import jax2tf
from jax.interpreters import mlir
from jax._src.interpreters import xla

View File

@ -25,7 +25,7 @@ from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -44,7 +44,7 @@ from jax.experimental.jax2tf.tests import tf_test_util
import tensorflow as tf # type: ignore[import]
from jax.config import config
from jax import config
from jax._src.config import numpy_dtype_promotion
config.parse_flags_with_absl()

View File

@ -31,7 +31,7 @@ from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax.config import config
from jax import config
from jax import lax
from jax.experimental import jax2tf
from jax.experimental import pjit

View File

@ -28,7 +28,7 @@ from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
from jax.config import config
from jax import config
from jax.experimental import jax2tf
from jax._src import xla_bridge
import numpy as np

View File

@ -28,7 +28,7 @@ import jax
from jax import lax
from jax import tree_util
from jax import vmap
from jax.config import config
from jax import config
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import (
nfold_vmap, _count_stored_elements,

View File

@ -74,5 +74,5 @@ from jax._src.interpreters.ad import (
zeros_like_p as zeros_like_p,
)
from jax.config import config
from jax import config
from jax._src import source_info_util

View File

@ -23,7 +23,7 @@ import jax
from jax import lax
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -31,7 +31,7 @@ from jax.experimental.serialize_executable import (
from jax.experimental import topologies
from jax.sharding import PartitionSpec as P
from jax.config import config
from jax import config
config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -75,7 +75,7 @@ import jax._src.util as jax_util
from jax._src.ad_checkpoint import saved_residuals
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -20,7 +20,7 @@ from jax._src import api_util
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -17,7 +17,7 @@ import unittest
from absl.testing import absltest
import jax
from jax.config import config
from jax import config
import jax.dlpack
import jax.numpy as jnp
from jax._src import test_util as jtu

View File

@ -37,7 +37,7 @@ from jax.sharding import PartitionSpec as P
from jax._src import array
from jax._src import prng
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -35,7 +35,7 @@ from jax import vmap
from jax.interpreters import batching
from jax.tree_util import register_pytree_node
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -23,7 +23,7 @@ import jax
from jax import lax
import jax._src.test_util as jtu
from jax._src.lib import xla_extension
from jax.config import config
from jax import config
from jax.experimental import checkify
from jax.experimental import pjit
from jax.sharding import NamedSharding

View File

@ -18,7 +18,7 @@ import unittest
from absl.testing import absltest
import jax
from jax.config import config
from jax import config
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb

View File

@ -43,7 +43,7 @@ from jax._src.lib import xla_client
import numpy as np
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -28,7 +28,7 @@ from jax import lax
from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs
from jax.config import config
from jax import config
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
tree_leaves)

View File

@ -29,7 +29,7 @@ from jax import tree_util
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -20,7 +20,7 @@ import numpy as np
import jax.numpy as jnp
from jax import jit, lax, make_jaxpr
from jax.config import config
from jax import config
from jax.interpreters import mlir
from jax.interpreters import xla

View File

@ -26,7 +26,7 @@ from jax import tree_util
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -25,7 +25,7 @@ from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit, maps
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -20,7 +20,7 @@ from typing import IO, Sequence, Tuple
from absl.testing import absltest
import jax
from jax.config import config
from jax import config
from jax.experimental import pjit
from jax._src import debugger
from jax._src import test_util as jtu

View File

@ -19,7 +19,7 @@ import unittest
from absl.testing import absltest
import jax
from jax import lax
from jax.config import config
from jax import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import pxla

View File

@ -23,7 +23,7 @@ from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import lax
from jax.config import config
from jax import config
from jax.interpreters import batching
import jax._src.lib

View File

@ -28,7 +28,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -27,7 +27,7 @@ from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.numpy.util import promote_dtypes_complex
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FFT_NORMS = [None, "ortho", "forward", "backward"]

View File

@ -24,7 +24,7 @@ from jax._src import test_util as jtu
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
from jax.config import config
from jax import config
config.parse_flags_with_absl()
def remat_of_for_loop(nsteps, body, state, **kwargs):

View File

@ -25,7 +25,7 @@ import jax.numpy as jnp
from jax import jit, jvp, vjp
import jax._src.test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()
npr.seed(0)

View File

@ -17,7 +17,7 @@ from absl.testing import absltest
import jax
import jax._src.xla_bridge as xla_bridge
from jax.config import config
from jax import config
import jax._src.test_util as jtu

View File

@ -28,7 +28,7 @@ from absl.testing import absltest
import jax
from jax import ad_checkpoint
from jax._src import core
from jax.config import config
from jax import config
from jax import dtypes
from jax.experimental import host_callback as hcb
from jax.sharding import PartitionSpec as P

View File

@ -25,7 +25,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax.config import config
from jax import config
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax._src import xla_bridge

View File

@ -24,7 +24,7 @@ from jax import image
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
from jax import config
# We use TensorFlow and PIL as reference implementations.
try:

View File

@ -24,7 +24,7 @@ from jax import dtypes
from jax._src import lib as jaxlib
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
from jax import config
import numpy as np
config.parse_flags_with_absl()

View File

@ -23,7 +23,7 @@ from jax._src import core
from jax import lax
from jax._src import effects
from jax._src import linear_util as lu
from jax.config import config
from jax import config
from jax.experimental import maps
from jax.experimental import pjit
from jax._src.interpreters import ad

View File

@ -22,7 +22,7 @@ import jax
from jax import jaxpr_util, jit, make_jaxpr, numpy as jnp
from jax._src.lib import xla_client
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -29,7 +29,7 @@ from jax.example_libraries import stax
from jax.experimental.jet import jet, fact, zero_series
from jax import lax
from jax.config import config
from jax import config
config.parse_flags_with_absl()
def jvp_taylor(fun, primals, series):

View File

@ -30,7 +30,7 @@ from jax import lax
from jax._src import test_util as jtu
from jax.test_util import check_grads
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -41,7 +41,7 @@ import jax.scipy as jsp
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax.control_flow import for_loop
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -26,7 +26,7 @@ from jax import lax
import jax.numpy as jnp
import jax._src.test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -35,7 +35,7 @@ from jax._src import test_util as jtu
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax.config import config
from jax import config
config.parse_flags_with_absl()
# We disable the whitespace continuation check in this file because otherwise it

View File

@ -33,7 +33,7 @@ from jax import numpy as jnp
from jax._src import dtypes
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -29,7 +29,7 @@ from jax import numpy as jnp
from jax._src import dtypes
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -49,7 +49,7 @@ from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import safe_zip
from jax._src import array
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -20,7 +20,7 @@ import jax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -30,7 +30,7 @@ import jax._src.scipy.sparse.linalg as sp_linalg
from jax._src import dtypes
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -26,7 +26,7 @@ import jax
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -21,7 +21,7 @@ from jax._src.lax import eigh as lax_eigh
from absl.testing import absltest
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -32,7 +32,7 @@ from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
from jax.scipy import cluster as lsp_cluster
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -47,7 +47,7 @@ from jax._src import lax_reference
from jax._src.lax import lax as lax_internal
from jax._src.internal_test_util import lax_test_util
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -26,7 +26,7 @@ from jax._src import test_util as jtu
from jax._src.internal_test_util import lax_test_util
from jax._src import util
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -34,7 +34,7 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.util import safe_map, safe_zip
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -32,7 +32,7 @@ from jax import scipy as jsp
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

View File

@ -30,7 +30,7 @@ import scipy.linalg as sla
import scipy.sparse as sps
import jax
from jax.config import config
from jax import config
from jax._src import test_util as jtu
from jax.experimental.sparse import linalg, bcoo
import jax.numpy as jnp

View File

@ -23,7 +23,7 @@ from jax._src import config as jax_config
from jax._src.lib.mlir import ir
from jax import numpy as jnp
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -24,7 +24,7 @@ from jax import lax
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax.config import config
from jax import config
config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -25,7 +25,7 @@ import jax
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.config import config
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
npr.seed(0)

View File

@ -26,7 +26,7 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax.config import config
from jax import config
from jax._src import core
from jax._src import distributed
import jax.numpy as jnp

View File

@ -20,7 +20,7 @@ from jax._src import core
from jax import lax
from jax._src.pjit import pjit
from jax._src import linear_util as lu
from jax.config import config
from jax import config
from jax._src import test_util as jtu
from jax._src.lib import xla_client

View File

@ -31,7 +31,7 @@ from jax import random
import jax
import jax.numpy as jnp
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -25,7 +25,7 @@ from jax.tree_util import tree_map
import scipy.integrate as osp_integrate
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -26,7 +26,7 @@ from jax import tree_util
from jax import lax
from jax.example_libraries import optimizers
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -26,7 +26,7 @@ except ImportError:
import jax
from jax import numpy as jnp
from jax.config import config
from jax import config
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import test_util as jtu

View File

@ -58,7 +58,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2, safe_zip
from jax.config import config
from jax import config
config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -55,7 +55,7 @@ from jax._src import array
from jax._src.sharding_impls import PmapSharding
from jax.ad_checkpoint import checkpoint as new_checkpoint
from jax.config import config
from jax import config
config.parse_flags_with_absl()
prev_xla_flags = None

View File

@ -23,7 +23,7 @@ from jax._src import dtypes
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -25,7 +25,7 @@ from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.profiler
from jax.config import config
from jax import config
import jax._src.test_util as jtu
try:

View File

@ -29,7 +29,7 @@ from jax._src import test_util as jtu
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax.config import config
from jax import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import mlir

View File

@ -17,7 +17,7 @@ import unittest
from absl.testing import absltest
import jax
from jax.config import config
from jax import config
import jax.dlpack
from jax._src import xla_bridge
from jax._src.lib import xla_client

View File

@ -16,7 +16,7 @@
import functools
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import numpy as np
import scipy.linalg as osp_linalg

View File

@ -42,7 +42,7 @@ from jax.interpreters import xla
from jax._src import random as jax_random
from jax._src import prng as prng_internal
from jax.config import config
from jax import config
config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.all_floating

View File

@ -19,7 +19,7 @@ from jax._src import test_util as jtu
import jax.scipy.fft as jsp_fft
import scipy.fft as osp_fft
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -22,7 +22,7 @@ from jax._src import test_util as jtu
import scipy.interpolate as sp_interp
import jax.scipy.interpolate as jsp_interp
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -26,7 +26,7 @@ from jax._src import test_util as jtu
from jax import dtypes
from jax.scipy import ndimage as lsp_ndimage
from jax.config import config
from jax import config
config.parse_flags_with_absl()

View File

@ -20,7 +20,7 @@ import scipy.optimize
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import jit
from jax.config import config
from jax import config
import jax.scipy.optimize
config.parse_flags_with_absl()

View File

@ -27,7 +27,7 @@ from jax._src import dtypes
from jax._src import test_util as jtu
import jax.scipy.signal as jsp_signal
from jax.config import config
from jax import config
config.parse_flags_with_absl()
onedim_shapes = [(1,), (2,), (5,), (10,)]

Some files were not shown because too many files have changed in this diff Show More