diff --git a/examples/BUILD b/examples/BUILD index 230e9ceb7..63c44f98f 100644 --- a/examples/BUILD +++ b/examples/BUILD @@ -25,12 +25,23 @@ py_library( srcs = ["datasets.py"], ) +py_binary( + name = "mnist_classifier_fromscratch", + srcs = ["mnist_classifier_fromscratch.py"], + deps = [ + ":datasets", + "//jax:libjax", + ], +) + py_binary( name = "mnist_classifier", srcs = ["mnist_classifier.py"], deps = [ ":datasets", "//jax:libjax", + "//jax:minmax", + "//jax:stax", ], ) @@ -44,3 +55,14 @@ py_binary( "//jax:stax", ], ) + +py_binary( + name = "resnet50", + srcs = ["resnet50.py"], + deps = [ + ":datasets", + "//jax:libjax", + "//jax:minmax", + "//jax:stax", + ], +) diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index 52f1e3df2..544d58b4d 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -29,9 +29,9 @@ import numpy.random as npr import jax.numpy as np from jax import jit, grad from jax.experimental import minmax -import datasets from jax.experimental import stax from jax.experimental.stax import Dense, Relu, LogSoftmax +import datasets def loss(params, batch): diff --git a/examples/mnist_classifier_fromscratch.py b/examples/mnist_classifier_fromscratch.py index 03d21538a..6aa3a6cd9 100644 --- a/examples/mnist_classifier_fromscratch.py +++ b/examples/mnist_classifier_fromscratch.py @@ -27,9 +27,9 @@ from absl import app import numpy.random as npr from jax.api import jit, grad -import datasets from jax.scipy.misc import logsumexp import jax.numpy as np +import datasets def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index f15e3dd30..d3fc4680b 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -30,10 +30,10 @@ import matplotlib.pyplot as plt import jax.numpy as np from jax import jit, grad, lax, random -import datasets from jax.experimental import minmax from jax.experimental import stax from jax.experimental.stax import Dense, FanOut, Relu, Softplus +import datasets def gaussian_kl(mu, sigmasq):