Move jax.test_util to jax._src.test_util.

Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
This commit is contained in:
Peter Hawkins 2021-09-24 07:02:08 -07:00 committed by jax authors
parent b26e1e6ba6
commit db2e91eba2
83 changed files with 1264 additions and 1233 deletions

View File

@ -18,7 +18,7 @@ from absl.testing import absltest
import numpy as np
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax.numpy as jnp
from examples import control

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
import numpy as np
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import random
import jax.numpy as jnp

1167
jax/_src/test_util.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -16,7 +16,7 @@ import os
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
from jax.experimental.jax2tf.examples import keras_reuse_main

View File

@ -17,7 +17,7 @@ import os
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
from jax.experimental.jax2tf.examples import saved_model_main

View File

@ -23,7 +23,7 @@ import jax
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -18,7 +18,7 @@ from absl.testing import absltest
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
import numpy as np
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -18,7 +18,7 @@ from typing import Any, Callable, Optional, Sequence, Union
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src import dtypes
from jax.experimental.jax2tf.tests import primitive_harness
import numpy as np

View File

@ -25,7 +25,7 @@ import jax
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -27,7 +27,7 @@ import unittest
from absl.testing import absltest
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
import numpy as np

View File

@ -48,7 +48,7 @@ import jax
from jax import config
from jax import dtypes
from jax._src import ad_util
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import lax
from jax import numpy as jnp
from jax._src.lax import control_flow as lax_control_flow

View File

@ -62,7 +62,7 @@ from absl.testing import parameterized
import jax
from jax import dtypes
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
from jax.experimental import jax2tf
from jax.interpreters import xla

View File

@ -23,7 +23,7 @@ import tensorflow as tf # type: ignore[import]
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -28,7 +28,7 @@ from jax.experimental import jax2tf
from jax.experimental.jax2tf import shape_poly
from jax import lax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src.lax import control_flow as lax_control_flow
import numpy as np

View File

@ -22,7 +22,7 @@ import unittest
from absl.testing import absltest
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
from jax.experimental import jax2tf

View File

@ -15,7 +15,7 @@
import functools
from absl.testing import absltest
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
import numpy as np
import os
import unittest

View File

@ -23,7 +23,7 @@ from absl.testing import absltest
import jax
from jax import dtypes
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax.config import config

File diff suppressed because it is too large Load Diff

View File

@ -49,7 +49,7 @@ from jax.interpreters import pxla
from jax.interpreters.sharded_jit import PartitionSpec as P
import jax._src.lib
from jax._src.lib import xla_client
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax import linear_util as lu
import jax._src.util

View File

@ -18,7 +18,7 @@ from absl.testing import absltest
from absl.testing import parameterized
from jax._src import api_util
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -21,7 +21,7 @@ from jax.config import config
import jax.dlpack
from jax._src.lib import xla_bridge, xla_client
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
import numpy as np

View File

@ -21,7 +21,7 @@ from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import lax
from jax._src.lax import parallel
from jax import random

View File

@ -16,7 +16,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.experimental.callback import find_by_value, rewrite, FoundValue
import jax.numpy as jnp
from jax import lax

View File

@ -28,7 +28,7 @@ from jax.experimental.pjit import pjit
import jax
from jax import jit, lax, pmap
from jax._src.util import prod
import jax.test_util as jtu
import jax._src.test_util as jtu
import jax._src.lib
import numpy as np

View File

@ -27,7 +27,7 @@ import jax
from jax import core
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src.abstract_arrays import make_shaped_array
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray

View File

@ -16,7 +16,7 @@ from absl.testing import absltest, parameterized
import numpy as np
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax import core, jit, lax, make_jaxpr
from jax.interpreters import xla

View File

@ -21,7 +21,7 @@ import numpy as np
from unittest import SkipTest
from jax._src import api
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit
import jax._src.lib

View File

@ -19,7 +19,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.util import safe_map, safe_zip
from jax.experimental import djax

View File

@ -26,7 +26,7 @@ import jax
from jax._src import dtypes
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.interpreters import xla
from jax.config import config

View File

@ -21,7 +21,7 @@ from absl.testing import parameterized
import jax
from jax import core, grad, jit, vmap, lax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src import source_info_util
from jax._src import traceback_util

View File

@ -23,7 +23,7 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -14,7 +14,7 @@
from absl.testing import absltest
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
import jax.test_util as jtu
import jax._src.test_util as jtu
import tempfile
import threading
import time

View File

