mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
custom batching jvp tests
This commit is contained in:
parent
0ab93a039e
commit
ad7c7d6eab
@ -6460,6 +6460,180 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
'custom vmap rule output values must be a sequence.*',
|
||||
lambda: api.vmap(f)(xs))
|
||||
|
||||
def test_jvp_basic(self):
|
||||
@api.custom_vmap
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs):
|
||||
self.assertEqual(axis_size, 3)
|
||||
self.assertEqual(in_batched, [True])
|
||||
return [jnp.cos(xs)], in_batched
|
||||
|
||||
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
|
||||
|
||||
x, tx = jnp.array(1.), jnp.array(2.)
|
||||
xs, txs = jnp.arange(3.), jnp.arange(3.) * 2.
|
||||
|
||||
y, ty = f_jvp(x, tx)
|
||||
self.assertAllClose(y, jnp.sin(x))
|
||||
self.assertAllClose(ty, jnp.cos(x) * tx)
|
||||
|
||||
ys, tys = api.vmap(f_jvp)(xs, txs)
|
||||
self.assertAllClose(ys, jnp.cos(xs))
|
||||
self.assertAllClose(tys, -jnp.sin(xs) * txs)
|
||||
|
||||
ys, tys = api.jvp(api.vmap(f), [xs], [txs])
|
||||
self.assertAllClose(ys, jnp.cos(xs))
|
||||
self.assertAllClose(tys, -jnp.sin(xs) * txs)
|
||||
|
||||
def test_jvp_nary(self):
|
||||
@api.custom_vmap
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs, ys):
|
||||
self.assertEqual(axis_size, 3)
|
||||
self.assertEqual(in_batched, [True, True])
|
||||
return [jnp.cos(xs) + ys], [True]
|
||||
|
||||
f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty])
|
||||
|
||||
x, y, tx, ty = jnp.arange(4.)
|
||||
xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3))
|
||||
|
||||
zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys)
|
||||
self.assertAllClose(zs, jnp.cos(xs) + ys)
|
||||
self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys)
|
||||
|
||||
zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys])
|
||||
self.assertAllClose(zs, jnp.cos(xs) + ys)
|
||||
self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys)
|
||||
|
||||
def test_jvp_extra_batched_tangents(self):
|
||||
@api.custom_vmap
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs):
|
||||
self.assertEqual(axis_size, 3)
|
||||
self.assertEqual(in_batched, [False])
|
||||
return [jnp.cos(xs)], in_batched
|
||||
|
||||
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
|
||||
|
||||
x, txs = jnp.array(1.), 2. + jnp.arange(3.)
|
||||
y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)
|
||||
self.assertAllClose(y, jnp.cos(x))
|
||||
self.assertAllClose(tys, -jnp.sin(x) * txs)
|
||||
|
||||
def test_jacfwd(self):
|
||||
# jacfwd is another way to exercise extra-batched tangents
|
||||
|
||||
@api.custom_vmap
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs):
|
||||
self.assertEqual(axis_size, 3)
|
||||
self.assertEqual(in_batched, [False])
|
||||
return [jnp.cos(xs)], in_batched
|
||||
|
||||
x = jnp.arange(3.) + .72
|
||||
j = api.jacfwd(f)(x)
|
||||
self.assertAllClose(j, -jnp.diag(jnp.sin(x)))
|
||||
|
||||
def test_jvp_extra_batched_primals(self):
|
||||
@api.custom_vmap
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs):
|
||||
self.assertEqual(axis_size, 3)
|
||||
self.assertEqual(in_batched, [False])
|
||||
return [jnp.cos(xs)], in_batched
|
||||
|
||||
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
|
||||
|
||||
xs, tx = jnp.arange(3.), jnp.array(4.)
|
||||
ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx)
|
||||
self.assertAllClose(ys, jnp.cos(xs))
|
||||
self.assertAllClose(tys, -jnp.sin(xs) * tx)
|
||||
|
||||
def test_jvp_extra_batched_primals_with_linear_vmap_rule(self):
|
||||
# When a function is linear, its Jacobian is constant. JAX's JVP
|
||||
# of linear functions takes advantage of this: when mapping over a
|
||||
# batch of primals relative to a fixed (i.e. symbolically
|
||||
# replicated) tangent, output tangents remain replicated as well
|
||||
# (i.e. JAX will not broadcast them). This is true in general, and
|
||||
# this test checks that vmapped JVPs continue to behave this way
|
||||
# when custom_vmap is involved and the custom vmap rule is linear.
|
||||
|
||||
@api.custom_vmap
|
||||
def f_linear(x): return 7. * x
|
||||
|
||||
@f_linear.def_vmap
|
||||
def linear_rule(axis_size, in_batched, xs):
|
||||
return [11. * xs], in_batched
|
||||
|
||||
@api.custom_vmap
|
||||
def f_nonlinear(x): return jnp.sin(x)
|
||||
|
||||
@f_nonlinear.def_vmap
|
||||
def nonlinear_rule(axis_size, in_batched, xs):
|
||||
return [jnp.cos(xs)], in_batched
|
||||
|
||||
f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx])
|
||||
f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx])
|
||||
xs, tx = jnp.arange(3.), jnp.array(4.)
|
||||
|
||||
# doesn't err
|
||||
_ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)
|
||||
|
||||
# does err
|
||||
self.assertRaisesRegex(
|
||||
ValueError, 'vmap has mapped output but out_axes is None',
|
||||
lambda: api.vmap(
|
||||
f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx))
|
||||
|
||||
def test_jvp_dataflow_violation(self):
|
||||
# The jvp-of-custom-vmap machinery should not assume the standard
|
||||
# dataflow constraint on the JVP of the custom vmap rule (primal
|
||||
# outputs independent of tangent inputs). Both jvp and vmap are
|
||||
# "forward" transformations under which, at present, we don't
|
||||
# enforce the JVP dependence diagram. Because output primals can
|
||||
# depend on input tangents, extra-batched input tangents can
|
||||
# create batched output primals, as this test checks.
|
||||
|
||||
@api.custom_jvp
|
||||
def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x)
|
||||
|
||||
@cos_with_invalid_dataflow_jvp.defjvp
|
||||
def invalid_dataflow_jvp(x, tx):
|
||||
[x], [tx] = x, tx
|
||||
return jnp.cos(x * tx), tx
|
||||
|
||||
@api.custom_vmap
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs):
|
||||
return [cos_with_invalid_dataflow_jvp(xs)], in_batched
|
||||
|
||||
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
|
||||
x, txs = jnp.array(1.), 2. + jnp.arange(3.)
|
||||
|
||||
# doesn't err
|
||||
ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs)
|
||||
self.assertAllClose(ys, jnp.cos(x * txs))
|
||||
self.assertAllClose(tys, txs)
|
||||
|
||||
# does err
|
||||
self.assertRaisesRegex(
|
||||
ValueError, 'vmap has mapped output but out_axes is None',
|
||||
lambda: api.vmap(
|
||||
f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs))
|
||||
|
||||
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user