Revert import path changes to examples/ and benchmarks/

PiperOrigin-RevId: 352911869
This commit is contained in:
Peter Hawkins 2021-01-20 17:35:30 -08:00 committed by jax authors
parent ffa05d1cc8
commit 160dfd343a
4 changed files with 5 additions and 5 deletions

View File

@ -26,7 +26,7 @@ from jax import pmap
from jax.config import config
from jax._src.util import prod
from . import benchmark
from benchmarks import benchmark
import numpy as np

View File

@ -30,7 +30,7 @@ from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
from . import datasets
from examples import datasets
def loss(params, batch):

View File

@ -22,10 +22,10 @@ import time
import numpy.random as npr
from jax.api import jit, grad
from jax import jit, grad
from jax.scipy.special import logsumexp
import jax.numpy as jnp
from . import datasets
from examples import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):

View File

@ -30,7 +30,7 @@ from jax import jit, grad, lax, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
from . import datasets
from examples import datasets
def gaussian_kl(mu, sigmasq):