mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
671 lines
19 KiB
Python
671 lines
19 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax._src.util import safe_zip, safe_map
|
|
|
|
from jax.experimental import attrs
|
|
from jax.experimental.attrs import jax_setattr, jax_getattr
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
@dataclass
|
|
class Thing:
|
|
x: float
|
|
__hash__ = object.__hash__
|
|
__eq__ = object.__eq__
|
|
|
|
attrs.register(Thing) # enables passing as arg into jitted function
|
|
|
|
class AttrsTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_jit_basic(self, jit: bool):
|
|
thing = Thing(1.0)
|
|
|
|
def double_it() -> None:
|
|
cur_x = jax_getattr(thing, "x")
|
|
jax_setattr(thing, "x", cur_x * 2)
|
|
|
|
if jit:
|
|
double_it = jax.jit(double_it)
|
|
|
|
self.assertEqual(thing.x, 1.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 2.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 4.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 8.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 16.0)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_jit_basic_tree(self, jit: bool):
|
|
thing = Thing((1.0, 2.0))
|
|
|
|
def double_it() -> None:
|
|
(cur_x, cur_y) = jax_getattr(thing, "x")
|
|
jax_setattr(thing, "x", (cur_x * 2, cur_y * 2))
|
|
|
|
if jit:
|
|
double_it = jax.jit(double_it)
|
|
|
|
self.assertEqual(thing.x, (1.0, 2.0))
|
|
double_it()
|
|
self.assertEqual(thing.x, (2.0, 4.0))
|
|
double_it()
|
|
self.assertEqual(thing.x, (4.0, 8.0))
|
|
double_it()
|
|
self.assertEqual(thing.x, (8.0, 16.0))
|
|
double_it()
|
|
self.assertEqual(thing.x, (16.0, 32.0))
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_jit_basic_tree_changes(self, jit: bool):
|
|
thing = Thing(None)
|
|
count = 0
|
|
|
|
def double_it() -> None:
|
|
nonlocal count
|
|
count += 1
|
|
maybe_x = jax_getattr(thing, "x")
|
|
x = 1.0 if maybe_x is None else maybe_x
|
|
jax_setattr(thing, "x", 2 * x)
|
|
|
|
if jit:
|
|
double_it = jax.jit(double_it)
|
|
|
|
self.assertEqual(thing.x, None)
|
|
double_it()
|
|
self.assertEqual(thing.x, 2.0)
|
|
self.assertEqual(count, 1)
|
|
double_it()
|
|
self.assertEqual(thing.x, 4.0)
|
|
self.assertEqual(count, 2)
|
|
double_it()
|
|
self.assertEqual(thing.x, 8.0)
|
|
self.assertEqual(count, 2 + (not jit))
|
|
|
|
def test_jit_basic_tree_changes_multiple(self):
|
|
thing1 = Thing(None)
|
|
thing2 = Thing(0)
|
|
count = 0
|
|
|
|
@jax.jit
|
|
def double_it() -> None:
|
|
nonlocal count
|
|
count += 1
|
|
|
|
x1 = jax_getattr(thing1, "x")
|
|
if x1 is None:
|
|
jax_setattr(thing1, 'x', (None,))
|
|
elif isinstance(x1, tuple):
|
|
# depend on a new value
|
|
jax_setattr(thing1, 'x', jax_getattr(thing2, 'x') + 1)
|
|
else:
|
|
jax_setattr(thing2, 'x', jax_getattr(thing1, 'x'))
|
|
jax_setattr(thing1, 'x', None)
|
|
|
|
self.assertEqual(thing1.x, None)
|
|
self.assertEqual(thing2.x, 0)
|
|
double_it()
|
|
self.assertEqual(thing1.x, (None,))
|
|
self.assertEqual(thing2.x, 0)
|
|
self.assertEqual(count, 1)
|
|
double_it()
|
|
self.assertEqual(thing1.x, 1)
|
|
self.assertEqual(thing2.x, 0)
|
|
self.assertEqual(count, 2)
|
|
double_it()
|
|
self.assertEqual(thing1.x, None)
|
|
self.assertEqual(thing2.x, 1)
|
|
self.assertEqual(count, 3)
|
|
double_it()
|
|
self.assertEqual(thing1.x, (None,))
|
|
self.assertEqual(thing2.x, 1)
|
|
self.assertEqual(count, 3)
|
|
double_it()
|
|
self.assertEqual(thing1.x, 2)
|
|
self.assertEqual(thing2.x, 1)
|
|
self.assertEqual(count, 3)
|
|
double_it()
|
|
self.assertEqual(thing1.x, None)
|
|
self.assertEqual(thing2.x, 2)
|
|
self.assertEqual(count, 3)
|
|
|
|
def test_jit_nesting_basic(self):
|
|
thing = Thing(1.0)
|
|
|
|
@jax.jit
|
|
@jax.jit
|
|
def double_it() -> None:
|
|
cur_x = jax_getattr(thing, "x")
|
|
jax_setattr(thing, "x", cur_x * 2)
|
|
|
|
self.assertEqual(thing.x, 1.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 2.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 4.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 8.0)
|
|
double_it()
|
|
self.assertEqual(thing.x, 16.0)
|
|
|
|
def test_jit_consts_and_args(self):
|
|
thing = Thing(1.0)
|
|
|
|
@jax.jit
|
|
def double_it(y) -> None:
|
|
cur_x = jax_getattr(thing, "x")
|
|
jax_setattr(thing, "x", cur_x * 2)
|
|
return jnp.cos(np.arange(3.) * cur_x * y)
|
|
|
|
self.assertEqual(thing.x, 1.0)
|
|
double_it(2.)
|
|
self.assertEqual(thing.x, 2.0)
|
|
double_it(2.)
|
|
self.assertEqual(thing.x, 4.0)
|
|
double_it(2.)
|
|
self.assertEqual(thing.x, 8.0)
|
|
double_it(2.)
|
|
self.assertEqual(thing.x, 16.0)
|
|
|
|
def test_jit_transpose_basic(self):
|
|
thing = Thing(jnp.array(2.0))
|
|
|
|
@jax.custom_vjp
|
|
def foo(x):
|
|
return x
|
|
|
|
def foo_fwd(x):
|
|
return x, None
|
|
|
|
def foo_bwd(x, g):
|
|
jax_setattr(thing, 'x', g)
|
|
return g,
|
|
|
|
foo.defvjp(foo_fwd, foo_bwd)
|
|
|
|
foo(3.14)
|
|
self.assertEqual(thing.x, 2.0)
|
|
|
|
jax.grad(foo)(3.14)
|
|
self.assertEqual(thing.x, 1.0)
|
|
|
|
thing.x = jnp.array(3.14)
|
|
self.assertEqual(thing.x, 3.14)
|
|
|
|
jax.jit(jax.grad(foo))(3.14)
|
|
self.assertEqual(thing.x, 1.0)
|
|
|
|
thing.x = jnp.array(2.718)
|
|
self.assertEqual(thing.x, 2.718)
|
|
|
|
jax.grad(jax.jit(lambda x: jnp.sin(foo(x))))(3.0)
|
|
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
|
|
|
|
thing.x = jnp.array(3.14)
|
|
self.assertEqual(thing.x, 3.14)
|
|
|
|
def bar(x):
|
|
out = jnp.sin(foo(x))
|
|
jax_setattr(thing, 'x', 5.0)
|
|
return out
|
|
|
|
jax.grad(jax.jit(bar))(3.0)
|
|
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_scan_basic(self, jit: bool):
|
|
thing = Thing(1.0)
|
|
|
|
def double_it_10():
|
|
def body(_, __):
|
|
cur_x = jax_getattr(thing ,"x")
|
|
jax_setattr(thing, "x", cur_x * 2.0)
|
|
return None, None
|
|
_, _ = jax.lax.scan(body, None, None, length=10)
|
|
|
|
if jit:
|
|
double_it_10 = jax.jit(double_it_10)
|
|
|
|
double_it_10()
|
|
self.assertAllClose(thing.x, 1024., check_dtypes=False)
|
|
|
|
def test_scan_basic_consts_and_args(self):
|
|
thing = Thing(1.0)
|
|
|
|
def double_it_10(y):
|
|
def body(i, x):
|
|
cur_x = jax_getattr(thing ,"x")
|
|
jax_setattr(thing, "x", cur_x * 2.0)
|
|
return i + 1, (y, y)
|
|
_, _ = jax.lax.scan(body, 0, jnp.arange(10))
|
|
|
|
jax.jit(double_it_10)(jnp.arange(3.))
|
|
self.assertAllClose(thing.x, 1024., check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_scan_transpose_basic(self, jit: bool):
|
|
thing = Thing(1.0)
|
|
|
|
@jax.custom_vjp
|
|
def foo(x):
|
|
return x
|
|
|
|
def foo_fwd(x):
|
|
return x, None
|
|
|
|
def foo_bwd(x, g):
|
|
jax_setattr(thing, 'x', 2 * jax_getattr(thing, 'x') * g)
|
|
return g,
|
|
|
|
foo.defvjp(foo_fwd, foo_bwd)
|
|
|
|
|
|
def double_it_10(x):
|
|
def body(x, __):
|
|
return foo(x), None
|
|
x, _ = jax.lax.scan(body, x, None, length=10)
|
|
return x
|
|
|
|
if jit:
|
|
double_it_10 = jax.jit(double_it_10)
|
|
|
|
double_it_10(1.0)
|
|
self.assertAllClose(thing.x, 1., check_dtypes=False)
|
|
|
|
jax.grad(double_it_10)(1.0)
|
|
self.assertAllClose(thing.x, 1024., check_dtypes=False)
|
|
|
|
def test_arg_to_jit(self):
|
|
self.skipTest("regressed this experimental feature") # TODO(mattjj)
|
|
thing = Thing(1.0)
|
|
count = 0
|
|
|
|
@jax.jit
|
|
def f(obj, x):
|
|
nonlocal count
|
|
count += 1
|
|
jax_setattr(obj, 'x', x)
|
|
|
|
f(thing, 2.0) # don't crash!
|
|
self.assertAllClose(thing.x, 2.0, check_dtypes=False)
|
|
f(thing, 3.0)
|
|
self.assertAllClose(thing.x, 3.0, check_dtypes=False)
|
|
self.assertEqual(count, 1)
|
|
|
|
def test_tracer_lifetime_bug(self):
|
|
# regression test for https://github.com/jax-ml/jax/issues/20082
|
|
class StatefulRNG:
|
|
key: jax.Array
|
|
|
|
def __init__(self, key: jax.Array):
|
|
self.key = key
|
|
|
|
def split(self) -> jax.Array:
|
|
key = jax_getattr(self, "key")
|
|
new_key, returned_key = jax.random.split(key)
|
|
jax_setattr(self, "key", new_key)
|
|
return returned_key
|
|
|
|
rng = StatefulRNG(jax.random.key(0))
|
|
|
|
def jitted():
|
|
rng.split()
|
|
rng.split()
|
|
|
|
jax.jit(jitted)() # don't crash
|
|
|
|
def test_scan_carry(self):
|
|
class A:
|
|
...
|
|
|
|
a = A()
|
|
|
|
jax_setattr(a, 'x', jnp.zeros(3))
|
|
|
|
def body(i, _):
|
|
x = jax_getattr(a, 'x')
|
|
x = x.at[i].set(x[i] + 1)
|
|
jax_setattr(a, 'x', x)
|
|
return i + 1, None
|
|
_, _ = jax.lax.scan(body, 0, None, length=3) # don't crash
|
|
|
|
|
|
class AttrsJVPTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_jvp_basic(self, jit):
|
|
thing = Thing(2.0)
|
|
|
|
def f():
|
|
x = jax_getattr(thing, 'x')
|
|
x = jnp.sin(x)
|
|
jax_setattr(thing, 'x', x)
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
|
|
self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False)
|
|
(thing_, attr_, tangent_), = attr_tangents
|
|
self.assertIs(thing, thing_)
|
|
self.assertEqual(attr_, 'x')
|
|
self.assertAllClose(tangent_, jnp.cos(2.0), check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_jvp_clobber(self, jit):
|
|
thing = Thing(2.0)
|
|
|
|
def f():
|
|
x = jax_getattr(thing, 'x')
|
|
x = jnp.sin(2.0)
|
|
jax_setattr(thing, 'x', x)
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
|
|
self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False)
|
|
self.assertEmpty(attr_tangents)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_jvp_nowrite(self, jit):
|
|
thing = Thing(2.0)
|
|
|
|
def f():
|
|
x = jax_getattr(thing, 'x')
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
|
|
self.assertAllClose(thing.x, 2.0, check_dtypes=False)
|
|
(thing_, attr_, tangent_), = attr_tangents
|
|
self.assertIs(thing, thing_)
|
|
self.assertEqual(attr_, 'x')
|
|
self.assertAllClose(tangent_, 1.0, check_dtypes=False)
|
|
|
|
def test_jit_of_jvp(self):
|
|
thing = Thing(2.0)
|
|
|
|
def f():
|
|
x = jax_getattr(thing, 'x')
|
|
x = jnp.sin(x)
|
|
jax_setattr(thing, 'x', x)
|
|
|
|
@jax.jit
|
|
def g():
|
|
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
|
|
(thing_, attr_, tangent_), = attr_tangents
|
|
self.assertIs(thing, thing_)
|
|
self.assertEqual(attr_, 'x')
|
|
return jax_getattr(thing, 'x'), tangent_
|
|
|
|
x, tangent = g()
|
|
self.assertAllClose(x, jnp.sin(2.0), check_dtypes=False)
|
|
self.assertAllClose(tangent, jnp.cos(2.0), check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_jvp_higher_order(self, jit):
|
|
thing = Thing(2.0)
|
|
|
|
def f(y):
|
|
x = jax_getattr(thing, 'x')
|
|
w = jnp.tan(jnp.sin(y) * jnp.cos(x))
|
|
z = jnp.tan(jnp.cos(y) * jnp.sin(x))
|
|
jax_setattr(thing, 'x', z)
|
|
return w
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
def f_ref(x, y):
|
|
w = jnp.tan(jnp.sin(y) * jnp.cos(x))
|
|
z = jnp.tan(jnp.cos(y) * jnp.sin(x))
|
|
return w, z
|
|
|
|
x = jax.random.normal(jax.random.key(0), (3,))
|
|
x_dot = jax.random.normal(jax.random.key(1), (3,))
|
|
y = jax.random.normal(jax.random.key(2), (3,))
|
|
y_dot = jax.random.normal(jax.random.key(3), (3,))
|
|
|
|
setattr(thing, 'x', x)
|
|
w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)])
|
|
z = getattr(thing, 'x')
|
|
|
|
(w_, z_), (w_dot_, z_dot_) = jax.jvp(f_ref, (x, y), (x_dot, y_dot))
|
|
|
|
self.assertAllClose(w, w_, check_dtypes=False)
|
|
self.assertAllClose(z, z_, check_dtypes=False)
|
|
self.assertAllClose(w_dot, w_dot_, check_dtypes=False)
|
|
self.assertAllClose(z_dot, z_dot_, check_dtypes=False)
|
|
|
|
def g(x_dot, y, y_dot):
|
|
w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)])
|
|
return w, w_dot, z_dot
|
|
|
|
def g_ref(x, x_dot, y, y_dot):
|
|
(w, z), (w_dot, z_dot) = jax.jvp(f_ref, (x, y), (x_dot, y_dot))
|
|
return w, w_dot, z, z_dot
|
|
|
|
x_dot2 = jax.random.normal(jax.random.key(3), (3,))
|
|
x_ddot = jax.random.normal(jax.random.key(4), (3,))
|
|
y_dot2 = jax.random.normal(jax.random.key(5), (3,))
|
|
y_ddot = jax.random.normal(jax.random.key(6), (3,))
|
|
|
|
setattr(thing, 'x', x)
|
|
(w, w_dot, z_dot), (w_dot2, w_ddot, z_ddot), [(_, _, z_dot2)] = \
|
|
attrs.jvp(g, (x_dot, y, y_dot), (x_ddot, y_dot2, y_ddot),
|
|
[(thing, 'x', x_dot2)])
|
|
z = getattr(thing, 'x')
|
|
|
|
(w_, w_dot_, z_, z_dot_), (w_dot2_, w_ddot_, z_dot2_, z_ddot_) = \
|
|
jax.jvp(g_ref, (x, x_dot, y, y_dot), (x_dot2, x_ddot, y_dot2, y_ddot))
|
|
|
|
self.assertAllClose( w, w_, check_dtypes=False)
|
|
self.assertAllClose( z, z_, check_dtypes=False)
|
|
self.assertAllClose( w_dot, w_dot_, check_dtypes=False)
|
|
self.assertAllClose( z_dot, z_dot_, check_dtypes=False)
|
|
self.assertAllClose(w_dot2, w_dot2_, check_dtypes=False)
|
|
self.assertAllClose(z_dot2, z_dot2_, check_dtypes=False)
|
|
self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False)
|
|
self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False)
|
|
|
|
class AttrsLinTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_attr_output(self, jit):
|
|
thing = Thing(1.0)
|
|
|
|
def f(x, _):
|
|
y = jnp.sin(x)
|
|
jax_setattr(thing, 'x', y)
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
out, f_lin = attrs.linearize(f, 3.0, 4.0)
|
|
self.assertIsNone(out)
|
|
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)
|
|
|
|
out_dot, attr_tangents = f_lin(1.0, 2.0, attr_tangents={})
|
|
self.assertIsNone(out_dot)
|
|
self.assertAllClose(thing.x, jnp.sin(3.0)) # didn't change
|
|
self.assertLen(attr_tangents, 1)
|
|
self.assertAllClose(attr_tangents[(thing, 'x')], jnp.cos(3.0),
|
|
check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_attr_input(self, jit):
|
|
thing = Thing(1.0)
|
|
|
|
def f():
|
|
x = jax_getattr(thing, 'x')
|
|
return jnp.sin(x)
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
out, f_lin = attrs.linearize(f, attrs=[(thing, 'x')])
|
|
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)
|
|
|
|
out_dot, attr_tangents = f_lin(attr_tangents={(thing, 'x'): 2.0})
|
|
self.assertAllClose(out_dot, 2. * jnp.cos(1.0), check_dtypes=False)
|
|
self.assertLen(attr_tangents, 1)
|
|
self.assertAllClose(attr_tangents[(thing, 'x')], 2.0, check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_attr_inout(self, jit):
|
|
thing1 = Thing(1.0)
|
|
thing2 = Thing(2.0)
|
|
|
|
def f(x, y):
|
|
z = jax_getattr(thing1, 'x')
|
|
w = jax_getattr(thing2, 'x')
|
|
out = jnp.sin(x * y * z * w)
|
|
jax_setattr(thing1, 'x', out)
|
|
jax_setattr(thing2, 'x', 2 * out)
|
|
return 3 * out, 4 * out
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
def f_ref(x, y, z, w):
|
|
out = jnp.sin(x * y * z * w)
|
|
return (3 * out, 4 * out), (out, 2 * out)
|
|
|
|
out, f_lin = attrs.linearize(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
|
|
expected = (3 * jnp.sin(1. * 2. * 3. * 4.),
|
|
4 * jnp.sin(1. * 2. * 3. * 4.))
|
|
self.assertAllClose(out, expected, check_dtypes=False)
|
|
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
|
|
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))
|
|
|
|
(out_ref, state_out_ref), f_lin_ref = jax.linearize(f_ref, 3., 4., 1., 2.)
|
|
self.assertAllClose(out, out_ref, check_dtypes=False)
|
|
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)
|
|
|
|
out_dot, attr_tangents = f_lin(1., 2.,
|
|
attr_tangents={(thing1, 'x'): 5.,
|
|
(thing2, 'x'): 6.})
|
|
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
|
|
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))
|
|
(out_dot_ref, state_dot_ref) = f_lin_ref(1., 2., 5., 6.)
|
|
self.assertAllClose(out_dot, out_dot_ref, check_dtypes=False)
|
|
self.assertLen(attr_tangents, 2)
|
|
self.assertAllClose(attr_tangents[(thing1, 'x')], state_dot_ref[0],
|
|
check_dtypes=False)
|
|
self.assertAllClose(attr_tangents[(thing2, 'x')], state_dot_ref[1],
|
|
check_dtypes=False)
|
|
|
|
class AttrsVJPTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_attr_input(self, jit):
|
|
thing = Thing(1.0)
|
|
|
|
def f():
|
|
x = jax_getattr(thing, 'x')
|
|
return jnp.sin(x)
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
out, f_vjp = attrs.vjp(f, attrs=[(thing, 'x')])
|
|
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)
|
|
|
|
arg_cts, attr_cotangents = f_vjp(1.0)
|
|
self.assertEqual(arg_cts, ())
|
|
self.assertLen(attr_cotangents, 1)
|
|
self.assertAllClose(attr_cotangents[(thing, 'x')], jnp.cos(1.0),
|
|
check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_attr_output(self, jit):
|
|
thing = Thing(1.0)
|
|
|
|
def f(x, _):
|
|
y = jnp.sin(x)
|
|
jax_setattr(thing, 'x', y)
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
out, f_vjp = attrs.vjp(f, 3.0, 4.0)
|
|
self.assertIsNone(out)
|
|
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)
|
|
|
|
arg_cts, attr_cotangents = f_vjp(None, attr_cotangents={(thing, 'x'): 2.0})
|
|
self.assertAllClose(arg_cts, (2 * jnp.cos(3.0), 0.), check_dtypes=False)
|
|
self.assertLen(attr_cotangents, 0)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_attr_inout(self, jit):
|
|
thing1 = Thing(1.0)
|
|
thing2 = Thing(2.0)
|
|
|
|
def f(x, y):
|
|
z = jax_getattr(thing1, 'x')
|
|
w = jax_getattr(thing2, 'x')
|
|
out = jnp.sin(x * y * z * w)
|
|
jax_setattr(thing1, 'x', out)
|
|
jax_setattr(thing2, 'x', 2 * out)
|
|
return 3 * out, 4 * out
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
def f_ref(x, y, z, w):
|
|
out = jnp.sin(x * y * z * w)
|
|
return (3 * out, 4 * out), (out, 2 * out)
|
|
|
|
out, f_vjp = attrs.vjp(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
|
|
(out_ref, state_out_ref), f_vjp_ref = jax.vjp(f_ref, 3., 4., 1., 2.)
|
|
self.assertAllClose(out, out_ref, check_dtypes=False)
|
|
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)
|
|
|
|
in_bar, attr_cotangents = f_vjp((1., 2.),
|
|
attr_cotangents={(thing1, 'x'): 5.,
|
|
(thing2, 'x'): 6.})
|
|
in_bar_ref_ = f_vjp_ref(((1., 2.), (5., 6.)))
|
|
in_bar_ref, attr_cotangents_ref = in_bar_ref_[:2], in_bar_ref_[2:]
|
|
self.assertAllClose(in_bar, in_bar_ref, check_dtypes=False)
|
|
self.assertLen(attr_cotangents, 2)
|
|
self.assertAllClose(attr_cotangents[(thing1, 'x')], attr_cotangents_ref[0],
|
|
check_dtypes=False)
|
|
self.assertAllClose(attr_cotangents[(thing2, 'x')], attr_cotangents_ref[1],
|
|
check_dtypes=False)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|