mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add jet tests, remove top-level files
This commit is contained in:
parent
840797d4a1
commit
668a1703bc
@ -55,18 +55,18 @@ class JetTrace(core.Trace):
|
||||
series_in = [[onp.zeros(onp.shape(x), dtype=onp.result_type(x))
|
||||
if t is zero_term else t for t in series]
|
||||
for x, series in zip(primals_in, series_in)]
|
||||
rule = prop_rules[primitive]
|
||||
rule = jet_rules[primitive]
|
||||
primal_out, terms_out = rule(primals_in, series_in, **params)
|
||||
return JetTracer(self, primal_out, terms_out)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
assert False
|
||||
assert False # TODO
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracer, params):
|
||||
assert False
|
||||
assert False # TODO
|
||||
|
||||
def join(self, xt, yt):
|
||||
assert False
|
||||
assert False # TODO?
|
||||
|
||||
|
||||
class ZeroTerm(object): pass
|
||||
@ -76,27 +76,10 @@ class ZeroSeries(object): pass
|
||||
zero_series = ZeroSeries()
|
||||
|
||||
|
||||
prop_rules = {}
|
||||
|
||||
def tay_to_deriv_coeff(u_tay):
|
||||
u_deriv = [ui * fact(i) for (i, ui) in enumerate(u_tay)]
|
||||
return u_deriv
|
||||
|
||||
def deriv_to_tay_coeff(u_deriv):
|
||||
u_tay = [ui / fact(i) for (i, ui) in enumerate(u_deriv)]
|
||||
return u_tay
|
||||
|
||||
def taylor_tilde(u_tay):
|
||||
u_tilde = [i * ui for (i, ui) in enumerate(u_tay)]
|
||||
return u_tilde
|
||||
|
||||
def taylor_untilde(u_tilde):
|
||||
u_tay = [i * ui for (i, ui) in enumerate(u_tilde)]
|
||||
return u_tay
|
||||
|
||||
jet_rules = {}
|
||||
|
||||
def deflinear(prim):
|
||||
prop_rules[prim] = partial(linear_prop, prim)
|
||||
jet_rules[prim] = partial(linear_prop, prim)
|
||||
|
||||
def linear_prop(prim, primals_in, series_in, **params):
|
||||
primal_out = prim.bind(*primals_in, **params)
|
||||
|
@ -1684,7 +1684,7 @@ def _exp_taylor(primals_in, series_in):
|
||||
v[k] = fact(k-1) * sum([scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
taylor.prop_rules[exp_p] = _exp_taylor
|
||||
taylor.jet_rules[exp_p] = _exp_taylor
|
||||
|
||||
log_p = standard_unop(_float | _complex, 'log')
|
||||
ad.defjvp(log_p, lambda g, x: div(g, x))
|
||||
@ -1700,7 +1700,7 @@ def _log_taylor(primals_in, series_in):
|
||||
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
taylor.prop_rules[log_p] = _log_taylor
|
||||
taylor.jet_rules[log_p] = _log_taylor
|
||||
|
||||
expm1_p = standard_unop(_float | _complex, 'expm1')
|
||||
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
||||
@ -1922,7 +1922,7 @@ def _div_taylor_rule(primals_in, series_in, **params):
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
ad.primitive_transposes[div_p] = _div_transpose_rule
|
||||
taylor.prop_rules[div_p] = _div_taylor_rule
|
||||
taylor.jet_rules[div_p] = _div_taylor_rule
|
||||
|
||||
rem_p = standard_naryop([_num, _num], 'rem')
|
||||
ad.defjvp(rem_p,
|
||||
@ -2384,9 +2384,9 @@ def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
|
||||
v[k] = fact(k) * sum([scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1)])
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
taylor.prop_rules[dot_general_p] = partial(_bilinear_taylor_rule, dot_general_p)
|
||||
taylor.prop_rules[mul_p] = partial(_bilinear_taylor_rule, mul_p)
|
||||
taylor.prop_rules[conv_general_dilated_p] = partial(_bilinear_taylor_rule, conv_general_dilated_p)
|
||||
taylor.jet_rules[dot_general_p] = partial(_bilinear_taylor_rule, dot_general_p)
|
||||
taylor.jet_rules[mul_p] = partial(_bilinear_taylor_rule, mul_p)
|
||||
taylor.jet_rules[conv_general_dilated_p] = partial(_bilinear_taylor_rule, conv_general_dilated_p)
|
||||
|
||||
|
||||
def _broadcast_shape_rule(operand, sizes):
|
||||
@ -3205,7 +3205,7 @@ def _gather_taylor_rule(primals_in, series_in, **params):
|
||||
primal_out = gather_p.bind(operand, start_indices, **params)
|
||||
series_out = [gather_p.bind(g, start_indices, **params) for g in gs]
|
||||
return primal_out, series_out
|
||||
taylor.prop_rules[gather_p] = _gather_taylor_rule
|
||||
taylor.jet_rules[gather_p] = _gather_taylor_rule
|
||||
|
||||
ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
||||
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
@ -3648,7 +3648,7 @@ def _reduce_max_taylor_rule(primals_in, series_in, **params):
|
||||
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
|
||||
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
|
||||
return primal_out, series_out
|
||||
taylor.prop_rules[reduce_max_p] = _reduce_max_taylor_rule
|
||||
taylor.jet_rules[reduce_max_p] = _reduce_max_taylor_rule
|
||||
|
||||
batching.defreducer(reduce_max_p)
|
||||
|
||||
|
267
jet_nn.py
267
jet_nn.py
@ -1,267 +0,0 @@
|
||||
"""
|
||||
Create some pared down examples to see if we can take jets through them.
|
||||
"""
|
||||
import time
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.flatten_util import ravel_pytree
|
||||
|
||||
import tensorflow_datasets as tfds
|
||||
from functools import reduce
|
||||
|
||||
from scipy.special import factorial as fact
|
||||
|
||||
|
||||
def sigmoid(z):
|
||||
"""
|
||||
Defined using only numpy primitives (but probably less numerically stable).
|
||||
"""
|
||||
return 1./(1. + jnp.exp(-z))
|
||||
|
||||
def repeated(f, n):
|
||||
def rfun(p):
|
||||
return reduce(lambda x, _: f(x), range(n), p)
|
||||
return rfun
|
||||
|
||||
|
||||
def jvp_taylor(f, primals, series):
|
||||
def expansion(eps):
|
||||
tayterms = [
|
||||
sum([eps**(i + 1) * terms[i] / fact(i + 1) for i in range(len(terms))])
|
||||
for terms in series
|
||||
]
|
||||
return f(*map(sum, zip(primals, tayterms)))
|
||||
|
||||
n_derivs = []
|
||||
N = len(series[0]) + 1
|
||||
for i in range(1, N):
|
||||
d = repeated(jax.jacobian, i)(expansion)(0.)
|
||||
n_derivs.append(d)
|
||||
return f(*primals), n_derivs
|
||||
|
||||
|
||||
def jvp_test_jet(f, primals, series, atol=1e-5):
|
||||
tic = time.time()
|
||||
y, terms = jax.jet(f, primals, series)
|
||||
print("jet done in {} sec".format(time.time() - tic))
|
||||
|
||||
tic = time.time()
|
||||
y_jvp, terms_jvp = jvp_taylor(f, primals, series)
|
||||
print("jvp done in {} sec".format(time.time() - tic))
|
||||
|
||||
assert jnp.allclose(y, y_jvp)
|
||||
assert jnp.allclose(terms, terms_jvp, atol=atol)
|
||||
|
||||
|
||||
def softmax_cross_entropy(logits, labels):
|
||||
one_hot = hk.one_hot(labels, logits.shape[-1])
|
||||
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
|
||||
# return -logits[labels]
|
||||
|
||||
|
||||
rng = jax.random.PRNGKey(42)
|
||||
|
||||
order = 4
|
||||
|
||||
batch_size = 2
|
||||
train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
|
||||
train_ds = train_ds.cache().shuffle(1000).batch(batch_size)
|
||||
batch = next(tfds.as_numpy(train_ds))
|
||||
|
||||
|
||||
def test_mlp_x():
|
||||
"""
|
||||
Test jet through a linear layer built with Haiku wrt x.
|
||||
"""
|
||||
|
||||
def loss_fn(images, labels):
|
||||
model = hk.Sequential([
|
||||
lambda x: x.astype(jnp.float32) / 255.,
|
||||
hk.Linear(10)
|
||||
])
|
||||
logits = model(images)
|
||||
return jnp.mean(softmax_cross_entropy(logits, labels))
|
||||
|
||||
loss_obj = hk.transform(loss_fn)
|
||||
|
||||
# flatten for MLP
|
||||
batch['image'] = jnp.reshape(batch['image'], (batch_size, -1))
|
||||
images, labels = jnp.array(batch['image'], dtype=jnp.float32), jnp.array(batch['label'])
|
||||
|
||||
params = loss_obj.init(rng, images, labels)
|
||||
|
||||
flat_params, unravel = ravel_pytree(params)
|
||||
|
||||
loss = loss_obj.apply(unravel(flat_params), images, labels)
|
||||
print("forward pass works")
|
||||
|
||||
f = lambda images: loss_obj.apply(unravel(flat_params), images, labels)
|
||||
_ = jax.grad(f)(images)
|
||||
terms_in = [npr.randn(*images.shape) for _ in range(order)]
|
||||
jvp_test_jet(f, (images,), (terms_in,))
|
||||
|
||||
|
||||
def test_mlp1():
|
||||
"""
|
||||
Test jet through a linear layer built with Haiku.
|
||||
"""
|
||||
|
||||
def loss_fn(images, labels):
|
||||
model = hk.Sequential([
|
||||
lambda x: x.astype(jnp.float32) / 255.,
|
||||
hk.Linear(10)
|
||||
])
|
||||
logits = model(images)
|
||||
return jnp.mean(softmax_cross_entropy(logits, labels))
|
||||
|
||||
loss_obj = hk.transform(loss_fn)
|
||||
|
||||
# flatten for MLP
|
||||
batch['image'] = jnp.reshape(batch['image'], (batch_size, -1))
|
||||
images, labels = jnp.array(batch['image'], dtype=jnp.float32), jnp.array(batch['label'])
|
||||
|
||||
params = loss_obj.init(rng, images, labels)
|
||||
|
||||
flat_params, unravel = ravel_pytree(params)
|
||||
|
||||
loss = loss_obj.apply(unravel(flat_params), images, labels)
|
||||
print("forward pass works")
|
||||
|
||||
f = lambda flat_params: loss_obj.apply(unravel(flat_params), images, labels)
|
||||
terms_in = [npr.randn(*flat_params.shape) for _ in range(order)]
|
||||
jvp_test_jet(f, (flat_params,), (terms_in,))
|
||||
|
||||
|
||||
def test_mlp2():
|
||||
"""
|
||||
Test jet through a MLP with sigmoid activations.
|
||||
"""
|
||||
|
||||
def loss_fn(images, labels):
|
||||
model = hk.Sequential([
|
||||
lambda x: x.astype(jnp.float32) / 255.,
|
||||
hk.Linear(100),
|
||||
sigmoid,
|
||||
hk.Linear(10)
|
||||
])
|
||||
logits = model(images)
|
||||
return jnp.mean(softmax_cross_entropy(logits, labels))
|
||||
|
||||
loss_obj = hk.transform(loss_fn)
|
||||
|
||||
# flatten for MLP
|
||||
batch['image'] = jnp.reshape(batch['image'], (batch_size, -1))
|
||||
images, labels = jnp.array(batch['image'], dtype=jnp.float32), jnp.array(batch['label'])
|
||||
|
||||
params = loss_obj.init(rng, images, labels)
|
||||
|
||||
flat_params, unravel = ravel_pytree(params)
|
||||
|
||||
loss = loss_obj.apply(unravel(flat_params), images, labels)
|
||||
print("forward pass works")
|
||||
|
||||
f = lambda flat_params: loss_obj.apply(unravel(flat_params), images, labels)
|
||||
terms_in = [npr.randn(*flat_params.shape) for _ in range(order)]
|
||||
jvp_test_jet(f, (flat_params,), (terms_in,), atol=1e-4)
|
||||
|
||||
|
||||
def test_res1():
|
||||
"""
|
||||
Test jet through a simple convnet with sigmoid activations.
|
||||
"""
|
||||
|
||||
def loss_fn(images, labels):
|
||||
model = hk.Sequential([
|
||||
lambda x: x.astype(jnp.float32) / 255.,
|
||||
hk.Conv2D(output_channels=10,
|
||||
kernel_shape=3,
|
||||
stride=2,
|
||||
padding="VALID"),
|
||||
sigmoid,
|
||||
hk.Reshape(output_shape=(batch_size, -1)),
|
||||
hk.Linear(10)
|
||||
])
|
||||
logits = model(images)
|
||||
return jnp.mean(softmax_cross_entropy(logits, labels))
|
||||
|
||||
loss_obj = hk.transform(loss_fn)
|
||||
|
||||
images, labels = batch['image'], batch['label']
|
||||
|
||||
params = loss_obj.init(rng, images, labels)
|
||||
|
||||
flat_params, unravel = ravel_pytree(params)
|
||||
|
||||
loss = loss_obj.apply(unravel(flat_params), images, labels)
|
||||
print("forward pass works")
|
||||
|
||||
f = lambda flat_params: loss_obj.apply(unravel(flat_params), images, labels)
|
||||
terms_in = [npr.randn(*flat_params.shape) for _ in range(order)]
|
||||
jvp_test_jet(f, (flat_params,), (terms_in,))
|
||||
|
||||
|
||||
def test_div():
|
||||
x = 1.
|
||||
y = 5.
|
||||
primals = (x, y)
|
||||
order = 4
|
||||
series_in = ([npr.randn() for _ in range(order)], [npr.randn() for _ in range(order)])
|
||||
jvp_test_jet(lambda a, b: a / b, primals, series_in)
|
||||
|
||||
|
||||
def test_gather():
|
||||
npr.seed(0)
|
||||
D = 3 # dimensionality
|
||||
N = 6 # differentiation order
|
||||
x = npr.randn(D)
|
||||
terms_in = list(npr.randn(N, D))
|
||||
jvp_test_jet(lambda x: x[1:], (x,), (terms_in,))
|
||||
|
||||
def test_reduce_max():
|
||||
npr.seed(0)
|
||||
D1, D2 = 3, 5 # dimensionality
|
||||
N = 6 # differentiation order
|
||||
x = npr.randn(D1, D2)
|
||||
terms_in = [npr.randn(D1, D2) for _ in range(N)]
|
||||
jvp_test_jet(lambda x: x.max(axis=1), (x,), (terms_in,))
|
||||
|
||||
def test_sub():
|
||||
x = 1.
|
||||
y = 5.
|
||||
primals = (x, y)
|
||||
order = 4
|
||||
series_in = ([npr.randn() for _ in range(order)], [npr.randn() for _ in range(order)])
|
||||
jvp_test_jet(lambda a, b: a - b, primals, series_in)
|
||||
|
||||
def test_exp():
|
||||
npr.seed(0)
|
||||
D1, D2 = 3, 5 # dimensionality
|
||||
N = 6 # differentiation order
|
||||
x = npr.randn(D1, D2)
|
||||
terms_in = [npr.randn(D1, D2) for _ in range(N)]
|
||||
jvp_test_jet(lambda x: jnp.exp(x), (x,), (terms_in,))
|
||||
|
||||
def test_log():
|
||||
npr.seed(0)
|
||||
D1, D2 = 3, 5 # dimensionality
|
||||
N = 4 # differentiation order
|
||||
x = jnp.exp(npr.randn(D1, D2))
|
||||
terms_in = [jnp.exp(npr.randn(D1, D2)) for _ in range(N)]
|
||||
jvp_test_jet(lambda x: jnp.log(x), (x,), (terms_in,))
|
||||
|
||||
|
||||
# test_div()
|
||||
# test_gather()
|
||||
# test_reduce_max()
|
||||
# test_sub()
|
||||
# test_exp()
|
||||
# test_log()
|
||||
|
||||
# test_mlp_x()
|
||||
# test_mlp1()
|
||||
# test_mlp2()
|
||||
# test_res1()
|
128
mac.py
128
mac.py
@ -1,128 +0,0 @@
|
||||
import time
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jet, jvp
|
||||
|
||||
def f(x, y):
|
||||
return x + 2 * np.exp(y)
|
||||
|
||||
out = jet(f, (1., 2.), [(1., 0.), (1., 0.)])
|
||||
print(out)
|
||||
|
||||
out = jvp(f, (1., 2.), (1., 1.))
|
||||
print(out)
|
||||
|
||||
|
||||
###
|
||||
|
||||
from functools import reduce
|
||||
import numpy.random as npr
|
||||
from jax import jacobian
|
||||
|
||||
from scipy.special import factorial as fact
|
||||
|
||||
def jvp_taylor(f, primals, series):
|
||||
def expansion(eps):
|
||||
tayterms = [
|
||||
sum([eps**(i + 1) * terms[i] / fact(i + 1) for i in range(len(terms))])
|
||||
for terms in series
|
||||
]
|
||||
return f(*map(sum, zip(primals, tayterms)))
|
||||
|
||||
n_derivs = []
|
||||
N = len(series[0]) + 1
|
||||
for i in range(1, N):
|
||||
d = repeated(jacobian, i)(expansion)(0.)
|
||||
n_derivs.append(d)
|
||||
return f(*primals), n_derivs
|
||||
|
||||
def repeated(f, n):
|
||||
def rfun(p):
|
||||
return reduce(lambda x, _: f(x), range(n), p)
|
||||
return rfun
|
||||
|
||||
def jvp_test_jet(f, primals, series, atol=1e-5):
|
||||
tic = time.time()
|
||||
y, terms = jet(f, primals, series)
|
||||
print("jet done in {} sec".format(time.time() - tic))
|
||||
|
||||
tic = time.time()
|
||||
y_jvp, terms_jvp = jvp_taylor(f, primals, series)
|
||||
print("jvp done in {} sec".format(time.time() - tic))
|
||||
|
||||
assert np.allclose(y, y_jvp)
|
||||
assert np.allclose(terms, terms_jvp, atol=atol)
|
||||
|
||||
def test_exp():
|
||||
npr.seed(0)
|
||||
D = 3 # dimensionality
|
||||
N = 6 # differentiation order
|
||||
x = npr.randn(D)
|
||||
terms_in = list(npr.randn(N,D))
|
||||
jvp_test_jet(np.exp, (x,), (terms_in,), atol=1e-4)
|
||||
|
||||
|
||||
def test_dot():
|
||||
M, K, N = 5, 6, 7
|
||||
order = 4
|
||||
x1 = npr.randn(M, K)
|
||||
x2 = npr.randn(K, N)
|
||||
primals = (x1, x2)
|
||||
terms_in1 = [npr.randn(*x1.shape) for _ in range(order)]
|
||||
terms_in2 = [npr.randn(*x2.shape) for _ in range(order)]
|
||||
series_in = (terms_in1, terms_in2)
|
||||
jvp_test_jet(np.dot, primals, series_in)
|
||||
|
||||
|
||||
def test_mlp():
|
||||
sigm = lambda x: 1. / (1. + np.exp(-x))
|
||||
def mlp(M1,M2,x):
|
||||
return np.dot(sigm(np.dot(x,M1)),M2)
|
||||
f_mlp = lambda x: mlp(M1,M2,x)
|
||||
M1,M2 = (npr.randn(10,10), npr.randn(10,5))
|
||||
x= npr.randn(2,10)
|
||||
terms_in = [np.ones_like(x), np.zeros_like(x), np.zeros_like(x), np.zeros_like(x)]
|
||||
jvp_test_jet(f_mlp,(x,),[terms_in])
|
||||
|
||||
def test_mul():
|
||||
D = 3
|
||||
N = 4
|
||||
x1 = npr.randn(D)
|
||||
x2 = npr.randn(D)
|
||||
f = lambda a, b: a * b
|
||||
primals = (x1, x2)
|
||||
terms_in1 = list(npr.randn(N,D))
|
||||
terms_in2 = list(npr.randn(N,D))
|
||||
series_in = (terms_in1, terms_in2)
|
||||
jvp_test_jet(f, primals, series_in)
|
||||
|
||||
from jax.experimental import stax
|
||||
from jax import random
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
def test_conv():
|
||||
order = 4
|
||||
input_shape = (1, 5, 5, 1)
|
||||
key = random.PRNGKey(0)
|
||||
init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
|
||||
_, (W, b) = init_fun(key, input_shape)
|
||||
|
||||
x = npr.randn(*input_shape)
|
||||
primals = (W, b, x)
|
||||
|
||||
series_in1 = [npr.randn(*W.shape) for _ in range(order)]
|
||||
series_in2 = [npr.randn(*b.shape) for _ in range(order)]
|
||||
series_in3 = [npr.randn(*x.shape) for _ in range(order)]
|
||||
|
||||
series_in = (series_in1, series_in2, series_in3)
|
||||
|
||||
def f(W, b, x):
|
||||
return apply_fun((W, b), x)
|
||||
jvp_test_jet(f, primals, series_in)
|
||||
|
||||
|
||||
test_exp()
|
||||
test_dot()
|
||||
test_conv()
|
||||
test_mul()
|
||||
# test_mlp() # TODO add div rule!
|
150
tests/jet_test.py
Normal file
150
tests/jet_test.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial, reduce
|
||||
import operator as op
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as onp
|
||||
from scipy.special import factorial as fact
|
||||
|
||||
from jax import core
|
||||
from jax import test_util as jtu
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import random
|
||||
from jax import jet, jacobian, jit
|
||||
from jax.experimental import stax
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def jvp_taylor(fun, primals, series):
|
||||
order, = set(map(len, series))
|
||||
def composition(eps):
|
||||
taylor_terms = [sum([eps ** (i+1) * terms[i] / fact(i + 1)
|
||||
for i in range(len(terms))]) for terms in series]
|
||||
nudged_args = [x + t for x, t in zip(primals, taylor_terms)]
|
||||
return fun(*nudged_args)
|
||||
primal_out = fun(*primals)
|
||||
terms_out = [repeated(jacobian, i+1)(composition)(0.) for i in range(order)]
|
||||
return primal_out, terms_out
|
||||
|
||||
def repeated(f, n):
|
||||
def rfun(p):
|
||||
return reduce(lambda x, _: f(x), range(n), p)
|
||||
return rfun
|
||||
|
||||
class JetTest(jtu.JaxTestCase):
|
||||
|
||||
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5):
|
||||
y, terms = jet(fun, primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(fun, primals, series)
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=True)
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=True)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_exp(self):
|
||||
order, dim = 4, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
primal_in = rng.randn(dim)
|
||||
terms_in = [rng.randn(dim) for _ in range(order)]
|
||||
self.check_jet(np.exp, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_log(self):
|
||||
order, dim = 4, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
primal_in = np.exp(rng.randn(dim))
|
||||
terms_in = [rng.randn(dim) for _ in range(order)]
|
||||
self.check_jet(np.log, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_dot(self):
|
||||
M, K, N = 2, 3, 4
|
||||
order = 3
|
||||
rng = onp.random.RandomState(0)
|
||||
x1 = rng.randn(M, K)
|
||||
x2 = rng.randn(K, N)
|
||||
primals = (x1, x2)
|
||||
terms_in1 = [rng.randn(*x1.shape) for _ in range(order)]
|
||||
terms_in2 = [rng.randn(*x2.shape) for _ in range(order)]
|
||||
series_in = (terms_in1, terms_in2)
|
||||
self.check_jet(np.dot, primals, series_in)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_conv(self):
|
||||
order = 3
|
||||
input_shape = (1, 5, 5, 1)
|
||||
key = random.PRNGKey(0)
|
||||
init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
|
||||
_, (W, b) = init_fun(key, input_shape)
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
|
||||
x = rng.randn(*input_shape)
|
||||
primals = (W, b, x)
|
||||
|
||||
series_in1 = [rng.randn(*W.shape) for _ in range(order)]
|
||||
series_in2 = [rng.randn(*b.shape) for _ in range(order)]
|
||||
series_in3 = [rng.randn(*x.shape) for _ in range(order)]
|
||||
|
||||
series_in = (series_in1, series_in2, series_in3)
|
||||
|
||||
def f(W, b, x):
|
||||
return apply_fun((W, b), x)
|
||||
|
||||
self.check_jet(f, primals, series_in)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_div(self):
|
||||
primals = 1., 5.
|
||||
order = 4
|
||||
rng = onp.random.RandomState(0)
|
||||
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
|
||||
self.check_jet(op.truediv, primals, series_in)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sub(self):
|
||||
primals = 1., 5.
|
||||
order = 4
|
||||
rng = onp.random.RandomState(0)
|
||||
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
|
||||
self.check_jet(op.sub, primals, series_in)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_gather(self):
|
||||
order, dim = 4, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
x = rng.randn(dim)
|
||||
terms_in = [rng.randn(dim) for _ in range(order)]
|
||||
self.check_jet(lambda x: x[1:], (x,), (terms_in,))
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_reduce_max(self):
|
||||
dim1, dim2 = 3, 5
|
||||
order = 6
|
||||
rng = onp.random.RandomState(0)
|
||||
x = rng.randn(dim1, dim2)
|
||||
terms_in = [rng.randn(dim1, dim2) for _ in range(order)]
|
||||
self.check_jet(lambda x: x.max(axis=1), (x,), (terms_in,))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user