mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Revert import path changes to examples/ and benchmarks/
PiperOrigin-RevId: 352911869
This commit is contained in:
parent
ffa05d1cc8
commit
160dfd343a
@ -26,7 +26,7 @@ from jax import pmap
|
||||
from jax.config import config
|
||||
from jax._src.util import prod
|
||||
|
||||
from . import benchmark
|
||||
from benchmarks import benchmark
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -30,7 +30,7 @@ from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
from . import datasets
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
|
@ -22,10 +22,10 @@ import time
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
from jax import jit, grad
|
||||
from jax.scipy.special import logsumexp
|
||||
import jax.numpy as jnp
|
||||
from . import datasets
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
|
@ -30,7 +30,7 @@ from jax import jit, grad, lax, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
from . import datasets
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
|
Loading…
x
Reference in New Issue
Block a user