rocm_jax/tests/attrs_test.py
Matthew Johnson 6fba4ecc58 PR #27576: [attrs] experimental appendattr
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576

This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.

This PR also includes some fixes for getattr/setattr.
Copybara import of the project:

--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:

[attrs] experimental appendattr

Merging this change closes #27576

COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
2025-03-28 15:21:12 -07:00

927 lines
25 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
import itertools as it
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.util import safe_zip, safe_map
from jax.experimental import attrs
from jax.experimental.attrs import jax_setattr, jax_getattr, jax_appendattr
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@dataclass
class Thing:
x: float
__hash__ = object.__hash__
__eq__ = object.__eq__
attrs.register(Thing) # enables passing as arg into jitted function
class AttrsTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_jit_basic(self, jit: bool):
thing = Thing(1.0)
def double_it() -> None:
cur_x = jax_getattr(thing, "x")
jax_setattr(thing, "x", cur_x * 2)
if jit:
double_it = jax.jit(double_it)
self.assertEqual(thing.x, 1.0)
double_it()
self.assertEqual(thing.x, 2.0)
double_it()
self.assertEqual(thing.x, 4.0)
double_it()
self.assertEqual(thing.x, 8.0)
double_it()
self.assertEqual(thing.x, 16.0)
def test_setattr_doesnt_leak(self):
thing = Thing(1.0)
@jax.jit
def f(x):
jax_setattr(thing, 'x', x)
raise Exception
try: f(1.)
except: pass
self.assertNotIsInstance(thing.x, jax.core.Tracer)
@parameterized.parameters([True, False])
def test_jit_basic_tree(self, jit: bool):
thing = Thing((1.0, 2.0))
def double_it() -> None:
(cur_x, cur_y) = jax_getattr(thing, "x")
jax_setattr(thing, "x", (cur_x * 2, cur_y * 2))
if jit:
double_it = jax.jit(double_it)
self.assertEqual(thing.x, (1.0, 2.0))
double_it()
self.assertEqual(thing.x, (2.0, 4.0))
double_it()
self.assertEqual(thing.x, (4.0, 8.0))
double_it()
self.assertEqual(thing.x, (8.0, 16.0))
double_it()
self.assertEqual(thing.x, (16.0, 32.0))
@parameterized.parameters([True, False])
def test_jit_basic_tree_changes(self, jit: bool):
thing = Thing(None)
count = 0
def double_it() -> None:
nonlocal count
count += 1
maybe_x = jax_getattr(thing, "x")
x = 1.0 if maybe_x is None else maybe_x
jax_setattr(thing, "x", 2 * x)
if jit:
double_it = jax.jit(double_it)
self.assertEqual(thing.x, None)
double_it()
self.assertEqual(thing.x, 2.0)
self.assertEqual(count, 1)
double_it()
self.assertEqual(thing.x, 4.0)
self.assertEqual(count, 2)
double_it()
self.assertEqual(thing.x, 8.0)
self.assertEqual(count, 2 + (not jit))
def test_jit_basic_tree_changes_multiple(self):
thing1 = Thing(None)
thing2 = Thing(0)
count = 0
@jax.jit
def double_it() -> None:
nonlocal count
count += 1
x1 = jax_getattr(thing1, "x")
if x1 is None:
jax_setattr(thing1, 'x', (None,))
elif isinstance(x1, tuple):
# depend on a new value
jax_setattr(thing1, 'x', jax_getattr(thing2, 'x') + 1)
else:
jax_setattr(thing2, 'x', jax_getattr(thing1, 'x'))
jax_setattr(thing1, 'x', None)
self.assertEqual(thing1.x, None)
self.assertEqual(thing2.x, 0)
double_it()
self.assertEqual(thing1.x, (None,))
self.assertEqual(thing2.x, 0)
self.assertEqual(count, 1)
double_it()
self.assertEqual(thing1.x, 1)
self.assertEqual(thing2.x, 0)
self.assertEqual(count, 2)
double_it()
self.assertEqual(thing1.x, None)
self.assertEqual(thing2.x, 1)
self.assertEqual(count, 3)
double_it()
self.assertEqual(thing1.x, (None,))
self.assertEqual(thing2.x, 1)
self.assertEqual(count, 3)
double_it()
self.assertEqual(thing1.x, 2)
self.assertEqual(thing2.x, 1)
self.assertEqual(count, 3)
double_it()
self.assertEqual(thing1.x, None)
self.assertEqual(thing2.x, 2)
self.assertEqual(count, 3)
def test_jit_nesting_basic(self):
thing = Thing(1.0)
@jax.jit
@jax.jit
def double_it() -> None:
cur_x = jax_getattr(thing, "x")
jax_setattr(thing, "x", cur_x * 2)
self.assertEqual(thing.x, 1.0)
double_it()
self.assertEqual(thing.x, 2.0)
double_it()
self.assertEqual(thing.x, 4.0)
double_it()
self.assertEqual(thing.x, 8.0)
double_it()
self.assertEqual(thing.x, 16.0)
def test_jit_consts_and_args(self):
thing = Thing(1.0)
@jax.jit
def double_it(y) -> None:
cur_x = jax_getattr(thing, "x")
jax_setattr(thing, "x", cur_x * 2)
return jnp.cos(np.arange(3.) * cur_x * y)
self.assertEqual(thing.x, 1.0)
double_it(2.)
self.assertEqual(thing.x, 2.0)
double_it(2.)
self.assertEqual(thing.x, 4.0)
double_it(2.)
self.assertEqual(thing.x, 8.0)
double_it(2.)
self.assertEqual(thing.x, 16.0)
def test_jit_transpose_basic(self):
thing = Thing(jnp.array(2.0))
@jax.custom_vjp
def foo(x):
return x
def foo_fwd(x):
return x, None
def foo_bwd(x, g):
jax_setattr(thing, 'x', g)
return g,
foo.defvjp(foo_fwd, foo_bwd)
foo(3.14)
self.assertEqual(thing.x, 2.0)
jax.grad(foo)(3.14)
self.assertEqual(thing.x, 1.0)
thing.x = jnp.array(3.14)
self.assertEqual(thing.x, 3.14)
jax.jit(jax.grad(foo))(3.14)
self.assertEqual(thing.x, 1.0)
thing.x = jnp.array(2.718)
self.assertEqual(thing.x, 2.718)
jax.grad(jax.jit(lambda x: jnp.sin(foo(x))))(3.0)
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
thing.x = jnp.array(3.14)
self.assertEqual(thing.x, 3.14)
def bar(x):
out = jnp.sin(foo(x))
jax_setattr(thing, 'x', 5.0)
return out
jax.grad(jax.jit(bar))(3.0)
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_basic(self, jit: bool):
thing = Thing(1.0)
def double_it_10():
def body(_, __):
cur_x = jax_getattr(thing ,"x")
jax_setattr(thing, "x", cur_x * 2.0)
return None, None
_, _ = jax.lax.scan(body, None, None, length=10)
if jit:
double_it_10 = jax.jit(double_it_10)
double_it_10()
self.assertAllClose(thing.x, 1024., check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_basic_pytree(self, jit):
class Thing: ...
thing = Thing()
thing.x = (1.0, 1.0)
def double_it_10():
def body(_, __):
cur_x, _ = jax_getattr(thing ,"x")
jax_setattr(thing, "x", (cur_x * 2.0, 3.0))
return None, None
_, _ = jax.lax.scan(body, None, None, length=10)
if jit:
double_it_10 = jax.jit(double_it_10)
double_it_10()
self.assertAllClose(thing.x[0], 1024., check_dtypes=False)
self.assertAllClose(thing.x[1], 3., check_dtypes=False)
def test_scan_basic_consts_and_args(self):
thing = Thing(1.0)
def double_it_10(y):
def body(i, x):
cur_x = jax_getattr(thing ,"x")
jax_setattr(thing, "x", cur_x * 2.0)
return i + 1, (y, y)
_, _ = jax.lax.scan(body, 0, jnp.arange(10))
jax.jit(double_it_10)(jnp.arange(3.))
self.assertAllClose(thing.x, 1024., check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_transpose_basic(self, jit: bool):
thing = Thing(1.0)
@jax.custom_vjp
def foo(x):
return x
def foo_fwd(x):
return x, None
def foo_bwd(x, g):
jax_setattr(thing, 'x', 2 * jax_getattr(thing, 'x') * g)
return g,
foo.defvjp(foo_fwd, foo_bwd)
def double_it_10(x):
def body(x, __):
return foo(x), None
x, _ = jax.lax.scan(body, x, None, length=10)
return x
if jit:
double_it_10 = jax.jit(double_it_10)
double_it_10(1.0)
self.assertAllClose(thing.x, 1., check_dtypes=False)
jax.grad(double_it_10)(1.0)
self.assertAllClose(thing.x, 1024., check_dtypes=False)
def test_arg_to_jit(self):
self.skipTest("regressed this experimental feature") # TODO(mattjj)
thing = Thing(1.0)
count = 0
@jax.jit
def f(obj, x):
nonlocal count
count += 1
jax_setattr(obj, 'x', x)
f(thing, 2.0) # don't crash!
self.assertAllClose(thing.x, 2.0, check_dtypes=False)
f(thing, 3.0)
self.assertAllClose(thing.x, 3.0, check_dtypes=False)
self.assertEqual(count, 1)
def test_tracer_lifetime_bug(self):
# regression test for https://github.com/jax-ml/jax/issues/20082
class StatefulRNG:
key: jax.Array
def __init__(self, key: jax.Array):
self.key = key
def split(self) -> jax.Array:
key = jax_getattr(self, "key")
new_key, returned_key = jax.random.split(key)
jax_setattr(self, "key", new_key)
return returned_key
rng = StatefulRNG(jax.random.key(0))
def jitted():
rng.split()
rng.split()
jax.jit(jitted)() # don't crash
def test_scan_carry(self):
class A:
...
a = A()
jax_setattr(a, 'x', jnp.zeros(3))
def body(i, _):
x = jax_getattr(a, 'x')
x = x.at[i].set(x[i] + 1)
jax_setattr(a, 'x', x)
return i + 1, None
_, _ = jax.lax.scan(body, 0, None, length=3) # don't crash
@parameterized.parameters([True, False])
def test_setattr_doesnt_exist(self, jit):
class Thing:
...
thing = Thing()
def f(x):
assert (not jit) or tracing_is_ok
jax_setattr(thing, 'x', x)
if jit:
f = jax.jit(f)
tracing_is_ok = True
self.assertFalse(hasattr(thing, 'x'))
f(1.0)
self.assertEqual(thing.x, 1.0)
f(2.0)
self.assertEqual(thing.x, 2.0)
tracing_is_ok = False
f(3.0)
self.assertEqual(thing.x, 3.0)
del thing.x
f(4.0)
self.assertEqual(thing.x, 4.0)
tracing_is_ok = True
f(5)
self.assertEqual(thing.x, 5)
def test_setattr_doesnt_exist_doesnt_leave_sentinel_around(self):
class Thing:
...
thing = Thing()
def f(x):
jax_setattr(thing, 'x', x)
jax.make_jaxpr(f)(3.)
self.assertFalse(hasattr(thing, 'x'))
tracing_ok = True
f(0.0)
self.assertAllClose(thing.x, 0.)
tracing_ok = False
f(1.0)
self.assertAllClose(thing.x, 1.)
@parameterized.parameters(it.product([False, True], repeat=2))
def test_appendattr_basic(self, jit, initialized):
class Thing:
...
thing = Thing()
if initialized:
thing.x = jnp.arange(0.)
def f(x):
assert (not jit) or tracing_ok
jax_appendattr(thing, 'x', x)
jax_appendattr(thing, 'x', x + 1)
if jit:
f = jax.jit(f)
tracing_ok = True
f(0.0)
self.assertAllClose(thing.x, jnp.array([0., 1.]))
tracing_ok = False
f(2.0)
self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3.]))
f(4.0)
self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.]))
@parameterized.parameters(it.product([False, True], repeat=2))
def test_appendattr_constant(self, jit, initialized):
class Thing: ...
thing = Thing()
if initialized:
thing.x = jnp.arange(0.)
def f():
assert (not jit) or tracing_ok
jax_appendattr(thing, 'x', 0.0)
jax_appendattr(thing, 'x', 1.0)
if jit:
f = jax.jit(f)
tracing_ok = True
f()
self.assertAllClose(thing.x, jnp.array([0., 1.]))
tracing_ok = False
f()
self.assertAllClose(thing.x, jnp.array([0., 1., 0., 1.]))
@parameterized.parameters([True, False])
def test_appendattr_getattr_errors(self, initialized):
class Thing: ...
thing = Thing()
if initialized:
thing.x = jnp.arange(0.)
@jax.jit
def f(x):
jax_appendattr(thing, 'x', x)
jax_getattr(thing, 'x')
with self.assertRaisesRegex(TypeError, "can't read/write"):
f(1.0)
@jax.jit
def g(x):
jax_setattr(thing, 'x', x)
jax_appendattr(thing, 'x', x)
with self.assertRaisesRegex(TypeError, "can't append"):
g(1.0)
if initialized:
self.assertNotIsInstance(thing.x, jax.core.Tracer)
else:
self.assertFalse(hasattr(thing, 'x'))
@parameterized.parameters(it.product([False, True], repeat=2))
def test_appendattr_dtype_disagreement(self, jit, initialized):
class Thing: ...
thing = Thing()
if initialized:
thing.x = jnp.array([], 'float32')
def f(x):
jax_appendattr(thing, 'x', x)
jax_appendattr(thing, 'x', x.astype('complex64'))
if jit:
f = jax.jit(f)
msg = "can only append to attr x with values of trailing shape "
msg += "float32" if initialized else "int32"
with self.assertRaisesRegex(TypeError, msg):
f(jnp.array(1, 'int32'))
@parameterized.parameters(it.product([False, True], repeat=2))
def test_appendattr_shape_disagreement(self, jit, initialized):
class Thing: ...
thing = Thing()
if initialized:
thing.x = jnp.array([])
def f(x):
jax_appendattr(thing, 'x', x)
jax_appendattr(thing, 'x', jnp.stack([x, x]))
if jit:
f = jax.jit(f)
msg = "can only append to attr x with values of trailing shape"
with self.assertRaisesRegex(TypeError, msg):
f(1)
@parameterized.parameters(it.product([False, True], repeat=2))
def test_appendattr_scan(self, jit, initialized):
class Thing: ...
thing = Thing()
if initialized:
thing.x = jnp.array([])
def f():
def body(c, x):
jax_appendattr(thing, 'x', 2 * x)
jax_appendattr(thing, 'x', 2 * x + 1)
return c, ()
_, () = jax.lax.scan(body, 0, jnp.arange(3.))
if jit:
f = jax.jit(f)
f()
self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.]))
@parameterized.parameters(it.product([False, True], repeat=2))
def test_appendattr_scan_vjp(self, jit, initialized):
class Thing: ...
thing = Thing()
if initialized:
thing.y_bar = jnp.array([])
def f(x):
def body(c, _):
return 0.5 * g(2 * c), ()
y, _ = jax.lax.scan(body, x, (), length=5)
return y
if jit:
f = jax.jit(f)
@jax.custom_vjp
def g(x):
return x
def g_fwd(x):
return g(x), None
def g_bwd(_, y_bar):
jax_appendattr(thing, 'y_bar', y_bar)
return y_bar,
g.defvjp(g_fwd, g_bwd)
jax.grad(f)(3.)
self.assertAllClose(thing.y_bar, jnp.array([0.5] * 5))
class AttrsJVPTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_jvp_basic(self, jit):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
x = jnp.sin(x)
jax_setattr(thing, 'x', x)
if jit:
f = jax.jit(f)
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False)
(thing_, attr_, tangent_), = attr_tangents
self.assertIs(thing, thing_)
self.assertEqual(attr_, 'x')
self.assertAllClose(tangent_, jnp.cos(2.0), check_dtypes=False)
@parameterized.parameters([True, False])
def test_jvp_clobber(self, jit):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
x = jnp.sin(2.0)
jax_setattr(thing, 'x', x)
if jit:
f = jax.jit(f)
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False)
self.assertEmpty(attr_tangents)
@parameterized.parameters([True, False])
def test_jvp_nowrite(self, jit):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
if jit:
f = jax.jit(f)
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
self.assertAllClose(thing.x, 2.0, check_dtypes=False)
(thing_, attr_, tangent_), = attr_tangents
self.assertIs(thing, thing_)
self.assertEqual(attr_, 'x')
self.assertAllClose(tangent_, 1.0, check_dtypes=False)
def test_jit_of_jvp(self):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
x = jnp.sin(x)
jax_setattr(thing, 'x', x)
@jax.jit
def g():
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
(thing_, attr_, tangent_), = attr_tangents
self.assertIs(thing, thing_)
self.assertEqual(attr_, 'x')
return jax_getattr(thing, 'x'), tangent_
x, tangent = g()
self.assertAllClose(x, jnp.sin(2.0), check_dtypes=False)
self.assertAllClose(tangent, jnp.cos(2.0), check_dtypes=False)
@parameterized.parameters([True, False])
def test_jvp_higher_order(self, jit):
thing = Thing(2.0)
def f(y):
x = jax_getattr(thing, 'x')
w = jnp.tan(jnp.sin(y) * jnp.cos(x))
z = jnp.tan(jnp.cos(y) * jnp.sin(x))
jax_setattr(thing, 'x', z)
return w
if jit:
f = jax.jit(f)
def f_ref(x, y):
w = jnp.tan(jnp.sin(y) * jnp.cos(x))
z = jnp.tan(jnp.cos(y) * jnp.sin(x))
return w, z
x = jax.random.normal(jax.random.key(0), (3,))
x_dot = jax.random.normal(jax.random.key(1), (3,))
y = jax.random.normal(jax.random.key(2), (3,))
y_dot = jax.random.normal(jax.random.key(3), (3,))
setattr(thing, 'x', x)
w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)])
z = getattr(thing, 'x')
(w_, z_), (w_dot_, z_dot_) = jax.jvp(f_ref, (x, y), (x_dot, y_dot))
self.assertAllClose(w, w_, check_dtypes=False)
self.assertAllClose(z, z_, check_dtypes=False)
self.assertAllClose(w_dot, w_dot_, check_dtypes=False)
self.assertAllClose(z_dot, z_dot_, check_dtypes=False)
def g(x_dot, y, y_dot):
w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)])
return w, w_dot, z_dot
def g_ref(x, x_dot, y, y_dot):
(w, z), (w_dot, z_dot) = jax.jvp(f_ref, (x, y), (x_dot, y_dot))
return w, w_dot, z, z_dot
x_dot2 = jax.random.normal(jax.random.key(3), (3,))
x_ddot = jax.random.normal(jax.random.key(4), (3,))
y_dot2 = jax.random.normal(jax.random.key(5), (3,))
y_ddot = jax.random.normal(jax.random.key(6), (3,))
setattr(thing, 'x', x)
(w, w_dot, z_dot), (w_dot2, w_ddot, z_ddot), [(_, _, z_dot2)] = \
attrs.jvp(g, (x_dot, y, y_dot), (x_ddot, y_dot2, y_ddot),
[(thing, 'x', x_dot2)])
z = getattr(thing, 'x')
(w_, w_dot_, z_, z_dot_), (w_dot2_, w_ddot_, z_dot2_, z_ddot_) = \
jax.jvp(g_ref, (x, x_dot, y, y_dot), (x_dot2, x_ddot, y_dot2, y_ddot))
self.assertAllClose( w, w_, check_dtypes=False)
self.assertAllClose( z, z_, check_dtypes=False)
self.assertAllClose( w_dot, w_dot_, check_dtypes=False)
self.assertAllClose( z_dot, z_dot_, check_dtypes=False)
self.assertAllClose(w_dot2, w_dot2_, check_dtypes=False)
self.assertAllClose(z_dot2, z_dot2_, check_dtypes=False)
self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False)
self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False)
class AttrsLinTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_attr_output(self, jit):
thing = Thing(1.0)
def f(x, _):
y = jnp.sin(x)
jax_setattr(thing, 'x', y)
if jit:
f = jax.jit(f)
out, f_lin = attrs.linearize(f, 3.0, 4.0)
self.assertIsNone(out)
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)
out_dot, attr_tangents = f_lin(1.0, 2.0, attr_tangents={})
self.assertIsNone(out_dot)
self.assertAllClose(thing.x, jnp.sin(3.0)) # didn't change
self.assertLen(attr_tangents, 1)
self.assertAllClose(attr_tangents[(thing, 'x')], jnp.cos(3.0),
check_dtypes=False)
@parameterized.parameters([True, False])
def test_attr_input(self, jit):
thing = Thing(1.0)
def f():
x = jax_getattr(thing, 'x')
return jnp.sin(x)
if jit:
f = jax.jit(f)
out, f_lin = attrs.linearize(f, attrs=[(thing, 'x')])
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)
out_dot, attr_tangents = f_lin(attr_tangents={(thing, 'x'): 2.0})
self.assertAllClose(out_dot, 2. * jnp.cos(1.0), check_dtypes=False)
self.assertLen(attr_tangents, 1)
self.assertAllClose(attr_tangents[(thing, 'x')], 2.0, check_dtypes=False)
@parameterized.parameters([True, False])
def test_attr_inout(self, jit):
thing1 = Thing(1.0)
thing2 = Thing(2.0)
def f(x, y):
z = jax_getattr(thing1, 'x')
w = jax_getattr(thing2, 'x')
out = jnp.sin(x * y * z * w)
jax_setattr(thing1, 'x', out)
jax_setattr(thing2, 'x', 2 * out)
return 3 * out, 4 * out
if jit:
f = jax.jit(f)
def f_ref(x, y, z, w):
out = jnp.sin(x * y * z * w)
return (3 * out, 4 * out), (out, 2 * out)
out, f_lin = attrs.linearize(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
expected = (3 * jnp.sin(1. * 2. * 3. * 4.),
4 * jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(out, expected, check_dtypes=False)
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))
(out_ref, state_out_ref), f_lin_ref = jax.linearize(f_ref, 3., 4., 1., 2.)
self.assertAllClose(out, out_ref, check_dtypes=False)
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)
out_dot, attr_tangents = f_lin(1., 2.,
attr_tangents={(thing1, 'x'): 5.,
(thing2, 'x'): 6.})
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))
(out_dot_ref, state_dot_ref) = f_lin_ref(1., 2., 5., 6.)
self.assertAllClose(out_dot, out_dot_ref, check_dtypes=False)
self.assertLen(attr_tangents, 2)
self.assertAllClose(attr_tangents[(thing1, 'x')], state_dot_ref[0],
check_dtypes=False)
self.assertAllClose(attr_tangents[(thing2, 'x')], state_dot_ref[1],
check_dtypes=False)
class AttrsVJPTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_attr_input(self, jit):
thing = Thing(1.0)
def f():
x = jax_getattr(thing, 'x')
return jnp.sin(x)
if jit:
f = jax.jit(f)
out, f_vjp = attrs.vjp(f, attrs=[(thing, 'x')])
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)
arg_cts, attr_cotangents = f_vjp(1.0)
self.assertEqual(arg_cts, ())
self.assertLen(attr_cotangents, 1)
self.assertAllClose(attr_cotangents[(thing, 'x')], jnp.cos(1.0),
check_dtypes=False)
@parameterized.parameters([True, False])
def test_attr_output(self, jit):
thing = Thing(1.0)
def f(x, _):
y = jnp.sin(x)
jax_setattr(thing, 'x', y)
if jit:
f = jax.jit(f)
out, f_vjp = attrs.vjp(f, 3.0, 4.0)
self.assertIsNone(out)
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)
arg_cts, attr_cotangents = f_vjp(None, attr_cotangents={(thing, 'x'): 2.0})
self.assertAllClose(arg_cts, (2 * jnp.cos(3.0), 0.), check_dtypes=False)
self.assertLen(attr_cotangents, 0)
@parameterized.parameters([True, False])
def test_attr_inout(self, jit):
thing1 = Thing(1.0)
thing2 = Thing(2.0)
def f(x, y):
z = jax_getattr(thing1, 'x')
w = jax_getattr(thing2, 'x')
out = jnp.sin(x * y * z * w)
jax_setattr(thing1, 'x', out)
jax_setattr(thing2, 'x', 2 * out)
return 3 * out, 4 * out
if jit:
f = jax.jit(f)
def f_ref(x, y, z, w):
out = jnp.sin(x * y * z * w)
return (3 * out, 4 * out), (out, 2 * out)
out, f_vjp = attrs.vjp(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
(out_ref, state_out_ref), f_vjp_ref = jax.vjp(f_ref, 3., 4., 1., 2.)
self.assertAllClose(out, out_ref, check_dtypes=False)
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)
in_bar, attr_cotangents = f_vjp((1., 2.),
attr_cotangents={(thing1, 'x'): 5.,
(thing2, 'x'): 6.})
in_bar_ref_ = f_vjp_ref(((1., 2.), (5., 6.)))
in_bar_ref, attr_cotangents_ref = in_bar_ref_[:2], in_bar_ref_[2:]
self.assertAllClose(in_bar, in_bar_ref, check_dtypes=False)
self.assertLen(attr_cotangents, 2)
self.assertAllClose(attr_cotangents[(thing1, 'x')], attr_cotangents_ref[0],
check_dtypes=False)
self.assertAllClose(attr_cotangents[(thing2, 'x')], attr_cotangents_ref[1],
check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())