Small cleanups to dependency structure.

PiperOrigin-RevId: 352853244
This commit is contained in:
Peter Hawkins 2021-01-20 12:43:00 -08:00 committed by jax authors
parent e217a7bd19
commit 929a684a39
13 changed files with 16 additions and 16 deletions

View File

@ -26,7 +26,7 @@ from jax import pmap
from jax.config import config
from jax._src.util import prod
from benchmarks import benchmark
from . import benchmark
import numpy as np

View File

@ -30,7 +30,7 @@ from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
from examples import datasets
from . import datasets
def loss(params, batch):

View File

@ -25,7 +25,7 @@ import numpy.random as npr
from jax.api import jit, grad
from jax.scipy.special import logsumexp
import jax.numpy as jnp
from examples import datasets
from . import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):

View File

@ -30,7 +30,7 @@ from jax import jit, grad, lax, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
from examples import datasets
from . import datasets
def gaussian_kl(mu, sigmasq):

View File

@ -20,7 +20,7 @@ licenses(["notice"])
package(default_visibility = ["//visibility:public"])
# top-level EF placeholder
exports_files(["LICENSE"])
pytype_library(
name = "jax",

View File

@ -347,13 +347,13 @@ from jax import custom_derivatives
from jax import dtypes
from jax import lax
from jax.lib import pytree
from jax.lib import xla_client
from jax.lib import xla_extension
from jax.interpreters import ad, xla, batching, masking, pxla
from jax.interpreters import partial_eval as pe
from jax._src import pprint_util as ppu
from jax._src import source_info_util
from jax._src import util
from jaxlib import xla_client
from jaxlib import xla_extension
import numpy as np

View File

@ -46,7 +46,7 @@ import tensorflow as tf # type: ignore[import]
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from jaxlib import xla_client
from jax.lib import xla_client
# The scope name need to be a valid TensorFlow name. See

View File

@ -54,7 +54,7 @@ from jax import numpy as jnp
from jax._src.lax import control_flow as lax_control_flow
from jax.interpreters import xla
from jaxlib import xla_client
from jax.lib import xla_client
import numpy as np

View File

@ -53,6 +53,7 @@ _check_jaxlib_version()
from jaxlib import xla_client
from jaxlib import lapack
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit

View File

@ -22,7 +22,7 @@ import numpy as np
from jaxlib import xla_client
try:
from jaxlib import cuda_prng_kernels
from . import cuda_prng_kernels
for _name, _value in cuda_prng_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:

View File

@ -21,14 +21,14 @@ import numpy as np
from jaxlib import xla_client
try:
from jaxlib import cublas_kernels
from . import cublas_kernels
for _name, _value in cublas_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
pass
try:
from jaxlib import cusolver_kernels
from . import cusolver_kernels
for _name, _value in cusolver_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:

View File

@ -14,8 +14,8 @@
from typing import List
from jaxlib import _pocketfft
from jaxlib import pocketfft_flatbuffers_py_generated as pd
from . import _pocketfft
from . import pocketfft_flatbuffers_py_generated as pd
import numpy as np
import flatbuffers

View File

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax.lib import xla_client
import jax.numpy as jnp
from jax.tools.jax_to_hlo import jax_to_hlo
from jax.lib import xla_client
from jax import test_util as jtu