Added lots of trivial jet rules.

Co-Authored-By: jessebett <jessebett@gmail.com>
Co-Authored-By: Jacob Kelly <jacob.kelly@mail.utoronto.ca>
This commit is contained in:
David Duvenaud 2020-03-29 16:28:17 -04:00
parent bcc5191c63
commit ead8011837
2 changed files with 152 additions and 85 deletions

View File

@ -14,16 +14,17 @@
from functools import partial
from collections import Counter
import numpy as onp
from jax import core
from jax.util import unzip2, prod
from jax.util import unzip2
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten)
import jax.linear_util as lu
from jax.interpreters import xla
from jax.lax import lax
from jax.lax import lax_fft
def jet(fun, primals, series):
try:
@ -124,8 +125,33 @@ zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
### rule definitions
jet_rules = {}
def defzero(prim):
jet_rules[prim] = partial(zero_prop, prim)
def zero_prop(prim, primals_in, series_in, **params):
primal_out = prim.bind(*primals_in, **params)
return primal_out, zero_series
defzero(lax.le_p)
defzero(lax.lt_p)
defzero(lax.gt_p)
defzero(lax.ge_p)
defzero(lax.eq_p)
defzero(lax.ne_p)
defzero(lax.and_p)
defzero(lax.or_p)
defzero(lax.xor_p)
defzero(lax.floor_p)
defzero(lax.ceil_p)
defzero(lax.round_p)
defzero(lax.sign_p)
defzero(lax.stop_gradient_p)
def deflinear(prim):
jet_rules[prim] = partial(linear_prop, prim)
@ -134,14 +160,6 @@ def linear_prop(prim, primals_in, series_in, **params):
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
return primal_out, series_out
### rule definitions
from jax.lax import lax
def fact(n):
return lax.exp(lax.lgamma(n+1.))
deflinear(lax.neg_p)
deflinear(lax.real_p)
deflinear(lax.complex_p)
@ -159,15 +177,26 @@ deflinear(lax.slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax.tie_in_p)
deflinear(lax_fft.fft_p)
deflinear(xla.device_put_p)
### More complicated rules
def fact(n):
return lax.exp(lax.lgamma(n+1.))
def _scale(k, j):
return 1. / (fact(k - j) * fact(j - 1))
def _exp_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.exp(x)] + [None] * len(series)
def scale(k, j): return 1. / (fact(k-j) * fact(j-1))
for k in range(1,len(v)):
v[k] = fact(k-1) * sum([scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.exp_p] = _exp_taylor
@ -177,9 +206,8 @@ def _log_taylor(primals_in, series_in):
series, = series_in
u = [x] + series
v = [lax.log(x)] + [None] * len(series)
def scale(k, j): return 1. / (fact(k-j) * fact(j-1))
for k in range(1, len(v)):
conv = sum([scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
conv = sum([_scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
primal_out, *series_out = v
return primal_out, series_out
@ -223,22 +251,30 @@ def _gather_taylor_rule(primals_in, series_in, **params):
return primal_out, series_out
jet_rules[lax.gather_p] = _gather_taylor_rule
def _reduce_max_taylor_rule(primals_in, series_in, **params):
operand, = primals_in
gs, = series_in
primal_out = lax.reduce_max_p.bind(operand, **params)
axes = params.pop("axes", None)
primal_dtype = gs[0].dtype
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
location_indicators = lax.convert_element_type(
lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
counts = lax._reduce_sum(location_indicators, axes)
def _reduce_chooser_taylor_rule(g):
return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
def _gen_reduce_choose_taylor_rule(chooser_fun):
def chooser_taylor_rule(primals_in, series_in, **params):
operand, = primals_in
gs, = series_in
primal_out = chooser_fun(operand, **params)
axes = params.pop("axes", None)
primal_dtype = gs[0].dtype
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
location_indicators = lax.convert_element_type(
lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
counts = lax._reduce_sum(location_indicators, axes)
def _reduce_chooser_taylor_rule(g):
return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
return chooser_taylor_rule
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax.reduce_max_p.bind)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax.reduce_min_p.bind)
def _abs_taylor_rule(x, series_in, **params):
x, = x
primal_out = lax.abs_p.bind(x, **params)
negs = lax.select(lax.lt(x, 0.0), lax.full_like(x, -1), lax.full_like(x, 1.0))
fix_sign = lambda y: negs * y
series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.reduce_max_p] = _reduce_max_taylor_rule
from jax.interpreters import xla
deflinear(xla.device_put_p)
jet_rules[lax.abs_p] = _abs_taylor_rule

View File

@ -13,28 +13,25 @@
# limitations under the License.
from functools import partial, reduce
import operator as op
from unittest import SkipTest
from functools import reduce
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
from jax import core
from jax import test_util as jtu
import jax.numpy as np
from jax import random
from jax import jacobian, jit
from jax import jacfwd
from jax.experimental import stax
from jax.experimental.jet import jet, fact
from jax.experimental.jet import jet, fact, zero_series
from jax.tree_util import tree_map
from jax import lax
from jax.config import config
config.parse_flags_with_absl()
def jvp_taylor(fun, primals, series):
# Computes the Taylor series the slow way, with nested jvp.
order, = set(map(len, series))
def composition(eps):
taylor_terms = [sum([eps ** (i+1) * terms[i] / fact(i + 1)
@ -42,7 +39,7 @@ def jvp_taylor(fun, primals, series):
nudged_args = [x + t for x, t in zip(primals, taylor_terms)]
return fun(*nudged_args)
primal_out = fun(*primals)
terms_out = [repeated(jacobian, i+1)(composition)(0.) for i in range(order)]
terms_out = [repeated(jacfwd, i+1)(composition)(0.) for i in range(order)]
return primal_out, terms_out
def repeated(f, n):
@ -50,6 +47,9 @@ def repeated(f, n):
return reduce(lambda x, _: f(x), range(n), p)
return rfun
def transform(lims, x):
return x * (lims[1] - lims[0]) + lims[0]
class JetTest(jtu.JaxTestCase):
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
@ -58,25 +58,14 @@ class JetTest(jtu.JaxTestCase):
expected_y, expected_terms = jvp_taylor(fun, primals, series)
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
if terms == zero_series:
terms = tree_map(np.zeros_like, expected_terms)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
@jtu.skip_on_devices("tpu")
def test_exp(self):
order, dim = 4, 3
rng = onp.random.RandomState(0)
primal_in = rng.randn(dim)
terms_in = [rng.randn(dim) for _ in range(order)]
self.check_jet(np.exp, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
@jtu.skip_on_devices("tpu")
def test_log(self):
order, dim = 4, 3
rng = onp.random.RandomState(0)
primal_in = np.exp(rng.randn(dim))
terms_in = [rng.randn(dim) for _ in range(order)]
self.check_jet(np.log, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
@jtu.skip_on_devices("tpu")
def test_dot(self):
M, K, N = 2, 3, 4
@ -95,6 +84,7 @@ class JetTest(jtu.JaxTestCase):
order = 3
input_shape = (1, 5, 5, 1)
key = random.PRNGKey(0)
# TODO(duvenaud): Check all types of padding
init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
_, (W, b) = init_fun(key, input_shape)
@ -114,38 +104,79 @@ class JetTest(jtu.JaxTestCase):
self.check_jet(f, primals, series_in, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_div(self):
primals = 1., 5.
order = 4
def unary_check(self, fun, lims=[-2, 2], order=3):
dims = 2, 3
rng = onp.random.RandomState(0)
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
self.check_jet(op.truediv, primals, series_in)
primal_in = transform(lims, rng.rand(*dims))
terms_in = [rng.randn(*dims) for _ in range(order)]
self.check_jet(fun, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
def binary_check(self, fun, lims=[-2, 2], order=3):
dims = 2, 3
rng = onp.random.RandomState(0)
primal_in = (transform(lims, rng.rand(*dims)),
transform(lims, rng.rand(*dims)))
series_in = ([rng.randn(*dims) for _ in range(order)],
[rng.randn(*dims) for _ in range(order)])
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
@jtu.skip_on_devices("tpu")
def test_sub(self):
primals = 1., 5.
order = 4
rng = onp.random.RandomState(0)
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
self.check_jet(op.sub, primals, series_in)
def test_exp(self): self.unary_check(np.exp)
@jtu.skip_on_devices("tpu")
def test_neg(self): self.unary_check(np.negative)
@jtu.skip_on_devices("tpu")
def test_floor(self): self.unary_check(np.floor)
@jtu.skip_on_devices("tpu")
def test_ceil(self): self.unary_check(np.ceil)
@jtu.skip_on_devices("tpu")
def test_round(self): self.unary_check(np.round)
@jtu.skip_on_devices("tpu")
def test_sign(self): self.unary_check(np.sign)
@jtu.skip_on_devices("tpu")
def test_log(self): self.unary_check(np.log, lims=[0.8, 4.0])
@jtu.skip_on_devices("tpu")
def test_gather(self): self.unary_check(lambda x: x[1:])
@jtu.skip_on_devices("tpu")
def test_reduce_max(self): self.unary_check(lambda x: x.max(axis=1))
@jtu.skip_on_devices("tpu")
def test_reduce_min(self): self.unary_check(lambda x: x.min(axis=1))
@jtu.skip_on_devices("tpu")
def test_all_max(self): self.unary_check(np.max)
@jtu.skip_on_devices("tpu")
def test_all_min(self): self.unary_check(np.min)
@jtu.skip_on_devices("tpu")
def test_stopgrad(self): self.unary_check(lax.stop_gradient)
@jtu.skip_on_devices("tpu")
def test_abs(self): self.unary_check(np.abs)
@jtu.skip_on_devices("tpu")
def test_fft(self): self.unary_check(np.fft.fft)
@jtu.skip_on_devices("tpu")
def test_gather(self):
order, dim = 4, 3
rng = onp.random.RandomState(0)
x = rng.randn(dim)
terms_in = [rng.randn(dim) for _ in range(order)]
self.check_jet(lambda x: x[1:], (x,), (terms_in,))
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
@jtu.skip_on_devices("tpu")
def test_reduce_max(self):
dim1, dim2 = 3, 5
order = 6
rng = onp.random.RandomState(0)
x = rng.randn(dim1, dim2)
terms_in = [rng.randn(dim1, dim2) for _ in range(order)]
self.check_jet(lambda x: x.max(axis=1), (x,), (terms_in,))
def test_sub(self): self.binary_check(lambda x, y: x - y)
@jtu.skip_on_devices("tpu")
def test_add(self): self.binary_check(lambda x, y: x + y)
@jtu.skip_on_devices("tpu")
def test_mul(self): self.binary_check(lambda x, y: x * y)
@jtu.skip_on_devices("tpu")
def test_le(self): self.binary_check(lambda x, y: x <= y)
@jtu.skip_on_devices("tpu")
def test_gt(self): self.binary_check(lambda x, y: x > y)
@jtu.skip_on_devices("tpu")
def test_lt(self): self.binary_check(lambda x, y: x < y)
@jtu.skip_on_devices("tpu")
def test_ge(self): self.binary_check(lambda x, y: x >= y)
@jtu.skip_on_devices("tpu")
def test_eq(self): self.binary_check(lambda x, y: x == y)
@jtu.skip_on_devices("tpu")
def test_ne(self): self.binary_check(lambda x, y: x != y)
@jtu.skip_on_devices("tpu")
def test_and(self): self.binary_check(lambda x, y: np.logical_and(x, y))
@jtu.skip_on_devices("tpu")
def test_or(self): self.binary_check(lambda x, y: np.logical_or(x, y))
@jtu.skip_on_devices("tpu")
def test_xor(self): self.binary_check(lambda x, y: np.logical_xor(x, y))
if __name__ == '__main__':