mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
source sync
PiperOrigin-RevId: 222153576
This commit is contained in:
parent
51fc713089
commit
a3619ca89d
@ -1,5 +0,0 @@
|
||||
JAX is a research project that grew out of [Autograd](https://github.com/hips/autograd).
|
||||
Here's a [two-page abstract](https://www.sysml.cc/doc/146.pdf) about an early version.
|
||||
Watch this space for updates!
|
||||
|
||||
This is not an official Google product.
|
@ -55,7 +55,6 @@ then
|
||||
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
|
||||
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
|
||||
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
|
||||
export TF_NCCL_VERSION=2
|
||||
export TF_NEED_CUDA=1
|
||||
else
|
||||
export TF_NEED_CUDA=0
|
||||
@ -73,6 +72,7 @@ export TF_DOWNLOAD_CLANG=0
|
||||
export TF_SET_ANDROID_WORKSPACE=0
|
||||
export TF_CUDA_CLANG=0
|
||||
export TF_NEED_TENSORRT=0
|
||||
export TF_NCCL_VERSION="2"
|
||||
./configure
|
||||
popd
|
||||
|
||||
|
@ -39,8 +39,8 @@ py_binary(
|
||||
srcs = ["mnist_vae.py"],
|
||||
deps = [
|
||||
":datasets",
|
||||
":minmax",
|
||||
":stax",
|
||||
"//jax:libjax",
|
||||
"//jax:minmax",
|
||||
"//jax:stax",
|
||||
],
|
||||
)
|
||||
|
34
examples/interactive.py
Normal file
34
examples/interactive.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from absl import app
|
||||
|
||||
import IPython
|
||||
import numpy as onp
|
||||
|
||||
import jax
|
||||
import jax.numpy as np
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax import jit, grad, vmap, jacfwd, jacrev, hessian
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
IPython.embed(user_ns=dict(globals(), **locals()))
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -23,14 +23,15 @@ from __future__ import print_function
|
||||
import time
|
||||
import itertools
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
import datasets
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, Softmax
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
@ -47,7 +48,7 @@ def accuracy(params, batch):
|
||||
init_random_params, predict = stax.serial(
|
||||
Dense(1024), Relu,
|
||||
Dense(1024), Relu,
|
||||
Dense(10), LogSoftmax)
|
||||
Dense(10), Softmax)
|
||||
|
||||
if __name__ == "__main__":
|
||||
step_size = 0.001
|
||||
@ -92,3 +93,7 @@ if __name__ == "__main__":
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -23,12 +23,13 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
import datasets
|
||||
from jax.scipy.misc import logsumexp
|
||||
import jax.numpy as np
|
||||
import datasets
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
@ -92,3 +93,7 @@ if __name__ == "__main__":
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -25,14 +25,15 @@ from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad, lax, random
|
||||
import datasets
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
import datasets
|
||||
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
@ -137,3 +138,7 @@ if __name__ == "__main__":
|
||||
test_elbo, images = evaluate(opt_state, test_images)
|
||||
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
|
||||
plt.imsave(imfile.format(epoch), images, cmap=plt.cm.gray)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -21,6 +21,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
@ -29,7 +31,7 @@ from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
|
||||
FanOut, Flatten, GeneralConv, Identity,
|
||||
MaxPool, Relu, LogSoftmax)
|
||||
MaxPool, Relu, Softmax)
|
||||
|
||||
|
||||
# ResNet blocks compose other layers
|
||||
@ -80,10 +82,12 @@ def ResNet50(num_classes):
|
||||
ConvBlock(3, [512, 512, 2048]),
|
||||
IdentityBlock(3, [512, 512]),
|
||||
IdentityBlock(3, [512, 512]),
|
||||
AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
|
||||
AvgPool((7, 7)), Flatten, Dense(num_classes), Softmax)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
batch_size = 8
|
||||
num_classes = 1001
|
||||
input_shape = (224, 224, 3, batch_size)
|
||||
@ -124,3 +128,7 @@ if __name__ == "__main__":
|
||||
for i in xrange(num_steps):
|
||||
opt_state = update(i, opt_state, next(batches))
|
||||
trained_params = minmax.get_params(opt_state)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
|
@ -29,7 +29,6 @@ py_library(
|
||||
"interpreters/*.py",
|
||||
"numpy/*.py",
|
||||
"scipy/*.py",
|
||||
"scipy/stats/*.py",
|
||||
],
|
||||
exclude = [
|
||||
"*_test.py",
|
||||
|
@ -179,10 +179,10 @@ def piecewise_constant(boundaries, values):
|
||||
return values[np.sum(i > boundaries)]
|
||||
return schedule
|
||||
|
||||
def make_schedule(scalar_or_schedule_fun):
|
||||
if callable(scalar_or_schedule_fun):
|
||||
return scalar_or_schedule_fun
|
||||
elif np.ndim(scalar_or_schedule_fun) == 0:
|
||||
return constant(scalar_or_schedule_fun)
|
||||
def make_schedule(constant_scalar_or_schedule_fun):
|
||||
if np.isscalar(constant_scalar_or_schedule_fun):
|
||||
return constant(constant_scalar_or_schedule_fun)
|
||||
elif callable(constant_scalar_or_schedule_fun):
|
||||
return constant_scalar_or_schedule_fun
|
||||
else:
|
||||
raise TypeError, type(scalar_or_schedule_fun)
|
||||
raise TypeError, type(constant_scalar_or_schedule_fun)
|
||||
|
@ -42,8 +42,9 @@ import jax.numpy as np
|
||||
def relu(x): return np.maximum(x, 0.)
|
||||
def softplus(x): return np.logaddexp(x, 0.)
|
||||
|
||||
def logsoftmax(x, axis=-1):
|
||||
"""Apply log softmax to an array of logits, log-normalizing along an axis."""
|
||||
# TODO(mattjj): change this name to better fit convention
|
||||
def softmax(x, axis=-1):
|
||||
"""Apply a softmax to an array of logits, log-normalizing along an axis."""
|
||||
return x - logsumexp(x, axis, keepdims=True)
|
||||
|
||||
def fastvar(x, axis, keepdims):
|
||||
@ -145,7 +146,7 @@ def _elemwise_no_params(fun, **kwargs):
|
||||
return init_fun, apply_fun
|
||||
Tanh = _elemwise_no_params(np.tanh)
|
||||
Relu = _elemwise_no_params(relu)
|
||||
LogSoftmax = _elemwise_no_params(logsoftmax, axis=-1)
|
||||
Softmax = _elemwise_no_params(softmax, axis=-1)
|
||||
Softplus = _elemwise_no_params(softplus)
|
||||
|
||||
|
||||
|
52
jax/scipy/stats.py
Normal file
52
jax/scipy/stats.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from .. import lax
|
||||
|
||||
|
||||
def _wraps(fun):
|
||||
"""Like functools.wraps but works with numpy.ufuncs."""
|
||||
docstr = """
|
||||
LAX-backed implementation of {fun}. Corresponding Scipy docstring below.
|
||||
|
||||
{np_doc}
|
||||
""".format(fun=fun.__name__, np_doc=fun.__doc__)
|
||||
def wrap(op):
|
||||
try:
|
||||
op.__name__ = fun.__name__
|
||||
op.__doc__ = docstr
|
||||
finally:
|
||||
return op
|
||||
return wrap
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logpdf)
|
||||
def norm_logpdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.norm.logpdf, x, loc, scale)
|
||||
two = _constant_like(x, 2)
|
||||
scale_sqrd = lax.pow(scale, two)
|
||||
log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * onp.pi), scale_sqrd))
|
||||
quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
|
||||
return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
|
||||
|
||||
|
||||
def _constant_like(x, const):
|
||||
return onp.array(const, dtype=lax._dtype(x))
|
2
setup.py
2
setup.py
@ -23,7 +23,7 @@ setup(
|
||||
author_email='jax-team@google.com',
|
||||
packages=['jax', 'jax.lib', 'jax.interpreters', 'jax.numpy', 'jax.scipy',
|
||||
'jax.experimental'],
|
||||
install_requires=['numpy>=1.12', 'six', 'protobuf', 'absl-py'],
|
||||
install_requires=['numpy>=1.12', 'six', 'protobuf'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
package_data={'jax.lib': glob('jax/lib/*.so')},
|
||||
|
30
tests/BUILD
30
tests/BUILD
@ -14,11 +14,11 @@
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
load(":build_defs.bzl", "jax_test")
|
||||
load("//third_party/py/jax/tests:build_defs.bzl", "jax_test")
|
||||
|
||||
jax_test(
|
||||
name = "core_test",
|
||||
srcs = ["core_test.py"],
|
||||
srcs = ["tests/core_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 5,
|
||||
},
|
||||
@ -26,7 +26,7 @@ jax_test(
|
||||
|
||||
jax_test(
|
||||
name = "lax_test",
|
||||
srcs = ["lax_test.py"],
|
||||
srcs = ["tests/lax_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 20,
|
||||
@ -35,7 +35,7 @@ jax_test(
|
||||
|
||||
jax_test(
|
||||
name = "lax_numpy_test",
|
||||
srcs = ["lax_numpy_test.py"],
|
||||
srcs = ["tests/lax_numpy_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 20,
|
||||
@ -44,7 +44,7 @@ jax_test(
|
||||
|
||||
jax_test(
|
||||
name = "lax_numpy_indexing_test",
|
||||
srcs = ["lax_numpy_indexing_test.py"],
|
||||
srcs = ["tests/lax_numpy_indexing_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 2,
|
||||
@ -53,7 +53,7 @@ jax_test(
|
||||
|
||||
jax_test(
|
||||
name = "lax_scipy_test",
|
||||
srcs = ["lax_scipy_test.py"],
|
||||
srcs = ["tests/lax_scipy_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 2,
|
||||
@ -62,33 +62,33 @@ jax_test(
|
||||
|
||||
jax_test(
|
||||
name = "random_test",
|
||||
srcs = ["random_test.py"],
|
||||
srcs = ["tests/random_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "api_test",
|
||||
srcs = ["api_test.py"],
|
||||
srcs = ["tests/api_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "batching_test",
|
||||
srcs = ["batching_test.py"],
|
||||
srcs = ["tests/batching_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "stax_test",
|
||||
srcs = ["stax_test.py"],
|
||||
deps = ["//jax:stax"],
|
||||
srcs = ["tests/stax_test.py"],
|
||||
deps = [":stax"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "minmax_test",
|
||||
srcs = ["minmax_test.py"],
|
||||
deps = ["//jax:minmax"],
|
||||
srcs = ["tests/minmax_test.py"],
|
||||
deps = [":minmax"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lapax_test",
|
||||
srcs = ["lapax_test.py"],
|
||||
deps = ["//jax:lapax"],
|
||||
srcs = ["tests/lapax_test.py"],
|
||||
deps = [":lapax"],
|
||||
)
|
||||
|
@ -21,9 +21,9 @@ import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.test_util as jtu
|
||||
from jax import jit, grad
|
||||
import jax.numpy as np
|
||||
from jax.api import grad
|
||||
from jax.experimental import minmax
|
||||
from jax.lib import xla_bridge as xla
|
||||
|
||||
@ -150,24 +150,5 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
step_schedule = minmax.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
|
||||
self._CheckOptimizer(minmax.rmsprop, loss, x0, num_iters, step_schedule)
|
||||
|
||||
def testTracedStepSize(self):
|
||||
def loss(x, _): return np.dot(x, x)
|
||||
x0 = np.ones(2)
|
||||
num_iters = 100
|
||||
step_size = 0.1
|
||||
|
||||
init_fun, _ = minmax.sgd(step_size)
|
||||
opt_state = init_fun(x0)
|
||||
|
||||
@jit
|
||||
def update(opt_state, step_size):
|
||||
_, update_fun = minmax.sgd(step_size)
|
||||
x = minmax.get_params(opt_state)
|
||||
g = grad(loss)(x, None)
|
||||
return update_fun(0, g, opt_state)
|
||||
|
||||
update(opt_state, 0.9) # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user