source sync

PiperOrigin-RevId: 222153576
This commit is contained in:
Roy Frostig 2018-11-19 15:50:10 -08:00
parent 51fc713089
commit a3619ca89d
15 changed files with 148 additions and 63 deletions

View File

@ -1,5 +0,0 @@
JAX is a research project that grew out of [Autograd](https://github.com/hips/autograd).
Here's a [two-page abstract](https://www.sysml.cc/doc/146.pdf) about an early version.
Watch this space for updates!
This is not an official Google product.

View File

@ -55,7 +55,6 @@ then
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
export TF_NCCL_VERSION=2
export TF_NEED_CUDA=1
else
export TF_NEED_CUDA=0
@ -73,6 +72,7 @@ export TF_DOWNLOAD_CLANG=0
export TF_SET_ANDROID_WORKSPACE=0
export TF_CUDA_CLANG=0
export TF_NEED_TENSORRT=0
export TF_NCCL_VERSION="2"
./configure
popd

View File

@ -39,8 +39,8 @@ py_binary(
srcs = ["mnist_vae.py"],
deps = [
":datasets",
":minmax",
":stax",
"//jax:libjax",
"//jax:minmax",
"//jax:stax",
],
)

34
examples/interactive.py Normal file
View File

@ -0,0 +1,34 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from absl import app
import IPython
import numpy as onp
import jax
import jax.numpy as np
from jax import lax
from jax import random
from jax import jit, grad, vmap, jacfwd, jacrev, hessian
def main(unused_argv):
IPython.embed(user_ns=dict(globals(), **locals()))
if __name__ == "__main__":
app.run(main)

View File

@ -23,14 +23,15 @@ from __future__ import print_function
import time
import itertools
from absl import app
import numpy.random as npr
import jax.numpy as np
from jax import jit, grad
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
import datasets
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, Softmax
def loss(params, batch):
@ -47,7 +48,7 @@ def accuracy(params, batch):
init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(1024), Relu,
Dense(10), LogSoftmax)
Dense(10), Softmax)
if __name__ == "__main__":
step_size = 0.001
@ -92,3 +93,7 @@ if __name__ == "__main__":
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
app.run(main)

View File

@ -23,12 +23,13 @@ from __future__ import print_function
import time
from absl import app
import numpy.random as npr
from jax.api import jit, grad
import datasets
from jax.scipy.misc import logsumexp
import jax.numpy as np
import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
@ -92,3 +93,7 @@ if __name__ == "__main__":
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
app.run(main)

View File

@ -25,14 +25,15 @@ from __future__ import print_function
import os
import time
from absl import app
import matplotlib.pyplot as plt
import jax.numpy as np
from jax import jit, grad, lax, random
import datasets
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
import datasets
def gaussian_kl(mu, sigmasq):
@ -137,3 +138,7 @@ if __name__ == "__main__":
test_elbo, images = evaluate(opt_state, test_images)
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
plt.imsave(imfile.format(epoch), images, cmap=plt.cm.gray)
if __name__ == "__main__":
app.run(main)

View File

@ -21,6 +21,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
import numpy.random as npr
import jax.numpy as np
@ -29,7 +31,7 @@ from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
FanOut, Flatten, GeneralConv, Identity,
MaxPool, Relu, LogSoftmax)
MaxPool, Relu, Softmax)
# ResNet blocks compose other layers
@ -80,10 +82,12 @@ def ResNet50(num_classes):
ConvBlock(3, [512, 512, 2048]),
IdentityBlock(3, [512, 512]),
IdentityBlock(3, [512, 512]),
AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
AvgPool((7, 7)), Flatten, Dense(num_classes), Softmax)
if __name__ == "__main__":
def main(argv):
del argv # Unused.
batch_size = 8
num_classes = 1001
input_shape = (224, 224, 3, batch_size)
@ -124,3 +128,7 @@ if __name__ == "__main__":
for i in xrange(num_steps):
opt_state = update(i, opt_state, next(batches))
trained_params = minmax.get_params(opt_state)
if __name__ == '__main__':
app.run(main)

View File