@ -23,7 +23,7 @@ from absl.testing import parameterized
import itertools as it
import jax.numpy as jnp
from jax import jit, jvp, vjp
import jax.test_util as jtu
import jax._src.test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -18,7 +18,7 @@ from absl.testing import absltest
import jax
import jax._src.lib.xla_bridge
from jax.config import config
import jax.test_util as jtu
import jax._src.test_util as jtu
config.parse_flags_with_absl()

View File

@ -35,7 +35,7 @@ from jax.experimental import maps
from jax.experimental import pjit
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.lib import xla_bridge

View File

@ -27,7 +27,7 @@ from absl.testing import parameterized
import jax
from jax.config import config
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.experimental import host_callback as hcb
import numpy as np

View File

@ -24,7 +24,7 @@ from absl.testing import parameterized
import jax
from jax import image
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config

View File

@ -21,7 +21,7 @@ from jax import lax, numpy as jnp
from jax import config
from jax.experimental import host_callback as hcb
from jax._src.lib import xla_client
import jax.test_util as jtu
import jax._src.test_util as jtu
import numpy as np
config.parse_flags_with_absl()

View File

@ -21,7 +21,7 @@ from jax._src import api
from jax import dtypes
from jax._src import lib as jaxlib
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
import numpy as np

View File

@ -16,7 +16,7 @@ from absl.testing import absltest
from jax._src.lib import xla_client
import jax.numpy as jnp
from jax.tools.jax_to_hlo import jax_to_hlo
from jax import test_util as jtu
from jax._src import test_util as jtu
class JaxToHloTest(absltest.TestCase):

View File

@ -15,7 +15,7 @@
from absl.testing import absltest
from jax import jaxpr_util, jit, make_jaxpr, numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config

View File

@ -20,7 +20,7 @@ import numpy as np
import unittest
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax.numpy as jnp
import jax.scipy.special
from jax import random

View File

@ -26,7 +26,7 @@ import numpy as np
import jax
from jax import dtypes
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.test_util import check_grads
from jax._src.util import prod

View File

@ -32,7 +32,7 @@ from jax import core
from jax.errors import UnexpectedTracerError
from jax import lax
from jax import random
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.util import unzip2
from jax.experimental import maps

View File

@ -24,7 +24,7 @@ from absl.testing import parameterized
import jax
from jax import lax
import jax.numpy as jnp
import jax.test_util as jtu
import jax._src.test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -30,7 +30,7 @@ import jax
from jax import dtypes
from jax import numpy as jnp
from jax import ops
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src import util
from jax.config import config

View File

@ -37,7 +37,7 @@ import jax
import jax.ops
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src import dtypes
from jax import tree_util
from jax.interpreters import xla

View File

@ -19,7 +19,7 @@ from absl.testing import parameterized
import jax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -23,7 +23,7 @@ import scipy.sparse.linalg
from jax import jit
import jax.numpy as jnp
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.tree_util import register_pytree_node_class
import jax.scipy.sparse.linalg
import jax._src.scipy.sparse.linalg

View File

@ -29,7 +29,7 @@ import jax
from jax import numpy as jnp
from jax import lax
from jax import scipy as jsp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
import jax._src.scipy.eigh

View File

@ -29,7 +29,7 @@ import jax
from jax import core
from jax._src import dtypes
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax._src import lax_reference
from jax.test_util import check_grads

View File

@ -26,7 +26,7 @@ import numpy as np
import jax
from jax import dtypes
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src.lib import xla_client
from jax._src.util import safe_map, safe_zip

View File

@ -29,7 +29,7 @@ from jax import jit, grad, jvp, vmap
from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -22,7 +22,7 @@ import re
import jax
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.experimental import loops
from jax.config import config

View File

@ -20,7 +20,7 @@ from absl.testing import absltest, parameterized
from jax import lax
from jax import core
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
from jax._src.util import safe_map, safe_zip
from jax.tree_util import tree_flatten

View File

@ -15,7 +15,7 @@
from unittest import SkipTest
from absl.testing import absltest
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax
from jax import numpy as jnp

View File

@ -21,7 +21,7 @@ from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge
from jax.interpreters import xla

View File

@ -23,7 +23,7 @@ import numpy.random as npr
from unittest import SkipTest
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.config import config

View File

@ -24,7 +24,7 @@ from absl.testing import parameterized
import scipy.stats
from jax import core
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.test_util import check_grads
from jax import nn
from jax import random

View File

