mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added build rule for generated_fun_test (formerly quickish_check)
This commit is contained in:
parent
29113dd606
commit
c3374a9d5f
@ -91,3 +91,8 @@ jax_test(
|
||||
name = "lapax_test",
|
||||
srcs = ["lapax_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "generated_fun_test",
|
||||
srcs = ["generated_fun_test.py"],
|
||||
)
|
||||
|
@ -15,10 +15,13 @@
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
import numpy.random as npr
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, jvp, vjp
|
||||
import itertools as it
|
||||
import sys
|
||||
from jax.test_util as jtu
|
||||
|
||||
npr.seed(0)
|
||||
|
||||
@ -208,14 +211,6 @@ def partial_argnums(f, args, dyn_argnums):
|
||||
dyn_args = [args[i] for i in dyn_argnums]
|
||||
return f_, dyn_args
|
||||
|
||||
def jit_is_identity(fun):
|
||||
vals = gen_vals(fun.in_vars)
|
||||
fun = partial(eval_fun, fun)
|
||||
ans = fun(*vals)
|
||||
static_argnums = thin(range(len(vals)), 0.5)
|
||||
ans_jitted = jit(fun, static_argnums=static_argnums)(*vals)
|
||||
check_all_close(ans, ans_jitted)
|
||||
|
||||
def jvp_matches_fd(fun):
|
||||
vals = gen_vals(fun.in_vars)
|
||||
tangents = gen_vals(fun.in_vars)
|
||||
@ -245,27 +240,28 @@ def vjp_matches_fd(fun):
|
||||
inner_prod_ad = inner_prod(in_tangents, out_cotangents)
|
||||
check_close(inner_prod_fd, inner_prod_ad)
|
||||
|
||||
properties = [
|
||||
jit_is_identity,
|
||||
jvp_matches_fd,
|
||||
vjp_matches_fd,
|
||||
]
|
||||
# vmap_matches_map ]
|
||||
counter = it.count()
|
||||
fresh = counter.next
|
||||
|
||||
def run_tests():
|
||||
sizes = [3, 10]
|
||||
num_examples = 50
|
||||
cases = it.product(sizes, range(num_examples), properties)
|
||||
for i, (size, _, check_prop) in enumerate(cases):
|
||||
sys.stderr.write('\rTested: {}'.format(i))
|
||||
class GeneratedFunTest(jtu.JaxTestCase):
|
||||
"""Tests of transformations on randomly generated functions."""
|
||||
|
||||
@parameterized.named_parameters(take(
|
||||
{"testcase_name": 'rand_fun_jit_test_{}'.format(fresh()),
|
||||
"fun" : gen_fun_and_types(size) }
|
||||
for _ in it.count()))
|
||||
def testJitIsIdentity(self, fun):
|
||||
vals = gen_vals(fun.in_vars)
|
||||
fun = partial(eval_fun, fun)
|
||||
ans = fun(*vals)
|
||||
static_argnums = thin(range(len(vals)), 0.5)
|
||||
ans_jitted = jit(fun, static_argnums=static_argnums)(*vals)
|
||||
try:
|
||||
fun = gen_fun_and_types(size)
|
||||
check_prop(fun)
|
||||
check_all_close(ans, ans_jitted)
|
||||
except:
|
||||
print fun
|
||||
raise
|
||||
|
||||
print "\nok"
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user