add jet tests, remove top-level files

This commit is contained in:
Matthew Johnson 2020-03-14 21:21:27 -07:00
parent 840797d4a1
commit 668a1703bc
5 changed files with 164 additions and 426 deletions

View File

@ -55,18 +55,18 @@ class JetTrace(core.Trace):
series_in = [[onp.zeros(onp.shape(x), dtype=onp.result_type(x))
if t is zero_term else t for t in series]
for x, series in zip(primals_in, series_in)]
rule = prop_rules[primitive]
rule = jet_rules[primitive]
primal_out, terms_out = rule(primals_in, series_in, **params)
return JetTracer(self, primal_out, terms_out)
def process_call(self, call_primitive, f, tracers, params):
assert False
assert False # TODO
def post_process_call(self, call_primitive, out_tracer, params):
assert False
assert False # TODO
def join(self, xt, yt):
assert False
assert False # TODO?
class ZeroTerm(object): pass
@ -76,27 +76,10 @@ class ZeroSeries(object): pass
zero_series = ZeroSeries()
prop_rules = {}
def tay_to_deriv_coeff(u_tay):
u_deriv = [ui * fact(i) for (i, ui) in enumerate(u_tay)]
return u_deriv
def deriv_to_tay_coeff(u_deriv):
u_tay = [ui / fact(i) for (i, ui) in enumerate(u_deriv)]
return u_tay
def taylor_tilde(u_tay):
u_tilde = [i * ui for (i, ui) in enumerate(u_tay)]
return u_tilde
def taylor_untilde(u_tilde):
u_tay = [i * ui for (i, ui) in enumerate(u_tilde)]
return u_tay
jet_rules = {}
def deflinear(prim):
prop_rules[prim] = partial(linear_prop, prim)
jet_rules[prim] = partial(linear_prop, prim)
def linear_prop(prim, primals_in, series_in, **params):
primal_out = prim.bind(*primals_in, **params)

View File