@ -29,7 +29,6 @@ py_library(
"interpreters/*.py",
"numpy/*.py",
"scipy/*.py",
"scipy/stats/*.py",
],
exclude = [
"*_test.py",

View File

@ -179,10 +179,10 @@ def piecewise_constant(boundaries, values):
return values[np.sum(i > boundaries)]
return schedule
def make_schedule(scalar_or_schedule_fun):
if callable(scalar_or_schedule_fun):
return scalar_or_schedule_fun
elif np.ndim(scalar_or_schedule_fun) == 0:
return constant(scalar_or_schedule_fun)
def make_schedule(constant_scalar_or_schedule_fun):
if np.isscalar(constant_scalar_or_schedule_fun):
return constant(constant_scalar_or_schedule_fun)
elif callable(constant_scalar_or_schedule_fun):
return constant_scalar_or_schedule_fun
else:
raise TypeError, type(scalar_or_schedule_fun)
raise TypeError, type(constant_scalar_or_schedule_fun)

View File

@ -42,8 +42,9 @@ import jax.numpy as np
def relu(x): return np.maximum(x, 0.)
def softplus(x): return np.logaddexp(x, 0.)
def logsoftmax(x, axis=-1):
"""Apply log softmax to an array of logits, log-normalizing along an axis."""
# TODO(mattjj): change this name to better fit convention
def softmax(x, axis=-1):
"""Apply a softmax to an array of logits, log-normalizing along an axis."""
return x - logsumexp(x, axis, keepdims=True)
def fastvar(x, axis, keepdims):
@ -145,7 +146,7 @@ def _elemwise_no_params(fun, **kwargs):
return init_fun, apply_fun
Tanh = _elemwise_no_params(np.tanh)
Relu = _elemwise_no_params(relu)
LogSoftmax = _elemwise_no_params(logsoftmax, axis=-1)
Softmax = _elemwise_no_params(softmax, axis=-1)
Softplus = _elemwise_no_params(softplus)

52
jax/scipy/stats.py Normal file
View File

@ -0,0 +1,52 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as onp
import scipy.stats as osp_stats
from .. import lax
def _wraps(fun):
"""Like functools.wraps but works with numpy.ufuncs."""
docstr = """
LAX-backed implementation of {fun}. Corresponding Scipy docstring below.
{np_doc}
""".format(fun=fun.__name__, np_doc=fun.__doc__)
def wrap(op):
try:
op.__name__ = fun.__name__
op.__doc__ = docstr
finally:
return op
return wrap
@_wraps(osp_stats.norm.logpdf)
def norm_logpdf(x, loc=0, scale=1):
x, loc, scale = _promote_args_like(osp_stats.norm.logpdf, x, loc, scale)
two = _constant_like(x, 2)
scale_sqrd = lax.pow(scale, two)
log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * onp.pi), scale_sqrd))
quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
def _constant_like(x, const):
return onp.array(const, dtype=lax._dtype(x))

View File

@ -23,7 +23,7 @@ setup(
author_email='jax-team@google.com',
packages=['jax', 'jax.lib', 'jax.interpreters', 'jax.numpy', 'jax.scipy',
'jax.experimental'],
install_requires=['numpy>=1.12', 'six', 'protobuf', 'absl-py'],
install_requires=['numpy>=1.12', 'six', 'protobuf'],
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={'jax.lib': glob('jax/lib/*.so')},

View File

@ -14,11 +14,11 @@
licenses(["notice"]) # Apache 2
load(":build_defs.bzl", "jax_test")
load("//third_party/py/jax/tests:build_defs.bzl", "jax_test")
jax_test(
name = "core_test",
srcs = ["core_test.py"],
srcs = ["tests/core_test.py"],
shard_count = {
"cpu": 5,
},
@ -26,7 +26,7 @@ jax_test(
jax_test(
name = "lax_test",
srcs = ["lax_test.py"],
srcs = ["tests/lax_test.py"],
shard_count = {
"cpu": 40,
"gpu": 20,
@ -35,7 +35,7 @@ jax_test(
jax_test(
name = "lax_numpy_test",
srcs = ["lax_numpy_test.py"],
srcs = ["tests/lax_numpy_test.py"],
shard_count = {
"cpu": 40,
"gpu": 20,
@ -44,7 +44,7 @@ jax_test(
jax_test(
name = "lax_numpy_indexing_test",
srcs = ["lax_numpy_indexing_test.py"],
srcs = ["tests/lax_numpy_indexing_test.py"],
shard_count = {
"cpu": 10,
"gpu": 2,
@ -53,7 +53,7 @@ jax_test(
jax_test(
name = "lax_scipy_test",
srcs = ["lax_scipy_test.py"],
srcs = ["tests/lax_scipy_test.py"],
shard_count = {
"cpu": 10,
"gpu": 2,
@ -62,33 +62,33 @@ jax_test(
jax_test(
name = "random_test",
srcs = ["random_test.py"],
srcs = ["tests/random_test.py"],
)
jax_test(
name = "api_test",
srcs = ["api_test.py"],
srcs = ["tests/api_test.py"],
)
jax_test(
name = "batching_test",
srcs = ["batching_test.py"],
srcs = ["tests/batching_test.py"],
)
jax_test(
name = "stax_test",
srcs = ["stax_test.py"],
deps = ["//jax:stax"],
srcs = ["tests/stax_test.py"],
deps = [":stax"],
)
jax_test(
name = "minmax_test",
srcs = ["minmax_test.py"],
deps = ["//jax:minmax"],
srcs = ["tests/minmax_test.py"],
deps = [":minmax"],
)
jax_test(
name = "lapax_test",
srcs = ["lapax_test.py"],
deps = ["//jax:lapax"],
srcs = ["tests/lapax_test.py"],
deps = [":lapax"],
)

View File

@ -21,9 +21,9 @@ import functools
from absl.testing import absltest
import jax.numpy as np
import jax.test_util as jtu
from jax import jit, grad
import jax.numpy as np
from jax.api import grad
from jax.experimental import minmax
from jax.lib import xla_bridge as xla
@ -150,24 +150,5 @@ class OptimizerTests(jtu.JaxTestCase):
step_schedule = minmax.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
self._CheckOptimizer(minmax.rmsprop, loss, x0, num_iters, step_schedule)
def testTracedStepSize(self):
def loss(x, _): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_size = 0.1
init_fun, _ = minmax.sgd(step_size)
opt_state = init_fun(x0)
@jit
def update(opt_state, step_size):
_, update_fun = minmax.sgd(step_size)
x = minmax.get_params(opt_state)
g = grad(loss)(x, None)
return update_fun(0, g, opt_state)
update(opt_state, 0.9) # doesn't crash
if __name__ == '__main__':
absltest.main()