skip checks in big randomness test

This commit is contained in:
Matthew Johnson 2020-09-23 20:15:32 -07:00
parent 96f5a3c402
commit 71f5f9972c

View File

@ -859,8 +859,9 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_eval_shape_big_random_array(self):
def f(x):
return random.normal(random.PRNGKey(x), (int(1e10),))
api.eval_shape(f, 0) # doesn't error
return random.normal(random.PRNGKey(x), (int(1e12),))
with core.skipping_checks(): # check_jaxpr will materialize array
api.eval_shape(f, 0) # doesn't error
if __name__ == "__main__":