rocm_jax/tests/attrs_test.py
Matthew Johnson 42ac4ca357 ref errors
2024-12-18 07:46:14 +00:00

671 lines
19 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.util import safe_zip, safe_map
from jax.experimental import attrs
from jax.experimental.attrs import jax_setattr, jax_getattr
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@dataclass
class Thing:
x: float
__hash__ = object.__hash__
__eq__ = object.__eq__
attrs.register(Thing) # enables passing as arg into jitted function
class AttrsTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_jit_basic(self, jit: bool):
thing = Thing(1.0)
def double_it() -> None:
cur_x = jax_getattr(thing, "x")
jax_setattr(thing, "x", cur_x * 2)
if jit:
double_it = jax.jit(double_it)
self.assertEqual(thing.x, 1.0)
double_it()
self.assertEqual(thing.x, 2.0)
double_it()
self.assertEqual(thing.x, 4.0)
double_it()
self.assertEqual(thing.x, 8.0)
double_it()
self.assertEqual(thing.x, 16.0)
@parameterized.parameters([True, False])
def test_jit_basic_tree(self, jit: bool):
thing = Thing((1.0, 2.0))
def double_it() -> None:
(cur_x, cur_y) = jax_getattr(thing, "x")
jax_setattr(thing, "x", (cur_x * 2, cur_y * 2))
if jit:
double_it = jax.jit(double_it)
self.assertEqual(thing.x, (1.0, 2.0))
double_it()
self.assertEqual(thing.x, (2.0, 4.0))
double_it()
self.assertEqual(thing.x, (4.0, 8.0))
double_it()
self.assertEqual(thing.x, (8.0, 16.0))
double_it()
self.assertEqual(thing.x, (16.0, 32.0))
@parameterized.parameters([True, False])
def test_jit_basic_tree_changes(self, jit: bool):
thing = Thing(None)
count = 0
def double_it() -> None:
nonlocal count
count += 1
maybe_x = jax_getattr(thing, "x")
x = 1.0 if maybe_x is None else maybe_x
jax_setattr(thing, "x", 2 * x)
if jit:
double_it = jax.jit(double_it)
self.assertEqual(thing.x, None)
double_it()
self.assertEqual(thing.x, 2.0)
self.assertEqual(count, 1)
double_it()
self.assertEqual(thing.x, 4.0)
self.assertEqual(count, 2)
double_it()
self.assertEqual(thing.x, 8.0)
self.assertEqual(count, 2 + (not jit))
def test_jit_basic_tree_changes_multiple(self):
thing1 = Thing(None)
thing2 = Thing(0)
count = 0
@jax.jit
def double_it() -> None:
nonlocal count
count += 1
x1 = jax_getattr(thing1, "x")
if x1 is None:
jax_setattr(thing1, 'x', (None,))
elif isinstance(x1, tuple):
# depend on a new value
jax_setattr(thing1, 'x', jax_getattr(thing2, 'x') + 1)
else:
jax_setattr(thing2, 'x', jax_getattr(thing1, 'x'))
jax_setattr(thing1, 'x', None)
self.assertEqual(thing1.x, None)
self.assertEqual(thing2.x, 0)
double_it()
self.assertEqual(thing1.x, (None,))
self.assertEqual(thing2.x, 0)
self.assertEqual(count, 1)
double_it()
self.assertEqual(thing1.x, 1)
self.assertEqual(thing2.x, 0)
self.assertEqual(count, 2)
double_it()
self.assertEqual(thing1.x, None)
self.assertEqual(thing2.x, 1)
self.assertEqual(count, 3)
double_it()
self.assertEqual(thing1.x, (None,))
self.assertEqual(thing2.x, 1)
self.assertEqual(count, 3)
double_it()
self.assertEqual(thing1.x, 2)
self.assertEqual(thing2.x, 1)
self.assertEqual(count, 3)
double_it()
self.assertEqual(thing1.x, None)
self.assertEqual(thing2.x, 2)
self.assertEqual(count, 3)
def test_jit_nesting_basic(self):
thing = Thing(1.0)
@jax.jit
@jax.jit
def double_it() -> None:
cur_x = jax_getattr(thing, "x")
jax_setattr(thing, "x", cur_x * 2)
self.assertEqual(thing.x, 1.0)
double_it()
self.assertEqual(thing.x, 2.0)
double_it()
self.assertEqual(thing.x, 4.0)
double_it()
self.assertEqual(thing.x, 8.0)
double_it()
self.assertEqual(thing.x, 16.0)
def test_jit_consts_and_args(self):
thing = Thing(1.0)
@jax.jit
def double_it(y) -> None:
cur_x = jax_getattr(thing, "x")
jax_setattr(thing, "x", cur_x * 2)
return jnp.cos(np.arange(3.) * cur_x * y)
self.assertEqual(thing.x, 1.0)
double_it(2.)
self.assertEqual(thing.x, 2.0)
double_it(2.)
self.assertEqual(thing.x, 4.0)
double_it(2.)
self.assertEqual(thing.x, 8.0)
double_it(2.)
self.assertEqual(thing.x, 16.0)
def test_jit_transpose_basic(self):
thing = Thing(jnp.array(2.0))
@jax.custom_vjp
def foo(x):
return x
def foo_fwd(x):
return x, None
def foo_bwd(x, g):
jax_setattr(thing, 'x', g)
return g,
foo.defvjp(foo_fwd, foo_bwd)
foo(3.14)
self.assertEqual(thing.x, 2.0)
jax.grad(foo)(3.14)
self.assertEqual(thing.x, 1.0)
thing.x = jnp.array(3.14)
self.assertEqual(thing.x, 3.14)
jax.jit(jax.grad(foo))(3.14)
self.assertEqual(thing.x, 1.0)
thing.x = jnp.array(2.718)
self.assertEqual(thing.x, 2.718)
jax.grad(jax.jit(lambda x: jnp.sin(foo(x))))(3.0)
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
thing.x = jnp.array(3.14)
self.assertEqual(thing.x, 3.14)
def bar(x):
out = jnp.sin(foo(x))
jax_setattr(thing, 'x', 5.0)
return out
jax.grad(jax.jit(bar))(3.0)
self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_basic(self, jit: bool):
thing = Thing(1.0)
def double_it_10():
def body(_, __):
cur_x = jax_getattr(thing ,"x")
jax_setattr(thing, "x", cur_x * 2.0)
return None, None
_, _ = jax.lax.scan(body, None, None, length=10)
if jit:
double_it_10 = jax.jit(double_it_10)
double_it_10()
self.assertAllClose(thing.x, 1024., check_dtypes=False)
def test_scan_basic_consts_and_args(self):
thing = Thing(1.0)
def double_it_10(y):
def body(i, x):
cur_x = jax_getattr(thing ,"x")
jax_setattr(thing, "x", cur_x * 2.0)
return i + 1, (y, y)
_, _ = jax.lax.scan(body, 0, jnp.arange(10))
jax.jit(double_it_10)(jnp.arange(3.))
self.assertAllClose(thing.x, 1024., check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_transpose_basic(self, jit: bool):
thing = Thing(1.0)
@jax.custom_vjp
def foo(x):
return x
def foo_fwd(x):
return x, None
def foo_bwd(x, g):
jax_setattr(thing, 'x', 2 * jax_getattr(thing, 'x') * g)
return g,
foo.defvjp(foo_fwd, foo_bwd)
def double_it_10(x):
def body(x, __):
return foo(x), None
x, _ = jax.lax.scan(body, x, None, length=10)
return x
if jit:
double_it_10 = jax.jit(double_it_10)
double_it_10(1.0)
self.assertAllClose(thing.x, 1., check_dtypes=False)
jax.grad(double_it_10)(1.0)
self.assertAllClose(thing.x, 1024., check_dtypes=False)
def test_arg_to_jit(self):
self.skipTest("regressed this experimental feature") # TODO(mattjj)
thing = Thing(1.0)
count = 0
@jax.jit
def f(obj, x):
nonlocal count
count += 1
jax_setattr(obj, 'x', x)
f(thing, 2.0) # don't crash!
self.assertAllClose(thing.x, 2.0, check_dtypes=False)
f(thing, 3.0)
self.assertAllClose(thing.x, 3.0, check_dtypes=False)
self.assertEqual(count, 1)
def test_tracer_lifetime_bug(self):
# regression test for https://github.com/jax-ml/jax/issues/20082
class StatefulRNG:
key: jax.Array
def __init__(self, key: jax.Array):
self.key = key
def split(self) -> jax.Array:
key = jax_getattr(self, "key")
new_key, returned_key = jax.random.split(key)
jax_setattr(self, "key", new_key)
return returned_key
rng = StatefulRNG(jax.random.key(0))
def jitted():
rng.split()
rng.split()
jax.jit(jitted)() # don't crash
def test_scan_carry(self):
class A:
...
a = A()
jax_setattr(a, 'x', jnp.zeros(3))
def body(i, _):
x = jax_getattr(a, 'x')
x = x.at[i].set(x[i] + 1)
jax_setattr(a, 'x', x)
return i + 1, None
_, _ = jax.lax.scan(body, 0, None, length=3) # don't crash
class AttrsJVPTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_jvp_basic(self, jit):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
x = jnp.sin(x)
jax_setattr(thing, 'x', x)
if jit:
f = jax.jit(f)
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False)
(thing_, attr_, tangent_), = attr_tangents
self.assertIs(thing, thing_)
self.assertEqual(attr_, 'x')
self.assertAllClose(tangent_, jnp.cos(2.0), check_dtypes=False)
@parameterized.parameters([True, False])
def test_jvp_clobber(self, jit):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
x = jnp.sin(2.0)
jax_setattr(thing, 'x', x)
if jit:
f = jax.jit(f)
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False)
self.assertEmpty(attr_tangents)
@parameterized.parameters([True, False])
def test_jvp_nowrite(self, jit):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
if jit:
f = jax.jit(f)
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
self.assertAllClose(thing.x, 2.0, check_dtypes=False)
(thing_, attr_, tangent_), = attr_tangents
self.assertIs(thing, thing_)
self.assertEqual(attr_, 'x')
self.assertAllClose(tangent_, 1.0, check_dtypes=False)
def test_jit_of_jvp(self):
thing = Thing(2.0)
def f():
x = jax_getattr(thing, 'x')
x = jnp.sin(x)
jax_setattr(thing, 'x', x)
@jax.jit
def g():
_, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)])
(thing_, attr_, tangent_), = attr_tangents
self.assertIs(thing, thing_)
self.assertEqual(attr_, 'x')
return jax_getattr(thing, 'x'), tangent_
x, tangent = g()
self.assertAllClose(x, jnp.sin(2.0), check_dtypes=False)
self.assertAllClose(tangent, jnp.cos(2.0), check_dtypes=False)
@parameterized.parameters([True, False])
def test_jvp_higher_order(self, jit):
thing = Thing(2.0)
def f(y):
x = jax_getattr(thing, 'x')
w = jnp.tan(jnp.sin(y) * jnp.cos(x))
z = jnp.tan(jnp.cos(y) * jnp.sin(x))
jax_setattr(thing, 'x', z)
return w
if jit:
f = jax.jit(f)
def f_ref(x, y):
w = jnp.tan(jnp.sin(y) * jnp.cos(x))
z = jnp.tan(jnp.cos(y) * jnp.sin(x))
return w, z
x = jax.random.normal(jax.random.key(0), (3,))
x_dot = jax.random.normal(jax.random.key(1), (3,))
y = jax.random.normal(jax.random.key(2), (3,))
y_dot = jax.random.normal(jax.random.key(3), (3,))
setattr(thing, 'x', x)
w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)])
z = getattr(thing, 'x')
(w_, z_), (w_dot_, z_dot_) = jax.jvp(f_ref, (x, y), (x_dot, y_dot))
self.assertAllClose(w, w_, check_dtypes=False)
self.assertAllClose(z, z_, check_dtypes=False)
self.assertAllClose(w_dot, w_dot_, check_dtypes=False)
self.assertAllClose(z_dot, z_dot_, check_dtypes=False)
def g(x_dot, y, y_dot):
w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)])
return w, w_dot, z_dot
def g_ref(x, x_dot, y, y_dot):
(w, z), (w_dot, z_dot) = jax.jvp(f_ref, (x, y), (x_dot, y_dot))
return w, w_dot, z, z_dot
x_dot2 = jax.random.normal(jax.random.key(3), (3,))
x_ddot = jax.random.normal(jax.random.key(4), (3,))
y_dot2 = jax.random.normal(jax.random.key(5), (3,))
y_ddot = jax.random.normal(jax.random.key(6), (3,))
setattr(thing, 'x', x)
(w, w_dot, z_dot), (w_dot2, w_ddot, z_ddot), [(_, _, z_dot2)] = \
attrs.jvp(g, (x_dot, y, y_dot), (x_ddot, y_dot2, y_ddot),
[(thing, 'x', x_dot2)])
z = getattr(thing, 'x')
(w_, w_dot_, z_, z_dot_), (w_dot2_, w_ddot_, z_dot2_, z_ddot_) = \
jax.jvp(g_ref, (x, x_dot, y, y_dot), (x_dot2, x_ddot, y_dot2, y_ddot))
self.assertAllClose( w, w_, check_dtypes=False)
self.assertAllClose( z, z_, check_dtypes=False)
self.assertAllClose( w_dot, w_dot_, check_dtypes=False)
self.assertAllClose( z_dot, z_dot_, check_dtypes=False)
self.assertAllClose(w_dot2, w_dot2_, check_dtypes=False)
self.assertAllClose(z_dot2, z_dot2_, check_dtypes=False)
self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False)
self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False)
class AttrsLinTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_attr_output(self, jit):
thing = Thing(1.0)
def f(x, _):
y = jnp.sin(x)
jax_setattr(thing, 'x', y)
if jit:
f = jax.jit(f)
out, f_lin = attrs.linearize(f, 3.0, 4.0)
self.assertIsNone(out)
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)
out_dot, attr_tangents = f_lin(1.0, 2.0, attr_tangents={})
self.assertIsNone(out_dot)
self.assertAllClose(thing.x, jnp.sin(3.0)) # didn't change
self.assertLen(attr_tangents, 1)
self.assertAllClose(attr_tangents[(thing, 'x')], jnp.cos(3.0),
check_dtypes=False)
@parameterized.parameters([True, False])
def test_attr_input(self, jit):
thing = Thing(1.0)
def f():
x = jax_getattr(thing, 'x')
return jnp.sin(x)
if jit:
f = jax.jit(f)
out, f_lin = attrs.linearize(f, attrs=[(thing, 'x')])
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)
out_dot, attr_tangents = f_lin(attr_tangents={(thing, 'x'): 2.0})
self.assertAllClose(out_dot, 2. * jnp.cos(1.0), check_dtypes=False)
self.assertLen(attr_tangents, 1)
self.assertAllClose(attr_tangents[(thing, 'x')], 2.0, check_dtypes=False)
@parameterized.parameters([True, False])
def test_attr_inout(self, jit):
thing1 = Thing(1.0)
thing2 = Thing(2.0)
def f(x, y):
z = jax_getattr(thing1, 'x')
w = jax_getattr(thing2, 'x')
out = jnp.sin(x * y * z * w)
jax_setattr(thing1, 'x', out)
jax_setattr(thing2, 'x', 2 * out)
return 3 * out, 4 * out
if jit:
f = jax.jit(f)
def f_ref(x, y, z, w):
out = jnp.sin(x * y * z * w)
return (3 * out, 4 * out), (out, 2 * out)
out, f_lin = attrs.linearize(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
expected = (3 * jnp.sin(1. * 2. * 3. * 4.),
4 * jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(out, expected, check_dtypes=False)
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))
(out_ref, state_out_ref), f_lin_ref = jax.linearize(f_ref, 3., 4., 1., 2.)
self.assertAllClose(out, out_ref, check_dtypes=False)
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)
out_dot, attr_tangents = f_lin(1., 2.,
attr_tangents={(thing1, 'x'): 5.,
(thing2, 'x'): 6.})
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))
(out_dot_ref, state_dot_ref) = f_lin_ref(1., 2., 5., 6.)
self.assertAllClose(out_dot, out_dot_ref, check_dtypes=False)
self.assertLen(attr_tangents, 2)
self.assertAllClose(attr_tangents[(thing1, 'x')], state_dot_ref[0],
check_dtypes=False)
self.assertAllClose(attr_tangents[(thing2, 'x')], state_dot_ref[1],
check_dtypes=False)
class AttrsVJPTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_attr_input(self, jit):
thing = Thing(1.0)
def f():
x = jax_getattr(thing, 'x')
return jnp.sin(x)
if jit:
f = jax.jit(f)
out, f_vjp = attrs.vjp(f, attrs=[(thing, 'x')])
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)
arg_cts, attr_cotangents = f_vjp(1.0)
self.assertEqual(arg_cts, ())
self.assertLen(attr_cotangents, 1)
self.assertAllClose(attr_cotangents[(thing, 'x')], jnp.cos(1.0),
check_dtypes=False)
@parameterized.parameters([True, False])
def test_attr_output(self, jit):
thing = Thing(1.0)
def f(x, _):
y = jnp.sin(x)
jax_setattr(thing, 'x', y)
if jit:
f = jax.jit(f)
out, f_vjp = attrs.vjp(f, 3.0, 4.0)
self.assertIsNone(out)
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)
arg_cts, attr_cotangents = f_vjp(None, attr_cotangents={(thing, 'x'): 2.0})
self.assertAllClose(arg_cts, (2 * jnp.cos(3.0), 0.), check_dtypes=False)
self.assertLen(attr_cotangents, 0)
@parameterized.parameters([True, False])
def test_attr_inout(self, jit):
thing1 = Thing(1.0)
thing2 = Thing(2.0)
def f(x, y):
z = jax_getattr(thing1, 'x')
w = jax_getattr(thing2, 'x')
out = jnp.sin(x * y * z * w)
jax_setattr(thing1, 'x', out)
jax_setattr(thing2, 'x', 2 * out)
return 3 * out, 4 * out
if jit:
f = jax.jit(f)
def f_ref(x, y, z, w):
out = jnp.sin(x * y * z * w)
return (3 * out, 4 * out), (out, 2 * out)
out, f_vjp = attrs.vjp(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
(out_ref, state_out_ref), f_vjp_ref = jax.vjp(f_ref, 3., 4., 1., 2.)
self.assertAllClose(out, out_ref, check_dtypes=False)
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)
in_bar, attr_cotangents = f_vjp((1., 2.),
attr_cotangents={(thing1, 'x'): 5.,
(thing2, 'x'): 6.})
in_bar_ref_ = f_vjp_ref(((1., 2.), (5., 6.)))
in_bar_ref, attr_cotangents_ref = in_bar_ref_[:2], in_bar_ref_[2:]
self.assertAllClose(in_bar, in_bar_ref, check_dtypes=False)
self.assertLen(attr_cotangents, 2)
self.assertAllClose(attr_cotangents[(thing1, 'x')], attr_cotangents_ref[0],
check_dtypes=False)
self.assertAllClose(attr_cotangents[(thing2, 'x')], attr_cotangents_ref[1],
check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())