Added build rule for generated_fun_test (formerly quickish_check)

This commit is contained in:
Dougal Maclaurin 2018-12-06 17:04:00 -05:00
parent 29113dd606
commit c3374a9d5f
2 changed files with 28 additions and 27 deletions

View File

@ -91,3 +91,8 @@ jax_test(
name = "lapax_test",
srcs = ["lapax_test.py"],
)
jax_test(
name = "generated_fun_test",
srcs = ["generated_fun_test.py"],
)

View File

@ -15,10 +15,13 @@
from collections import namedtuple
from functools import partial
import numpy.random as npr
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as np
from jax import jit, jvp, vjp
import itertools as it
import sys
from jax.test_util as jtu
npr.seed(0)
@ -208,14 +211,6 @@ def partial_argnums(f, args, dyn_argnums):
dyn_args = [args[i] for i in dyn_argnums]
return f_, dyn_args
def jit_is_identity(fun):
vals = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
ans = fun(*vals)
static_argnums = thin(range(len(vals)), 0.5)
ans_jitted = jit(fun, static_argnums=static_argnums)(*vals)
check_all_close(ans, ans_jitted)
def jvp_matches_fd(fun):
vals = gen_vals(fun.in_vars)
tangents = gen_vals(fun.in_vars)
@ -245,27 +240,28 @@ def vjp_matches_fd(fun):
inner_prod_ad = inner_prod(in_tangents, out_cotangents)
check_close(inner_prod_fd, inner_prod_ad)
properties = [
jit_is_identity,
jvp_matches_fd,
vjp_matches_fd,
]
# vmap_matches_map ]
counter = it.count()
fresh = counter.next
def run_tests():
sizes = [3, 10]
num_examples = 50
cases = it.product(sizes, range(num_examples), properties)
for i, (size, _, check_prop) in enumerate(cases):
sys.stderr.write('\rTested: {}'.format(i))
class GeneratedFunTest(jtu.JaxTestCase):
"""Tests of transformations on randomly generated functions."""
@parameterized.named_parameters(take(
{"testcase_name": 'rand_fun_jit_test_{}'.format(fresh()),
"fun" : gen_fun_and_types(size) }
for _ in it.count()))
def testJitIsIdentity(self, fun):
vals = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
ans = fun(*vals)
static_argnums = thin(range(len(vals)), 0.5)
ans_jitted = jit(fun, static_argnums=static_argnums)(*vals)
try:
fun = gen_fun_and_types(size)
check_prop(fun)
check_all_close(ans, ans_jitted)
except:
print fun
raise
print "\nok"
if __name__ == "__main__":
run_tests()
config.config_with_absl()
absltest.main()