mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add jet primitives, refactor tests (#3468)
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
This commit is contained in:
parent
e9ce700b06
commit
575216e094
@ -18,6 +18,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax.util import unzip2
|
||||
from jax import ad_util
|
||||
@ -180,6 +181,7 @@ defzero(lax.gt_p)
|
||||
defzero(lax.ge_p)
|
||||
defzero(lax.eq_p)
|
||||
defzero(lax.ne_p)
|
||||
defzero(lax.not_p)
|
||||
defzero(lax.and_p)
|
||||
defzero(lax.or_p)
|
||||
defzero(lax.xor_p)
|
||||
@ -188,6 +190,11 @@ defzero(lax.ceil_p)
|
||||
defzero(lax.round_p)
|
||||
defzero(lax.sign_p)
|
||||
defzero(ad_util.stop_gradient_p)
|
||||
defzero(lax.is_finite_p)
|
||||
defzero(lax.shift_left_p)
|
||||
defzero(lax.shift_right_arithmetic_p)
|
||||
defzero(lax.shift_right_logical_p)
|
||||
defzero(lax.bitcast_convert_type_p)
|
||||
|
||||
|
||||
def deflinear(prim):
|
||||
@ -201,6 +208,8 @@ def linear_prop(prim, primals_in, series_in, **params):
|
||||
deflinear(lax.neg_p)
|
||||
deflinear(lax.real_p)
|
||||
deflinear(lax.complex_p)
|
||||
deflinear(lax.conj_p)
|
||||
deflinear(lax.imag_p)
|
||||
deflinear(lax.add_p)
|
||||
deflinear(lax.sub_p)
|
||||
deflinear(lax.convert_element_type_p)
|
||||
@ -218,12 +227,14 @@ deflinear(lax.tie_in_p)
|
||||
deflinear(lax_fft.fft_p)
|
||||
deflinear(xla.device_put_p)
|
||||
|
||||
|
||||
def def_deriv(prim, deriv):
|
||||
"""
|
||||
Define the jet rule for a primitive in terms of its first derivative.
|
||||
"""
|
||||
jet_rules[prim] = partial(deriv_prop, prim, deriv)
|
||||
|
||||
|
||||
def deriv_prop(prim, deriv, primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
@ -240,6 +251,7 @@ def deriv_prop(prim, deriv, primals_in, series_in):
|
||||
|
||||
def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x)))))
|
||||
|
||||
|
||||
def def_comp(prim, comp):
|
||||
"""
|
||||
Define the jet rule for a primitive in terms of a composition of simpler primitives.
|
||||
@ -247,7 +259,16 @@ def def_comp(prim, comp):
|
||||
jet_rules[prim] = partial(jet, comp)
|
||||
|
||||
|
||||
def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
|
||||
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
|
||||
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
|
||||
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
|
||||
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
|
||||
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
|
||||
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
|
||||
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
|
||||
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
|
||||
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))
|
||||
|
||||
|
||||
def _erf_inv_rule(primals_in, series_in):
|
||||
@ -314,17 +335,6 @@ def _exp_taylor(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.exp_p] = _exp_taylor
|
||||
|
||||
def _expm1_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
v = [lax.exp(x)] + [None] * len(series)
|
||||
for k in range(1,len(v)):
|
||||
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
|
||||
primal_out, *series_out = v
|
||||
return lax.expm1(x), series_out
|
||||
jet_rules[lax.expm1_p] = _expm1_taylor
|
||||
|
||||
def _pow_taylor(primals_in, series_in):
|
||||
u_, r_ = primals_in
|
||||
|
||||
@ -340,7 +350,11 @@ def _pow_taylor(primals_in, series_in):
|
||||
jet_rules[lax.pow_p] = _pow_taylor
|
||||
|
||||
def _integer_pow_taylor(primals_in, series_in, *, y):
|
||||
if y == 2:
|
||||
if y == 0:
|
||||
return jet(jnp.ones_like, primals_in, series_in)
|
||||
elif y == 1:
|
||||
return jet(lambda x: x, primals_in, series_in)
|
||||
elif y == 2:
|
||||
return jet(lambda x: x * x, primals_in, series_in)
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
@ -391,26 +405,6 @@ def _log_taylor(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.log_p] = _log_taylor
|
||||
|
||||
def _sqrt_taylor(primals_in, series_in):
|
||||
return jet(lambda x: x ** 0.5, primals_in, series_in)
|
||||
jet_rules[lax.sqrt_p] = _sqrt_taylor
|
||||
|
||||
def _rsqrt_taylor(primals_in, series_in):
|
||||
return jet(lambda x: x ** -0.5, primals_in, series_in)
|
||||
jet_rules[lax.rsqrt_p] = _rsqrt_taylor
|
||||
|
||||
def _asinh_taylor(primals_in, series_in):
|
||||
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)), primals_in, series_in)
|
||||
jet_rules[lax.asinh_p] = _asinh_taylor
|
||||
|
||||
def _acosh_taylor(primals_in, series_in):
|
||||
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)), primals_in, series_in)
|
||||
jet_rules[lax.acosh_p] = _acosh_taylor
|
||||
|
||||
def _atanh_taylor(primals_in, series_in):
|
||||
return jet(lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)), primals_in, series_in)
|
||||
jet_rules[lax.atanh_p] = _atanh_taylor
|
||||
|
||||
def _atan2_taylor(primals_in, series_in):
|
||||
x, y = primals_in
|
||||
primal_out = lax.atan2(x, y)
|
||||
@ -426,19 +420,7 @@ def _atan2_taylor(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.atan2_p] = _atan2_taylor
|
||||
|
||||
def _log1p_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x + 1] + series
|
||||
v = [lax.log(x + 1)] + [None] * len(series)
|
||||
for k in range(1, len(v)):
|
||||
conv = sum([_scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
|
||||
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.log1p_p] = _log1p_taylor
|
||||
|
||||
def _div_taylor_rule(primals_in, series_in, **params):
|
||||
def _div_taylor_rule(primals_in, series_in):
|
||||
x, y = primals_in
|
||||
x_terms, y_terms = series_in
|
||||
u = [x] + x_terms
|
||||
@ -531,3 +513,37 @@ def _select_taylor_rule(primal_in, series_in, **params):
|
||||
series_out = [sel(*terms_in, **params) for terms_in in zip(*series_in)]
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.select_p] = _select_taylor_rule
|
||||
|
||||
|
||||
def _lax_max_taylor_rule(primal_in, series_in):
|
||||
x, y = primal_in
|
||||
|
||||
xgy = x > y # greater than mask
|
||||
xey = x == y # equal to mask
|
||||
primal_out = lax.select(xgy, x, y)
|
||||
|
||||
def select_max_and_avg_eq(x_i, y_i):
|
||||
"""Select x where x>y or average when x==y"""
|
||||
max_i = lax.select(xgy, x_i, y_i)
|
||||
max_i = lax.select(xey, (x_i + y_i)/2, max_i)
|
||||
return max_i
|
||||
|
||||
series_out = [select_max_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.max_p] = _lax_max_taylor_rule
|
||||
|
||||
def _lax_min_taylor_rule(primal_in, series_in):
|
||||
x, y = primal_in
|
||||
xgy = x < y # less than mask
|
||||
xey = x == y # equal to mask
|
||||
primal_out = lax.select(xgy, x, y)
|
||||
|
||||
def select_min_and_avg_eq(x_i, y_i):
|
||||
"""Select x where x>y or average when x==y"""
|
||||
min_i = lax.select(xgy, x_i, y_i)
|
||||
min_i = lax.select(xey, (x_i + y_i)/2, min_i)
|
||||
return min_i
|
||||
|
||||
series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.min_p] = _lax_min_taylor_rule
|
||||
|
@ -2116,6 +2116,7 @@ ad.defjvp(integer_pow_p, _integer_pow_jvp)
|
||||
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
||||
|
||||
not_p = standard_unop(_bool_or_int, 'not')
|
||||
ad.defjvp_zero(not_p)
|
||||
|
||||
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
||||
ad.defjvp_zero(and_p)
|
||||
|
@ -36,7 +36,7 @@ def jvp_taylor(fun, primals, series):
|
||||
def composition(eps):
|
||||
taylor_terms = [sum([eps ** (i+1) * terms[i] / fact(i + 1)
|
||||
for i in range(len(terms))]) for terms in series]
|
||||
nudged_args = [x + t for x, t in zip(primals, taylor_terms)]
|
||||
nudged_args = [(x + t).astype(x.dtype) for x, t in zip(primals, taylor_terms)]
|
||||
return fun(*nudged_args)
|
||||
primal_out = fun(*primals)
|
||||
terms_out = [repeated(jacfwd, i+1)(composition)(0.) for i in range(order)]
|
||||
@ -122,24 +122,36 @@ class JetTest(jtu.JaxTestCase):
|
||||
|
||||
self.check_jet(f, primals, series_in, check_dtypes=False)
|
||||
|
||||
def unary_check(self, fun, lims=[-2, 2], order=3):
|
||||
def unary_check(self, fun, lims=[-2, 2], order=3, dtype=None):
|
||||
dims = 2, 3
|
||||
rng = np.random.RandomState(0)
|
||||
primal_in = transform(lims, rng.rand(*dims))
|
||||
terms_in = [rng.randn(*dims) for _ in range(order)]
|
||||
if dtype is None:
|
||||
primal_in = transform(lims, rng.rand(*dims))
|
||||
terms_in = [rng.randn(*dims) for _ in range(order)]
|
||||
else:
|
||||
rng = jtu.rand_uniform(rng, *lims)
|
||||
primal_in = rng(dims, dtype)
|
||||
terms_in = [rng(dims, dtype) for _ in range(order)]
|
||||
self.check_jet(fun, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def binary_check(self, fun, lims=[-2, 2], order=3, finite=True):
|
||||
def binary_check(self, fun, lims=[-2, 2], order=3, finite=True, dtype=None):
|
||||
dims = 2, 3
|
||||
rng = np.random.RandomState(0)
|
||||
if isinstance(lims, tuple):
|
||||
x_lims, y_lims = lims
|
||||
else:
|
||||
x_lims, y_lims = lims, lims
|
||||
primal_in = (transform(x_lims, rng.rand(*dims)),
|
||||
transform(y_lims, rng.rand(*dims)))
|
||||
series_in = ([rng.randn(*dims) for _ in range(order)],
|
||||
[rng.randn(*dims) for _ in range(order)])
|
||||
if dtype is None:
|
||||
primal_in = (transform(x_lims, rng.rand(*dims)),
|
||||
transform(y_lims, rng.rand(*dims)))
|
||||
series_in = ([rng.randn(*dims) for _ in range(order)],
|
||||
[rng.randn(*dims) for _ in range(order)])
|
||||
else:
|
||||
rng = jtu.rand_uniform(rng, *lims)
|
||||
primal_in = (rng(dims, dtype),
|
||||
rng(dims, dtype))
|
||||
series_in = ([rng(dims, dtype) for _ in range(order)],
|
||||
[rng(dims, dtype) for _ in range(order)])
|
||||
if finite:
|
||||
self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
|
||||
else:
|
||||
@ -165,7 +177,7 @@ class JetTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_int_pow(self):
|
||||
for p in range(2, 6):
|
||||
for p in range(6):
|
||||
self.unary_check(lambda x: x ** p, lims=[-2, 2])
|
||||
self.unary_check(lambda x: x ** 10, lims=[0, 0])
|
||||
|
||||
@ -179,9 +191,19 @@ class JetTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ceil(self): self.unary_check(jnp.ceil)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_round(self): self.unary_check(jnp.round)
|
||||
def test_round(self): self.unary_check(lax.round)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sign(self): self.unary_check(jnp.sign)
|
||||
def test_sign(self): self.unary_check(lax.sign)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_real(self): self.unary_check(lax.real, dtype=np.complex64)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_conj(self): self.unary_check(lax.conj, dtype=np.complex64)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_imag(self): self.unary_check(lax.imag, dtype=np.complex64)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_not(self): self.unary_check(lax.bitwise_not, dtype=np.bool_)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_is_finite(self): self.unary_check(lax.is_finite)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_log(self): self.unary_check(jnp.log, lims=[0.8, 4.0])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -236,36 +258,65 @@ class JetTest(jtu.JaxTestCase):
|
||||
def test_erf_inv(self): self.unary_check(lax.erf_inv, lims=[-1, 1])
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
|
||||
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sub(self): self.binary_check(lambda x, y: x - y)
|
||||
def test_rem(self): self.binary_check(lax.rem, lims=[0.8, 4.0])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_add(self): self.binary_check(lambda x, y: x + y)
|
||||
def test_complex(self): self.binary_check(lax.complex)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_mul(self): self.binary_check(lambda x, y: x * y)
|
||||
def test_sub(self): self.binary_check(lambda x, y: x - y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_le(self): self.binary_check(lambda x, y: x <= y)
|
||||
def test_add(self): self.binary_check(lambda x, y: x + y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_gt(self): self.binary_check(lambda x, y: x > y)
|
||||
def test_mul(self): self.binary_check(lambda x, y: x * y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_lt(self): self.binary_check(lambda x, y: x < y)
|
||||
def test_le(self): self.binary_check(lambda x, y: x <= y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ge(self): self.binary_check(lambda x, y: x >= y)
|
||||
def test_gt(self): self.binary_check(lambda x, y: x > y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_eq(self): self.binary_check(lambda x, y: x == y)
|
||||
def test_lt(self): self.binary_check(lambda x, y: x < y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ne(self): self.binary_check(lambda x, y: x != y)
|
||||
def test_ge(self): self.binary_check(lambda x, y: x >= y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_and(self): self.binary_check(lambda x, y: jnp.logical_and(x, y))
|
||||
def test_eq(self): self.binary_check(lambda x, y: x == y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_or(self): self.binary_check(lambda x, y: jnp.logical_or(x, y))
|
||||
def test_ne(self): self.binary_check(lambda x, y: x != y)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_xor(self): self.binary_check(lambda x, y: jnp.logical_xor(x, y))
|
||||
def test_max(self): self.binary_check(lax.max)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_min(self): self.binary_check(lax.min)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_and(self): self.binary_check(lax.bitwise_and, dtype=np.bool_)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_or(self): self.binary_check(lax.bitwise_or, dtype=np.bool_)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_xor(self): self.binary_check(jnp.bitwise_xor, dtype=np.bool_)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_shift_left(self): self.binary_check(lax.shift_left, dtype=np.int32)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_shift_right_a(self): self.binary_check(lax.shift_right_arithmetic, dtype=np.int32)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_shift_right_l(self): self.binary_check(lax.shift_right_logical, dtype=np.int32)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jtu.ignore_warning(message="overflow encountered in power")
|
||||
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
|
||||
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_atan2(self): self.binary_check(lax.atan2, lims=[-40, 40])
|
||||
def test_atan2(self): self.binary_check(lax.atan2, lims=[-40, 40])
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_clamp(self):
|
||||
lims = [-2, 2]
|
||||
order = 3
|
||||
dims = 2, 3
|
||||
rng = np.random.RandomState(0)
|
||||
primal_in = (transform(lims, rng.rand(*dims)),
|
||||
transform(lims, rng.rand(*dims)),
|
||||
transform(lims, rng.rand(*dims)))
|
||||
series_in = ([rng.randn(*dims) for _ in range(order)],
|
||||
[rng.randn(*dims) for _ in range(order)],
|
||||
[rng.randn(*dims) for _ in range(order)])
|
||||
|
||||
self.check_jet(lax.clamp, primal_in, series_in, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_process_call(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user