@ -1684,7 +1684,7 @@ def _exp_taylor(primals_in, series_in):
v[k] = fact(k-1) * sum([scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
primal_out, *series_out = v
return primal_out, series_out
taylor.prop_rules[exp_p] = _exp_taylor
taylor.jet_rules[exp_p] = _exp_taylor
log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
@ -1700,7 +1700,7 @@ def _log_taylor(primals_in, series_in):
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
primal_out, *series_out = v
return primal_out, series_out
taylor.prop_rules[log_p] = _log_taylor
taylor.jet_rules[log_p] = _log_taylor
expm1_p = standard_unop(_float | _complex, 'expm1')
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
@ -1922,7 +1922,7 @@ def _div_taylor_rule(primals_in, series_in, **params):
primal_out, *series_out = v
return primal_out, series_out
ad.primitive_transposes[div_p] = _div_transpose_rule
taylor.prop_rules[div_p] = _div_taylor_rule
taylor.jet_rules[div_p] = _div_taylor_rule
rem_p = standard_naryop([_num, _num], 'rem')
ad.defjvp(rem_p,
@ -2384,9 +2384,9 @@ def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
v[k] = fact(k) * sum([scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1)])
primal_out, *series_out = v
return primal_out, series_out
taylor.prop_rules[dot_general_p] = partial(_bilinear_taylor_rule, dot_general_p)
taylor.prop_rules[mul_p] = partial(_bilinear_taylor_rule, mul_p)
taylor.prop_rules[conv_general_dilated_p] = partial(_bilinear_taylor_rule, conv_general_dilated_p)
taylor.jet_rules[dot_general_p] = partial(_bilinear_taylor_rule, dot_general_p)
taylor.jet_rules[mul_p] = partial(_bilinear_taylor_rule, mul_p)
taylor.jet_rules[conv_general_dilated_p] = partial(_bilinear_taylor_rule, conv_general_dilated_p)
def _broadcast_shape_rule(operand, sizes):
@ -3205,7 +3205,7 @@ def _gather_taylor_rule(primals_in, series_in, **params):
primal_out = gather_p.bind(operand, start_indices, **params)
series_out = [gather_p.bind(g, start_indices, **params) for g in gs]
return primal_out, series_out
taylor.prop_rules[gather_p] = _gather_taylor_rule
taylor.jet_rules[gather_p] = _gather_taylor_rule
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
@ -3648,7 +3648,7 @@ def _reduce_max_taylor_rule(primals_in, series_in, **params):
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
taylor.prop_rules[reduce_max_p] = _reduce_max_taylor_rule
taylor.jet_rules[reduce_max_p] = _reduce_max_taylor_rule
batching.defreducer(reduce_max_p)

267
jet_nn.py
View File

@ -1,267 +0,0 @@
"""
Create some pared down examples to see if we can take jets through them.
"""
import time
import haiku as hk
import jax
import jax.numpy as jnp
import numpy.random as npr
from jax.flatten_util import ravel_pytree
import tensorflow_datasets as tfds
from functools import reduce
from scipy.special import factorial as fact
def sigmoid(z):
"""
Defined using only numpy primitives (but probably less numerically stable).
"""
return 1./(1. + jnp.exp(-z))
def repeated(f, n):
def rfun(p):
return reduce(lambda x, _: f(x), range(n), p)
return rfun
def jvp_taylor(f, primals, series):
def expansion(eps):
tayterms = [
sum([eps**(i + 1) * terms[i] / fact(i + 1) for i in range(len(terms))])
for terms in series
]
return f(*map(sum, zip(primals, tayterms)))
n_derivs = []
N = len(series[0]) + 1
for i in range(1, N):
d = repeated(jax.jacobian, i)(expansion)(0.)
n_derivs.append(d)
return f(*primals), n_derivs
def jvp_test_jet(f, primals, series, atol=1e-5):
tic = time.time()
y, terms = jax.jet(f, primals, series)
print("jet done in {} sec".format(time.time() - tic))
tic = time.time()
y_jvp, terms_jvp = jvp_taylor(f, primals, series)
print("jvp done in {} sec".format(time.time() - tic))
assert jnp.allclose(y, y_jvp)
assert jnp.allclose(terms, terms_jvp, atol=atol)
def softmax_cross_entropy(logits, labels):
one_hot = hk.one_hot(labels, logits.shape[-1])
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
# return -logits[labels]
rng = jax.random.PRNGKey(42)
order = 4
batch_size = 2
train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
train_ds = train_ds.cache().shuffle(1000).batch(batch_size)
batch = next(tfds.as_numpy(train_ds))
def test_mlp_x():
"""
Test jet through a linear layer built with Haiku wrt x.
"""
def loss_fn(images, labels):
model = hk.Sequential([
lambda x: x.astype(jnp.float32) / 255.,
hk.Linear(10)
])
logits = model(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_obj = hk.transform(loss_fn)
# flatten for MLP
batch['image'] = jnp.reshape(batch['image'], (batch_size, -1))
images, labels = jnp.array(batch['image'], dtype=jnp.float32), jnp.array(batch['label'])
params = loss_obj.init(rng, images, labels)
flat_params, unravel = ravel_pytree(params)
loss = loss_obj.apply(unravel(flat_params), images, labels)
print("forward pass works")
f = lambda images: loss_obj.apply(unravel(flat_params), images, labels)
_ = jax.grad(f)(images)
terms_in = [npr.randn(*images.shape) for _ in range(order)]
jvp_test_jet(f, (images,), (terms_in,))
def test_mlp1():
"""
Test jet through a linear layer built with Haiku.
"""
def loss_fn(images, labels):
model = hk.Sequential([
lambda x: x.astype(jnp.float32) / 255.,
hk.Linear(10)
])
logits = model(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_obj = hk.transform(loss_fn)
# flatten for MLP
batch['image'] = jnp.reshape(batch['image'], (batch_size, -1))
images, labels = jnp.array(batch['image'], dtype=jnp.float32), jnp.array(batch['label'])
params = loss_obj.init(rng, images, labels)
flat_params, unravel = ravel_pytree(params)
loss = loss_obj.apply(unravel(flat_params), images, labels)
print("forward pass works")
f = lambda flat_params: loss_obj.apply(unravel(flat_params), images, labels)
terms_in = [npr.randn(*flat_params.shape) for _ in range(order)]
jvp_test_jet(f, (flat_params,), (terms_in,))
def test_mlp2():
"""
Test jet through a MLP with sigmoid activations.
"""
def loss_fn(images, labels):
model = hk.Sequential([
lambda x: x.astype(jnp.float32) / 255.,
hk.Linear(100),
sigmoid,
hk.Linear(10)
])
logits = model(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_obj = hk.transform(loss_fn)
# flatten for MLP
batch['image'] = jnp.reshape(batch['image'], (batch_size, -1))
images, labels = jnp.array(batch['image'], dtype=jnp.float32), jnp.array(batch['label'])
params = loss_obj.init(rng, images, labels)
flat_params, unravel = ravel_pytree(params)
loss = loss_obj.apply(unravel(flat_params), images, labels)
print("forward pass works")
f = lambda flat_params: loss_obj.apply(unravel(flat_params), images, labels)
terms_in = [npr.randn(*flat_params.shape) for _ in range(order)]
jvp_test_jet(f, (flat_params,), (terms_in,), atol=1e-4)
def test_res1():
"""
Test jet through a simple convnet with sigmoid activations.
"""
def loss_fn(images, labels):
model = hk.Sequential([
lambda x: x.astype(jnp.float32) / 255.,
hk.Conv2D(output_channels=10,
kernel_shape=3,
stride=2,
padding="VALID"),
sigmoid,
hk.Reshape(output_shape=(batch_size, -1)),
hk.Linear(10)
])
logits = model(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_obj = hk.transform(loss_fn)
images, labels = batch['image'], batch['label']
params = loss_obj.init(rng, images, labels)
flat_params, unravel = ravel_pytree(params)
loss = loss_obj.apply(unravel(flat_params), images, labels)
print("forward pass works")
f = lambda flat_params: loss_obj.apply(unravel(flat_params), images, labels)
terms_in = [npr.randn(*flat_params.shape) for _ in range(order)]
jvp_test_jet(f, (flat_params,), (terms_in,))
def test_div():
x = 1.
y = 5.
primals = (x, y)
order = 4
series_in = ([npr.randn() for _ in range(order)], [npr.randn() for _ in range(order)])
jvp_test_jet(lambda a, b: a / b, primals, series_in)
def test_gather():
npr.seed(0)
D = 3 # dimensionality
N = 6 # differentiation order
x = npr.randn(D)
terms_in = list(npr.randn(N, D))
jvp_test_jet(lambda x: x[1:], (x,), (terms_in,))
def test_reduce_max():
npr.seed(0)
D1, D2 = 3, 5 # dimensionality
N = 6 # differentiation order
x = npr.randn(D1, D2)
terms_in = [npr.randn(D1, D2) for _ in range(N)]
jvp_test_jet(lambda x: x.max(axis=1), (x,), (terms_in,))
def test_sub():
x = 1.
y = 5.
primals = (x, y)
order = 4
series_in = ([npr.randn() for _ in range(order)], [npr.randn() for _ in range(order)])
jvp_test_jet(lambda a, b: a - b, primals, series_in)
def test_exp():
npr.seed(0)
D1, D2 = 3, 5 # dimensionality
N = 6 # differentiation order
x = npr.randn(D1, D2)
terms_in = [npr.randn(D1, D2) for _ in range(N)]
jvp_test_jet(lambda x: jnp.exp(x), (x,), (terms_in,))
def test_log():
npr.seed(0)
D1, D2 = 3, 5 # dimensionality
N = 4 # differentiation order
x = jnp.exp(npr.randn(D1, D2))
terms_in = [jnp.exp(npr.randn(D1, D2)) for _ in range(N)]
jvp_test_jet(lambda x: jnp.log(x), (x,), (terms_in,))
# test_div()
# test_gather()
# test_reduce_max()
# test_sub()
# test_exp()
# test_log()
# test_mlp_x()
# test_mlp1()
# test_mlp2()
# test_res1()

128
mac.py
View File

@ -1,128 +0,0 @@
import time
import jax.numpy as np
from jax import jet, jvp
def f(x, y):
return x + 2 * np.exp(y)
out = jet(f, (1., 2.), [(1., 0.), (1., 0.)])
print(out)
out = jvp(f, (1., 2.), (1., 1.))
print(out)
###
from functools import reduce
import numpy.random as npr
from jax import jacobian
from scipy.special import factorial as fact
def jvp_taylor(f, primals, series):
def expansion(eps):
tayterms = [
sum([eps**(i + 1) * terms[i] / fact(i + 1) for i in range(len(terms))])
for terms in series
]
return f(*map(sum, zip(primals, tayterms)))
n_derivs = []
N = len(series[0]) + 1
for i in range(1, N):
d = repeated(jacobian, i)(expansion)(0.)
n_derivs.append(d)
return f(*primals), n_derivs
def repeated(f, n):
def rfun(p):
return reduce(lambda x, _: f(x), range(n), p)
return rfun
def jvp_test_jet(f, primals, series, atol=1e-5):
tic = time.time()
y, terms = jet(f, primals, series)
print("jet done in {} sec".format(time.time() - tic))
tic = time.time()
y_jvp, terms_jvp = jvp_taylor(f, primals, series)
print("jvp done in {} sec".format(time.time() - tic))
assert np.allclose(y, y_jvp)
assert np.allclose(terms, terms_jvp, atol=atol)
def test_exp():
npr.seed(0)
D = 3 # dimensionality
N = 6 # differentiation order
x = npr.randn(D)
terms_in = list(npr.randn(N,D))
jvp_test_jet(np.exp, (x,), (terms_in,), atol=1e-4)
def test_dot():
M, K, N = 5, 6, 7
order = 4
x1 = npr.randn(M, K)
x2 = npr.randn(K, N)
primals = (x1, x2)
terms_in1 = [npr.randn(*x1.shape) for _ in range(order)]
terms_in2 = [npr.randn(*x2.shape) for _ in range(order)]
series_in = (terms_in1, terms_in2)
jvp_test_jet(np.dot, primals, series_in)
def test_mlp():
sigm = lambda x: 1. / (1. + np.exp(-x))
def mlp(M1,M2,x):
return np.dot(sigm(np.dot(x,M1)),M2)
f_mlp = lambda x: mlp(M1,M2,x)
M1,M2 = (npr.randn(10,10), npr.randn(10,5))
x= npr.randn(2,10)
terms_in = [np.ones_like(x), np.zeros_like(x), np.zeros_like(x), np.zeros_like(x)]
jvp_test_jet(f_mlp,(x,),[terms_in])
def test_mul():
D = 3
N = 4
x1 = npr.randn(D)
x2 = npr.randn(D)
f = lambda a, b: a * b
primals = (x1, x2)
terms_in1 = list(npr.randn(N,D))
terms_in2 = list(npr.randn(N,D))
series_in = (terms_in1, terms_in2)
jvp_test_jet(f, primals, series_in)
from jax.experimental import stax
from jax import random
from jax.tree_util import tree_map
def test_conv():
order = 4
input_shape = (1, 5, 5, 1)
key = random.PRNGKey(0)
init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
_, (W, b) = init_fun(key, input_shape)
x = npr.randn(*input_shape)
primals = (W, b, x)
series_in1 = [npr.randn(*W.shape) for _ in range(order)]
series_in2 = [npr.randn(*b.shape) for _ in range(order)]
series_in3 = [npr.randn(*x.shape) for _ in range(order)]
series_in = (series_in1, series_in2, series_in3)
def f(W, b, x):
return apply_fun((W, b), x)
jvp_test_jet(f, primals, series_in)
test_exp()
test_dot()
test_conv()
test_mul()
# test_mlp() # TODO add div rule!

150
tests/jet_test.py Normal file
View File

@ -0,0 +1,150 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial, reduce
import operator as op
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
from scipy.special import factorial as fact
from jax import core
from jax import test_util as jtu
import jax.numpy as np
from jax import random
from jax import jet, jacobian, jit
from jax.experimental import stax
from jax.config import config
config.parse_flags_with_absl()
def jvp_taylor(fun, primals, series):
order, = set(map(len, series))
def composition(eps):
taylor_terms = [sum([eps ** (i+1) * terms[i] / fact(i + 1)
for i in range(len(terms))]) for terms in series]
nudged_args = [x + t for x, t in zip(primals, taylor_terms)]
return fun(*nudged_args)
primal_out = fun(*primals)
terms_out = [repeated(jacobian, i+1)(composition)(0.) for i in range(order)]
return primal_out, terms_out
def repeated(f, n):
def rfun(p):
return reduce(lambda x, _: f(x), range(n), p)
return rfun
class JetTest(jtu.JaxTestCase):
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5):
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=True)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=True)
@jtu.skip_on_devices("tpu")
def test_exp(self):
order, dim = 4, 3
rng = onp.random.RandomState(0)
primal_in = rng.randn(dim)
terms_in = [rng.randn(dim) for _ in range(order)]
self.check_jet(np.exp, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
@jtu.skip_on_devices("tpu")
def test_log(self):
order, dim = 4, 3
rng = onp.random.RandomState(0)
primal_in = np.exp(rng.randn(dim))
terms_in = [rng.randn(dim) for _ in range(order)]
self.check_jet(np.log, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
@jtu.skip_on_devices("tpu")
def test_dot(self):
M, K, N = 2, 3, 4
order = 3
rng = onp.random.RandomState(0)
x1 = rng.randn(M, K)
x2 = rng.randn(K, N)
primals = (x1, x2)
terms_in1 = [rng.randn(*x1.shape) for _ in range(order)]
terms_in2 = [rng.randn(*x2.shape) for _ in range(order)]
series_in = (terms_in1, terms_in2)
self.check_jet(np.dot, primals, series_in)
@jtu.skip_on_devices("tpu")
def test_conv(self):
order = 3
input_shape = (1, 5, 5, 1)
key = random.PRNGKey(0)
init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
_, (W, b) = init_fun(key, input_shape)
rng = onp.random.RandomState(0)
x = rng.randn(*input_shape)
primals = (W, b, x)
series_in1 = [rng.randn(*W.shape) for _ in range(order)]
series_in2 = [rng.randn(*b.shape) for _ in range(order)]
series_in3 = [rng.randn(*x.shape) for _ in range(order)]
series_in = (series_in1, series_in2, series_in3)
def f(W, b, x):
return apply_fun((W, b), x)
self.check_jet(f, primals, series_in)
@jtu.skip_on_devices("tpu")
def test_div(self):
primals = 1., 5.
order = 4
rng = onp.random.RandomState(0)
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
self.check_jet(op.truediv, primals, series_in)
@jtu.skip_on_devices("tpu")
def test_sub(self):
primals = 1., 5.
order = 4
rng = onp.random.RandomState(0)
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
self.check_jet(op.sub, primals, series_in)
@jtu.skip_on_devices("tpu")
def test_gather(self):
order, dim = 4, 3
rng = onp.random.RandomState(0)
x = rng.randn(dim)
terms_in = [rng.randn(dim) for _ in range(order)]
self.check_jet(lambda x: x[1:], (x,), (terms_in,))
@jtu.skip_on_devices("tpu")
def test_reduce_max(self):
dim1, dim2 = 3, 5
order = 6
rng = onp.random.RandomState(0)
x = rng.randn(dim1, dim2)
terms_in = [rng.randn(dim1, dim2) for _ in range(order)]
self.check_jet(lambda x: x.max(axis=1), (x,), (terms_in,))
if __name__ == '__main__':
absltest.main()