mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Added lots of trivial jet rules.
Co-Authored-By: jessebett <jessebett@gmail.com> Co-Authored-By: Jacob Kelly <jacob.kelly@mail.utoronto.ca>
This commit is contained in:
parent
bcc5191c63
commit
ead8011837
@ -14,16 +14,17 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
from collections import Counter
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from jax import core
|
||||
from jax.util import unzip2, prod
|
||||
from jax.util import unzip2
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten)
|
||||
import jax.linear_util as lu
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax.lax import lax
|
||||
from jax.lax import lax_fft
|
||||
|
||||
def jet(fun, primals, series):
|
||||
try:
|
||||
@ -124,8 +125,33 @@ zero_series = ZeroSeries()
|
||||
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
|
||||
|
||||
|
||||
### rule definitions
|
||||
|
||||
jet_rules = {}
|
||||
|
||||
def defzero(prim):
|
||||
jet_rules[prim] = partial(zero_prop, prim)
|
||||
|
||||
def zero_prop(prim, primals_in, series_in, **params):
|
||||
primal_out = prim.bind(*primals_in, **params)
|
||||
return primal_out, zero_series
|
||||
|
||||
defzero(lax.le_p)
|
||||
defzero(lax.lt_p)
|
||||
defzero(lax.gt_p)
|
||||
defzero(lax.ge_p)
|
||||
defzero(lax.eq_p)
|
||||
defzero(lax.ne_p)
|
||||
defzero(lax.and_p)
|
||||
defzero(lax.or_p)
|
||||
defzero(lax.xor_p)
|
||||
defzero(lax.floor_p)
|
||||
defzero(lax.ceil_p)
|
||||
defzero(lax.round_p)
|
||||
defzero(lax.sign_p)
|
||||
defzero(lax.stop_gradient_p)
|
||||
|
||||
|
||||
def deflinear(prim):
|
||||
jet_rules[prim] = partial(linear_prop, prim)
|
||||
|
||||
@ -134,14 +160,6 @@ def linear_prop(prim, primals_in, series_in, **params):
|
||||
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
|
||||
return primal_out, series_out
|
||||
|
||||
|
||||
### rule definitions
|
||||
|
||||
from jax.lax import lax
|
||||
|
||||
def fact(n):
|
||||
return lax.exp(lax.lgamma(n+1.))
|
||||
|
||||
deflinear(lax.neg_p)
|
||||
deflinear(lax.real_p)
|
||||
deflinear(lax.complex_p)
|
||||
@ -159,15 +177,26 @@ deflinear(lax.slice_p)
|
||||
deflinear(lax.reduce_sum_p)
|
||||
deflinear(lax.reduce_window_sum_p)
|
||||
deflinear(lax.tie_in_p)
|
||||
deflinear(lax_fft.fft_p)
|
||||
deflinear(xla.device_put_p)
|
||||
|
||||
|
||||
|
||||
### More complicated rules
|
||||
|
||||
def fact(n):
|
||||
return lax.exp(lax.lgamma(n+1.))
|
||||
|
||||
def _scale(k, j):
|
||||
return 1. / (fact(k - j) * fact(j - 1))
|
||||
|
||||
def _exp_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
v = [lax.exp(x)] + [None] * len(series)
|
||||
def scale(k, j): return 1. / (fact(k-j) * fact(j-1))
|
||||
for k in range(1,len(v)):
|
||||
v[k] = fact(k-1) * sum([scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
|
||||
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.exp_p] = _exp_taylor
|
||||
@ -177,9 +206,8 @@ def _log_taylor(primals_in, series_in):
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
v = [lax.log(x)] + [None] * len(series)
|
||||
def scale(k, j): return 1. / (fact(k-j) * fact(j-1))
|
||||
for k in range(1, len(v)):
|
||||
conv = sum([scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
|
||||
conv = sum([_scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
|
||||
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
@ -223,22 +251,30 @@ def _gather_taylor_rule(primals_in, series_in, **params):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.gather_p] = _gather_taylor_rule
|
||||
|
||||
def _reduce_max_taylor_rule(primals_in, series_in, **params):
|
||||
operand, = primals_in
|
||||
gs, = series_in
|
||||
primal_out = lax.reduce_max_p.bind(operand, **params)
|
||||
axes = params.pop("axes", None)
|
||||
primal_dtype = gs[0].dtype
|
||||
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
||||
location_indicators = lax.convert_element_type(
|
||||
lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
|
||||
counts = lax._reduce_sum(location_indicators, axes)
|
||||
def _reduce_chooser_taylor_rule(g):
|
||||
return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
|
||||
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
|
||||
def _gen_reduce_choose_taylor_rule(chooser_fun):
|
||||
def chooser_taylor_rule(primals_in, series_in, **params):
|
||||
operand, = primals_in
|
||||
gs, = series_in
|
||||
primal_out = chooser_fun(operand, **params)
|
||||
axes = params.pop("axes", None)
|
||||
primal_dtype = gs[0].dtype
|
||||
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
||||
location_indicators = lax.convert_element_type(
|
||||
lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
|
||||
counts = lax._reduce_sum(location_indicators, axes)
|
||||
def _reduce_chooser_taylor_rule(g):
|
||||
return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
|
||||
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
|
||||
return primal_out, series_out
|
||||
return chooser_taylor_rule
|
||||
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax.reduce_max_p.bind)
|
||||
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax.reduce_min_p.bind)
|
||||
|
||||
def _abs_taylor_rule(x, series_in, **params):
|
||||
x, = x
|
||||
primal_out = lax.abs_p.bind(x, **params)
|
||||
negs = lax.select(lax.lt(x, 0.0), lax.full_like(x, -1), lax.full_like(x, 1.0))
|
||||
fix_sign = lambda y: negs * y
|
||||
series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.reduce_max_p] = _reduce_max_taylor_rule
|
||||
|
||||
|
||||
from jax.interpreters import xla
|
||||
deflinear(xla.device_put_p)
|
||||
jet_rules[lax.abs_p] = _abs_taylor_rule
|
||||
|
@ -13,28 +13,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial, reduce
|
||||
import operator as op
|
||||
from unittest import SkipTest
|
||||
from functools import reduce
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as onp
|
||||
|
||||
from jax import core
|
||||
from jax import test_util as jtu
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import random
|
||||
from jax import jacobian, jit
|
||||
from jax import jacfwd
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.jet import jet, fact
|
||||
from jax.experimental.jet import jet, fact, zero_series
|
||||
from jax.tree_util import tree_map
|
||||
from jax import lax
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def jvp_taylor(fun, primals, series):
|
||||
# Computes the Taylor series the slow way, with nested jvp.
|
||||
order, = set(map(len, series))
|
||||
def composition(eps):
|
||||
taylor_terms = [sum([eps ** (i+1) * terms[i] / fact(i + 1)
|
||||
@ -42,7 +39,7 @@ def jvp_taylor(fun, primals, series):
|
||||
nudged_args = [x + t for x, t in zip(primals, taylor_terms)]
|
||||
return fun(*nudged_args)
|
||||
primal_out = fun(*primals)
|
||||
terms_out = [repeated(jacobian, i+1)(composition)(0.) for i in range(order)]
|
||||
terms_out = [repeated(jacfwd, i+1)(composition)(0.) for i in range(order)]
|
||||
return primal_out, terms_out
|
||||
|
||||
def repeated(f, n):
|
||||
@ -50,6 +47,9 @@ def repeated(f, n):
|
||||
return reduce(lambda x, _: f(x), range(n), p)
|
||||
return rfun
|
||||
|
||||
def transform(lims, x):
|
||||
return x * (lims[1] - lims[0]) + lims[0]
|
||||
|
||||
class JetTest(jtu.JaxTestCase):
|
||||
|
||||
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
|
||||
@ -58,25 +58,14 @@ class JetTest(jtu.JaxTestCase):
|
||||
expected_y, expected_terms = jvp_taylor(fun, primals, series)
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
|
||||
if terms == zero_series:
|
||||
terms = tree_map(np.zeros_like, expected_terms)
|
||||
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_exp(self):
|
||||
order, dim = 4, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
primal_in = rng.randn(dim)
|
||||
terms_in = [rng.randn(dim) for _ in range(order)]
|
||||
self.check_jet(np.exp, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_log(self):
|
||||
order, dim = 4, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
primal_in = np.exp(rng.randn(dim))
|
||||
terms_in = [rng.randn(dim) for _ in range(order)]
|
||||
self.check_jet(np.log, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_dot(self):
|
||||
M, K, N = 2, 3, 4
|
||||
@ -95,6 +84,7 @@ class JetTest(jtu.JaxTestCase):
|
||||
order = 3
|
||||
input_shape = (1, 5, 5, 1)
|
||||
key = random.PRNGKey(0)
|
||||
# TODO(duvenaud): Check all types of padding
|
||||
init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
|
||||
_, (W, b) = init_fun(key, input_shape)
|
||||
|
||||
@ -114,38 +104,79 @@ class JetTest(jtu.JaxTestCase):
|
||||
|
||||
self.check_jet(f, primals, series_in, check_dtypes=False)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_div(self):
|
||||
primals = 1., 5.
|
||||
order = 4
|
||||
def unary_check(self, fun, lims=[-2, 2], order=3):
|
||||
dims = 2, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
|
||||
self.check_jet(op.truediv, primals, series_in)
|
||||
primal_in = transform(lims, rng.rand(*dims))
|
||||
terms_in = [rng.randn(*dims) for _ in range(order)]
|
||||
self.check_jet(fun, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def binary_check(self, fun, lims=[-2, 2], order=3):
|
||||
dims = 2, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
primal_in = (transform(lims, rng.rand(*dims)),
|
||||
transform(lims, rng.rand(*dims)))
|
||||
series_in = ([rng.randn(*dims) for _ in range(order)],
|
||||
[rng.randn(*dims) for _ in range(order)])
|
||||
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sub(self):
|
||||
primals = 1., 5.
|
||||
order = 4
|
||||
rng = onp.random.RandomState(0)
|
||||
series_in = ([rng.randn() for _ in range(order)], [rng.randn() for _ in range(order)])
|
||||
self.check_jet(op.sub, primals, series_in)
|
||||
def test_exp(self): self.unary_check(np.exp)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_neg(self): self.unary_check(np.negative)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_floor(self): self.unary_check(np.floor)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ceil(self): self.unary_check(np.ceil)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_round(self): self.unary_check(np.round)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sign(self): self.unary_check(np.sign)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_log(self): self.unary_check(np.log, lims=[0.8, 4.0])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_gather(self): self.unary_check(lambda x: x[1:])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_reduce_max(self): self.unary_check(lambda x: x.max(axis=1))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_reduce_min(self): self.unary_check(lambda x: x.min(axis=1))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_all_max(self): self.unary_check(np.max)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_all_min(self): self.unary_check(np.min)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_stopgrad(self): self.unary_check(lax.stop_gradient)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_abs(self): self.unary_check(np.abs)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_fft(self): self.unary_check(np.fft.fft)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_gather(self):
|
||||
order, dim = 4, 3
|
||||
rng = onp.random.RandomState(0)
|
||||
x = rng.randn(dim)
|
||||
terms_in = [rng.randn(dim) for _ in range(order)]
|
||||
self.check_jet(lambda x: x[1:], (x,), (terms_in,))
|
||||
|
||||
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_reduce_max(self):
|
||||
dim1, dim2 = 3, 5
|
||||
order = 6
|
||||
rng = onp.random.RandomState(0)
|
||||
x = rng.randn(dim1, dim2)
|
||||
terms_in = [rng.randn(dim1, dim2) for _ in range(order)]
|
||||
self.check_jet(lambda x: x.max(axis=1), (x,), (terms_in,))
|
||||
def test_sub(self): self.binary_check(lambda x, y: x - y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_add(self): self.binary_check(lambda x, y: x + y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_mul(self): self.binary_check(lambda x, y: x * y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_le(self): self.binary_check(lambda x, y: x <= y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_gt(self): self.binary_check(lambda x, y: x > y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_lt(self): self.binary_check(lambda x, y: x < y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ge(self): self.binary_check(lambda x, y: x >= y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_eq(self): self.binary_check(lambda x, y: x == y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ne(self): self.binary_check(lambda x, y: x != y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_and(self): self.binary_check(lambda x, y: np.logical_and(x, y))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_or(self): self.binary_check(lambda x, y: np.logical_or(x, y))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_xor(self): self.binary_check(lambda x, y: np.logical_xor(x, y))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user