@ -18,7 +18,7 @@ from absl.testing import absltest
import numpy as np
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.experimental.ode import odeint
from jax.tree_util import tree_map

View File

@ -20,7 +20,7 @@ from absl.testing import absltest
import numpy as np
import jax.numpy as jnp
import jax.test_util as jtu
import jax._src.test_util as jtu
from jax import jit, grad, jacfwd, jacrev
from jax import tree_util
from jax import lax

View File

@ -26,7 +26,7 @@ except ImportError:
import jax
from jax import numpy as jnp
from jax.config import config
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax._src.lib
config.parse_flags_with_absl()

View File

@ -24,7 +24,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.errors import JAXTypeError
from jax import lax
# TODO(skye): do we still wanna call this PartitionSpec?

View File

@ -30,7 +30,7 @@ from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax import lax
from jax._src.lax import parallel

View File

@ -19,8 +19,9 @@ import unittest
from absl.testing import absltest
from absl.testing import parameterized
from jax import jit
from jax import numpy as jnp
from jax import test_util as jtu, jit
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -25,7 +25,7 @@ import jax
import jax.numpy as jnp
import jax.profiler
from jax.config import config
import jax.test_util as jtu
import jax._src.test_util as jtu
try:
import portpicker

View File

@ -32,7 +32,7 @@ from jax import lax
from jax import numpy as jnp
from jax import prng
from jax import random
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import vmap
from jax.interpreters import xla
import jax._src.random

View File

@ -15,7 +15,7 @@ import itertools
from absl.testing import absltest, parameterized
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax.scipy.fft as jsp_fft
import scipy.fftpack as osp_fft # TODO use scipy.fft once scipy>=1.4.0 is used

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
import scipy.ndimage as osp_ndimage
from jax import grad
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import dtypes
from jax.scipy import ndimage as lsp_ndimage
from jax._src.util import prod

View File

@ -17,7 +17,7 @@ import numpy as np
import scipy.optimize
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import jit
from jax.config import config
import jax.scipy.optimize

View File

@ -20,7 +20,7 @@ from absl.testing import absltest, parameterized
import numpy as np
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
import jax.scipy.signal as jsp_signal
import scipy.signal as osp_signal

View File

@ -22,7 +22,7 @@ import scipy as osp
import scipy.stats as osp_stats
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.scipy import stats as lsp_stats
from jax.scipy.special import expit

View File

@ -27,7 +27,7 @@ from absl.testing import parameterized
import jax
from jax import jit, pmap, vjp
from jax import lax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax.experimental import (sharded_jit, with_sharding_constraint,
PartitionSpec as P)

View File

@ -28,7 +28,7 @@ from jax import lax
from jax._src.lib import cusparse
from jax._src.lib import xla_bridge
from jax import jit
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import xla
import jax.numpy as jnp
from jax import jvp

View File

@ -22,7 +22,7 @@ import numpy as np
from jax import config, core, jit, lax
import jax.numpy as jnp
import jax.test_util as jtu
import jax._src.test_util as jtu
from jax.experimental.sparse import BCOO, sparsify
from jax.experimental.sparse.transform import (
arrays_to_argspecs, argspecs_to_arrays, sparsify_raw, ArgSpec, SparseEnv)

View File

@ -19,7 +19,7 @@ from absl.testing import parameterized
import numpy as np
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import random
from jax.experimental import stax
from jax import dtypes

View File

@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized
from jax import grad
from jax.config import config
import jax.numpy as jnp
import jax.test_util as jtu
import jax._src.test_util as jtu
from jax._src.scipy.optimize.line_search import line_search
from scipy.optimize.linesearch import line_search_wolfe2

View File

@ -20,7 +20,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.tree_util import _process_pytree
from jax import flatten_util

View File

@ -15,7 +15,7 @@
from absl.testing import absltest
from jax import linear_util as lu
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()

View File

@ -27,7 +27,7 @@ from jax import random
from jax.config import config
from jax.experimental import enable_x64, disable_x64
import jax.numpy as jnp
import jax.test_util as jtu
import jax._src.test_util as jtu
config.parse_flags_with_absl()

View File

@ -16,7 +16,7 @@ import time
import warnings
from absl.testing import absltest
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc

View File

@ -15,7 +15,7 @@
from absl.testing import absltest
import jax
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax.interpreters import xla

View File

@ -33,7 +33,7 @@ from functools import partial
import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax import test_util as jtu
from jax._src import test_util as jtu
from jax import vmap
from jax import lax
from jax import core