rocm_jax/tests/checkify_test.py
2022-01-04 19:17:08 +00:00

418 lines
12 KiB
Python

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax import lax
from jax.config import config
from jax.experimental import checkify
import jax._src.test_util as jtu
config.parse_flags_with_absl()
class CheckifyTransformTests(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))
@jtu.skip_on_devices('tpu')
def test_jit_nan(self, jit):
def f(x1, x2):
y1 = jnp.sin(x1)
y2 = jnp.sin(x2)
return y1 + y2
f = jax.jit(f) if jit else f
err, _ = checkify.checkify(f)(3., 4.)
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(3., jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))
def test_jit_oob(self, jit):
def f(x, i):
y = jnp.sin(x)
z = y[i]
w = jnp.cos(z)
return w
f = jax.jit(f) if jit else f
err, _ = checkify.checkify(f)(jnp.arange(3), 2)
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(jnp.arange(3), 5)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))
def test_jit_div_errors(self, jit):
def f(x, y):
return x/y
f = jax.jit(f) if jit else f
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.ones((3,)))
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.array([1, 0, 1]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
err, _ = checkify.checkify(f)(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive div')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))
@jtu.skip_on_devices('tpu')
def test_jit_multi(self, jit):
def f(x, i):
y = x[i]
z = jnp.cos(y)
return z
f = jax.jit(f) if jit else f
# no error
err, _ = checkify.checkify(f)(jnp.array([0., jnp.inf, 2.]), 2)
self.assertIs(err.get(), None)
# oob error
err, _ = checkify.checkify(f)(jnp.array([0., 1., 2.]), 5)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
# nan error
err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 2)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive cos')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))
def test_jit_ordering(self, jit):
def f(x, i):
y = x[i]
z = jnp.sin(x)
return y * z
f = jax.jit(f) if jit else f
# both oob and nan error, but oob happens first
err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 5)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@jtu.skip_on_devices('tpu')
def test_pmap_basic(self):
if len(jax.devices()) < 2:
raise unittest.SkipTest("requires at least 2 devices")
@jax.pmap
def f(x1, x2):
y1 = jnp.sin(x1)
y2 = jnp.sin(x2)
return y1 + y2
xs = jnp.array([0., 2.])
err, _ = checkify.checkify(f)(xs, xs)
self.assertIs(err.get(), None)
ys = jnp.array([3., jnp.inf])
err, _ = checkify.checkify(f)(xs, ys)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
@jtu.skip_on_devices('tpu')
def test_cond_basic(self):
@jax.jit
def f(x):
return lax.cond(x > 0,
lambda: jnp.sin(x),
lambda: x)
err, y = checkify.checkify(f)(3.)
self.assertIs(err.get(), None)
err, y = checkify.checkify(f)(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
err, y = checkify.checkify(f)(-jnp.inf)
self.assertIs(err.get(), None)
@jtu.skip_on_devices('tpu')
def test_scan_map(self):
def scan_body(_, x):
return None, jnp.sin(x)
@jax.jit
def f(xs):
return lax.scan(scan_body, None, xs)
xs = jnp.array([0., 2.])
err, (_, ch_outs) = checkify.checkify(f)(xs)
_, outs = f(xs)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_outs, outs)
xs = jnp.array([3., jnp.inf])
err, (_, ch_outs) = checkify.checkify(f)(xs)
_, outs = f(xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertArraysEqual(ch_outs, outs)
@jtu.skip_on_devices('tpu')
def test_scan_carry(self):
def scan_body(carry, x):
carry = carry-1.
possible_nan = jnp.sin(1./carry)
return carry, x+possible_nan
@jax.jit
def f(carry, xs):
return lax.scan(scan_body, carry, xs)
carry, xs = 3., jnp.ones((2,))
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
# error happens on first iteration
carry, xs = 1., jnp.ones((2,))
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
# error happens on second iteration
carry, xs = 2., jnp.ones((4,))
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
@jtu.skip_on_devices("tpu")
def test_while_loop_body_error(self):
def while_cond(val):
i, _ = val
return i < 2
def while_body(val):
i, x = val
possible_nan = jnp.sin(1./i)
return i+1., x+possible_nan
@jax.jit
def f(init_val):
return lax.while_loop(while_cond, while_body, (init_val, 0.))
init_val = 1.
err, ch_out = checkify.checkify(f)(init_val)
out = f(init_val)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_out, out)
init_val = 0.
err, ch_out = checkify.checkify(f)(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_out, out)
@jtu.skip_on_devices("tpu")
def test_while_loop_cond_error(self):
def while_cond(val):
_ = jnp.sin(1./val)
return val < 2.
def while_body(val):
return val+1.
@jax.jit
def f(init_val):
return lax.while_loop(while_cond, while_body, init_val)
init_val = 1.
err, ch_out = checkify.checkify(f)(init_val)
out = f(init_val)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_out, out)
init_val = 0.
err, ch_out = checkify.checkify(f)(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_out, out)
@jtu.skip_on_devices("tpu")
def test_while_loop_cond_error_and_false(self):
# Tests if an error is generated when cond returns False.
def while_cond(val):
possible_nan = jnp.sin(1./val)
return jnp.logical_not(jnp.isnan(possible_nan))
@jax.jit
def f(init_val):
return lax.while_loop(while_cond, lambda val: val-1, init_val)
# error on first cond
init_val = 0.
err, _ = checkify.checkify(f)(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
# error on second cond
init_val = 1.
err, _ = checkify.checkify(f)(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
@jtu.skip_on_devices("tpu")
def test_while_loop_body_and_cond_error(self):
def while_cond(val):
i, cond_val, _ = val
_ = jnp.sin(cond_val)
return i < 2
def while_body(val):
i, cond_val, body_val = val
possible_nan = jnp.cos(body_val)
return i+1., cond_val, possible_nan
@jax.jit
def f(cond_val, body_val):
return lax.while_loop(while_cond, while_body, (0., cond_val, body_val))
cond_val = jnp.inf
body_val = 1.
err, _ = checkify.checkify(f)(cond_val, body_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
cond_val = 1.
body_val = jnp.inf
err, _ = checkify.checkify(f)(cond_val, body_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive cos")
cond_val = jnp.inf
body_val = jnp.inf
err, _ = checkify.checkify(f)(cond_val, body_val)
self.assertIsNotNone(err.get())
# first error which occurs is in cond
self.assertStartsWith(err.get(), "nan generated by primitive sin")
class AssertPrimitiveTests(jtu.JaxTestCase):
def test_assert_primitive_impl(self):
def f():
checkify.assert_(False, "hi")
with self.assertRaisesRegex(AssertionError, "hi"):
f()
def test_assert_primitive_(self):
@jax.jit
def f():
checkify.assert_(False, "hi")
with self.assertRaisesRegex(Exception, "can't be staged"):
f()
def test_assert_discharging(self):
@checkify.checkify
def f(x):
checkify.assert_(x > 0, "must be positive!")
return jnp.log(x)
err, y = f(1.)
self.assertIsNone(err.get())
err, y = f(0.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "must be positive")
f = jax.jit(f)
err, y = f(1.)
self.assertIsNone(err.get())
err, y = f(0.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "must be positive")
def test_assert2(self):
def f(pred): # note: data dependence needed!
checkify.assert2_(pred, 0, {0: "hi"})
with self.assertRaisesRegex(AssertionError, "hi"):
f(False)
f = checkify.checkify(f)
err, none = f(False)
self.assertIsNone(none)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "hi")
def test_discharge_recharge(self):
def ejit(f):
f = checkify.checkify(f)
f = jax.jit(f)
def jitted_f(*args):
err, out = f(*args)
checkify.assert2_(~err.err, err.code, err.msgs)
return out
return jitted_f
@ejit
def f(pred):
assert python_should_be_running
checkify.assert_(pred, "foo")
python_should_be_running = True
f(True)
python_should_be_running = False
f(True)
with self.assertRaisesRegex(AssertionError, "foo"):
f(False)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())