source sync

PiperOrigin-RevId: 222448341
This commit is contained in:
Matthew Johnson 2018-11-21 12:50:47 -08:00 committed by Roy Frostig
parent e5b76f4fde
commit 323be694a7
4 changed files with 25 additions and 3 deletions

View File

@ -25,12 +25,23 @@ py_library(
srcs = ["datasets.py"],
)
py_binary(
name = "mnist_classifier_fromscratch",
srcs = ["mnist_classifier_fromscratch.py"],
deps = [
":datasets",
"//jax:libjax",
],
)
py_binary(
name = "mnist_classifier",
srcs = ["mnist_classifier.py"],
deps = [
":datasets",
"//jax:libjax",
"//jax:minmax",
"//jax:stax",
],
)
@ -44,3 +55,14 @@ py_binary(
"//jax:stax",
],
)
py_binary(
name = "resnet50",
srcs = ["resnet50.py"],
deps = [
":datasets",
"//jax:libjax",
"//jax:minmax",
"//jax:stax",
],
)

View File

@ -29,9 +29,9 @@ import numpy.random as npr
import jax.numpy as np
from jax import jit, grad
from jax.experimental import minmax
import datasets
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
import datasets
def loss(params, batch):

View File

@ -27,9 +27,9 @@ from absl import app
import numpy.random as npr
from jax.api import jit, grad
import datasets
from jax.scipy.misc import logsumexp
import jax.numpy as np
import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):

View File

@ -30,10 +30,10 @@ import matplotlib.pyplot as plt
import jax.numpy as np
from jax import jit, grad, lax, random
import datasets
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
import datasets
def gaussian_kl(mu, sigmasq):