mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
This commit is contained in:
parent
e217a7bd19
commit
929a684a39
@ -26,7 +26,7 @@ from jax import pmap
|
||||
from jax.config import config
|
||||
from jax._src.util import prod
|
||||
|
||||
from benchmarks import benchmark
|
||||
from . import benchmark
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -30,7 +30,7 @@ from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
from examples import datasets
|
||||
from . import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
|
@ -25,7 +25,7 @@ import numpy.random as npr
|
||||
from jax.api import jit, grad
|
||||
from jax.scipy.special import logsumexp
|
||||
import jax.numpy as jnp
|
||||
from examples import datasets
|
||||
from . import datasets
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
|
@ -30,7 +30,7 @@ from jax import jit, grad, lax, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
from examples import datasets
|
||||
from . import datasets
|
||||
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
|
@ -20,7 +20,7 @@ licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
# top-level EF placeholder
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
pytype_library(
|
||||
name = "jax",
|
||||
|
@ -347,13 +347,13 @@ from jax import custom_derivatives
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax.lib import pytree
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import xla_extension
|
||||
from jax.interpreters import ad, xla, batching, masking, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import pprint_util as ppu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import xla_extension
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -46,7 +46,7 @@ import tensorflow as tf # type: ignore[import]
|
||||
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
|
||||
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
|
||||
|
||||
from jaxlib import xla_client
|
||||
from jax.lib import xla_client
|
||||
|
||||
|
||||
# The scope name need to be a valid TensorFlow name. See
|
||||
|
@ -54,7 +54,7 @@ from jax import numpy as jnp
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jaxlib import xla_client
|
||||
from jax.lib import xla_client
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -53,6 +53,7 @@ _check_jaxlib_version()
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import lapack
|
||||
|
||||
xla_extension = xla_client._xla
|
||||
pytree = xla_client._xla.pytree
|
||||
jax_jit = xla_client._xla.jax_jit
|
||||
|
||||
|
@ -22,7 +22,7 @@ import numpy as np
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from jaxlib import cuda_prng_kernels
|
||||
from . import cuda_prng_kernels
|
||||
for _name, _value in cuda_prng_kernels.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
|
@ -21,14 +21,14 @@ import numpy as np
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from jaxlib import cublas_kernels
|
||||
from . import cublas_kernels
|
||||
for _name, _value in cublas_kernels.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from jaxlib import cusolver_kernels
|
||||
from . import cusolver_kernels
|
||||
for _name, _value in cusolver_kernels.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
|
@ -14,8 +14,8 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
from jaxlib import _pocketfft
|
||||
from jaxlib import pocketfft_flatbuffers_py_generated as pd
|
||||
from . import _pocketfft
|
||||
from . import pocketfft_flatbuffers_py_generated as pd
|
||||
import numpy as np
|
||||
|
||||
import flatbuffers
|
||||
|
@ -12,11 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.lib import xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax.tools.jax_to_hlo import jax_to_hlo
|
||||
from jax.lib import xla_client
|
||||
from jax import test_util as jtu
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user