mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice. PiperOrigin-RevId: 398721725
This commit is contained in:
parent
b26e1e6ba6
commit
db2e91eba2
@ -18,7 +18,7 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
from examples import control
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
1167
jax/_src/test_util.py
Normal file
1167
jax/_src/test_util.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -16,7 +16,7 @@ import os
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
from jax.experimental.jax2tf.examples import keras_reuse_main
|
||||
|
@ -17,7 +17,7 @@ import os
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
from jax.experimental.jax2tf.examples import saved_model_main
|
||||
|
@ -23,7 +23,7 @@ import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -18,7 +18,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import numpy as np
|
||||
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -18,7 +18,7 @@ from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import dtypes
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
import numpy as np
|
||||
|
@ -25,7 +25,7 @@ import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -27,7 +27,7 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
import numpy as np
|
||||
|
@ -48,7 +48,7 @@ import jax
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax._src import ad_util
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
|
@ -62,7 +62,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.interpreters import xla
|
||||
|
@ -23,7 +23,7 @@ import tensorflow as tf # type: ignore[import]
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -28,7 +28,7 @@ from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
import numpy as np
|
||||
|
||||
|
@ -22,7 +22,7 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
|
@ -15,7 +15,7 @@
|
||||
import functools
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import numpy as np
|
||||
import os
|
||||
import unittest
|
||||
|
@ -23,7 +23,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
||||
from jax.config import config
|
||||
|
1167
jax/test_util.py
1167
jax/test_util.py
File diff suppressed because it is too large
Load Diff
@ -49,7 +49,7 @@ from jax.interpreters import pxla
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
import jax._src.lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax import linear_util as lu
|
||||
import jax._src.util
|
||||
|
@ -18,7 +18,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax._src import api_util
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -21,7 +21,7 @@ from jax.config import config
|
||||
import jax.dlpack
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -21,7 +21,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import lax
|
||||
from jax._src.lax import parallel
|
||||
from jax import random
|
||||
|
@ -16,7 +16,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.callback import find_by_value, rewrite, FoundValue
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
@ -28,7 +28,7 @@ from jax.experimental.pjit import pjit
|
||||
import jax
|
||||
from jax import jit, lax, pmap
|
||||
from jax._src.util import prod
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
import jax._src.lib
|
||||
import numpy as np
|
||||
|
||||
|
@ -27,7 +27,7 @@ import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.abstract_arrays import make_shaped_array
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.core import UnshapedArray, ShapedArray
|
||||
|
@ -16,7 +16,7 @@ from absl.testing import absltest, parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax import core, jit, lax, make_jaxpr
|
||||
from jax.interpreters import xla
|
||||
|
@ -21,7 +21,7 @@ import numpy as np
|
||||
from unittest import SkipTest
|
||||
|
||||
from jax._src import api
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
import jax._src.lib
|
||||
|
@ -19,7 +19,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.util import safe_map, safe_zip
|
||||
|
||||
from jax.experimental import djax
|
||||
|
@ -26,7 +26,7 @@ import jax
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax.config import config
|
||||
|
@ -21,7 +21,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import core, grad, jit, vmap, lax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
|
||||
|
@ -23,7 +23,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
|
@ -23,7 +23,7 @@ from absl.testing import parameterized
|
||||
import itertools as it
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, jvp, vjp
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -18,7 +18,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax._src.lib.xla_bridge
|
||||
from jax.config import config
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -35,7 +35,7 @@ from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.lib import xla_bridge
|
||||
|
||||
|
@ -27,7 +27,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import host_callback as hcb
|
||||
|
||||
import numpy as np
|
||||
|
@ -24,7 +24,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import image
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
|
||||
|
@ -21,7 +21,7 @@ from jax import lax, numpy as jnp
|
||||
from jax import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax._src.lib import xla_client
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -21,7 +21,7 @@ from jax._src import api
|
||||
from jax import dtypes
|
||||
from jax._src import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
import numpy as np
|
||||
|
||||
|
@ -16,7 +16,7 @@ from absl.testing import absltest
|
||||
from jax._src.lib import xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax.tools.jax_to_hlo import jax_to_hlo
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
|
||||
class JaxToHloTest(absltest.TestCase):
|
||||
|
@ -15,7 +15,7 @@
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import jaxpr_util, jit, make_jaxpr, numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
import unittest
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy.special
|
||||
from jax import random
|
||||
|
@ -26,7 +26,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.test_util import check_grads
|
||||
from jax._src.util import prod
|
||||
|
||||
|
@ -32,7 +32,7 @@ from jax import core
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.util import unzip2
|
||||
from jax.experimental import maps
|
||||
|
@ -24,7 +24,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -30,7 +30,7 @@ import jax
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import ops
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
|
||||
from jax.config import config
|
||||
|
@ -37,7 +37,7 @@ import jax
|
||||
import jax.ops
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from jax.interpreters import xla
|
||||
|
@ -19,7 +19,7 @@ from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -23,7 +23,7 @@ import scipy.sparse.linalg
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
import jax.scipy.sparse.linalg
|
||||
import jax._src.scipy.sparse.linalg
|
||||
|
@ -29,7 +29,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import lax
|
||||
from jax import scipy as jsp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
import jax._src.scipy.eigh
|
||||
|
||||
|
@ -29,7 +29,7 @@ import jax
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src import lax_reference
|
||||
from jax.test_util import check_grads
|
||||
|
@ -26,7 +26,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
|
@ -29,7 +29,7 @@ from jax import jit, grad, jvp, vmap
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -22,7 +22,7 @@ import re
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import loops
|
||||
|
||||
from jax.config import config
|
||||
|
@ -20,7 +20,7 @@ from absl.testing import absltest, parameterized
|
||||
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax.tree_util import tree_flatten
|
||||
|
@ -15,7 +15,7 @@
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
@ -21,7 +21,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax.interpreters import xla
|
||||
|
||||
|
@ -23,7 +23,7 @@ import numpy.random as npr
|
||||
from unittest import SkipTest
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
|
@ -24,7 +24,7 @@ from absl.testing import parameterized
|
||||
import scipy.stats
|
||||
|
||||
from jax import core
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.test_util import check_grads
|
||||
from jax import nn
|
||||
from jax import random
|
||||
|
@ -18,7 +18,7 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.ode import odeint
|
||||
from jax.tree_util import tree_map
|
||||
|
@ -20,7 +20,7 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
from jax import jit, grad, jacfwd, jacrev
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
|
@ -26,7 +26,7 @@ except ImportError:
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import jax._src.lib
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -24,7 +24,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
# TODO(skye): do we still wanna call this PartitionSpec?
|
||||
|
@ -30,7 +30,7 @@ from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax._src.lax import parallel
|
||||
|
@ -19,8 +19,9 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import jit
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu, jit
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -25,7 +25,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
from jax.config import config
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
|
@ -32,7 +32,7 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import prng
|
||||
from jax import random
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax.interpreters import xla
|
||||
import jax._src.random
|
||||
|
@ -15,7 +15,7 @@ import itertools
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import jax.scipy.fft as jsp_fft
|
||||
import scipy.fftpack as osp_fft # TODO use scipy.fft once scipy>=1.4.0 is used
|
||||
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
import scipy.ndimage as osp_ndimage
|
||||
|
||||
from jax import grad
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax.scipy import ndimage as lsp_ndimage
|
||||
from jax._src.util import prod
|
||||
|
@ -17,7 +17,7 @@ import numpy as np
|
||||
import scipy.optimize
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import jit
|
||||
from jax.config import config
|
||||
import jax.scipy.optimize
|
||||
|
@ -20,7 +20,7 @@ from absl.testing import absltest, parameterized
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
import jax.scipy.signal as jsp_signal
|
||||
import scipy.signal as osp_signal
|
||||
|
||||
|
@ -22,7 +22,7 @@ import scipy as osp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import stats as lsp_stats
|
||||
from jax.scipy.special import expit
|
||||
|
||||
|
@ -27,7 +27,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import jit, pmap, vjp
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax.experimental import (sharded_jit, with_sharding_constraint,
|
||||
PartitionSpec as P)
|
||||
|
@ -28,7 +28,7 @@ from jax import lax
|
||||
from jax._src.lib import cusparse
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax import jit
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import xla
|
||||
import jax.numpy as jnp
|
||||
from jax import jvp
|
||||
|
@ -22,7 +22,7 @@ import numpy as np
|
||||
|
||||
from jax import config, core, jit, lax
|
||||
import jax.numpy as jnp
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
from jax.experimental.sparse import BCOO, sparsify
|
||||
from jax.experimental.sparse.transform import (
|
||||
arrays_to_argspecs, argspecs_to_arrays, sparsify_raw, ArgSpec, SparseEnv)
|
||||
|
@ -19,7 +19,7 @@ from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import random
|
||||
from jax.experimental import stax
|
||||
from jax import dtypes
|
||||
|
2
tests/third_party/scipy/line_search_test.py
vendored
2
tests/third_party/scipy/line_search_test.py
vendored
@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized
|
||||
from jax import grad
|
||||
from jax.config import config
|
||||
import jax.numpy as jnp
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src.scipy.optimize.line_search import line_search
|
||||
from scipy.optimize.linesearch import line_search_wolfe2
|
||||
|
||||
|
@ -20,7 +20,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.tree_util import _process_pytree
|
||||
from jax import flatten_util
|
||||
|
@ -15,7 +15,7 @@
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import linear_util as lu
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -27,7 +27,7 @@ from jax import random
|
||||
from jax.config import config
|
||||
from jax.experimental import enable_x64, disable_x64
|
||||
import jax.numpy as jnp
|
||||
import jax.test_util as jtu
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
@ -16,7 +16,7 @@ import time
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax.interpreters import xla
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ from functools import partial
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jscipy
|
||||
from jax import test_util as jtu
|
||||
from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax import core
|
||||
|
Loading…
x
Reference in New Issue
Block a user