mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
source sync
PiperOrigin-RevId: 222448341
This commit is contained in:
parent
e5b76f4fde
commit
323be694a7
@ -25,12 +25,23 @@ py_library(
|
||||
srcs = ["datasets.py"],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "mnist_classifier_fromscratch",
|
||||
srcs = ["mnist_classifier_fromscratch.py"],
|
||||
deps = [
|
||||
":datasets",
|
||||
"//jax:libjax",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "mnist_classifier",
|
||||
srcs = ["mnist_classifier.py"],
|
||||
deps = [
|
||||
":datasets",
|
||||
"//jax:libjax",
|
||||
"//jax:minmax",
|
||||
"//jax:stax",
|
||||
],
|
||||
)
|
||||
|
||||
@ -44,3 +55,14 @@ py_binary(
|
||||
"//jax:stax",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "resnet50",
|
||||
srcs = ["resnet50.py"],
|
||||
deps = [
|
||||
":datasets",
|
||||
"//jax:libjax",
|
||||
"//jax:minmax",
|
||||
"//jax:stax",
|
||||
],
|
||||
)
|
||||
|
@ -29,9 +29,9 @@ import numpy.random as npr
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
import datasets
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
|
@ -27,9 +27,9 @@ from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
import datasets
|
||||
from jax.scipy.misc import logsumexp
|
||||
import jax.numpy as np
|
||||
import datasets
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
|
@ -30,10 +30,10 @@ import matplotlib.pyplot as plt
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad, lax, random
|
||||
import datasets
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
import datasets
|
||||
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
|
Loading…
x
Reference in New Issue
Block a user