Fixed tests for X64

This commit is contained in:
George Necula 2019-11-14 12:54:30 +01:00
parent 4bdfe5a66c
commit c6d3270512
2 changed files with 16 additions and 10 deletions

View File

@ -82,12 +82,18 @@ Then, from the repository root directory run
JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default 10):
cases that are generated and checked for each test (default is 10). The automated tests
currently use 25:
.. code-block:: shell
JAX_NUM_GENERATED_CASES=100 pytest -n auto tests
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
The automated tests also run the tests with default 64-bit floats and ints:
.. code-block:: shell
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
You can run a more specific set of tests using
`pytest <https://docs.pytest.org/en/latest/usage.html#specifying-tests-selecting-tests>`_'s

View File

@ -198,13 +198,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.while_loop(lambda c: (1., 1.), lambda c: c, 0.)
with self.assertRaisesRegex(TypeError,
re.escape("cond_fun must return a boolean scalar, but got output type(s) [ShapedArray(float32[])].")):
lax.while_loop(lambda c: 1., lambda c: c, 0.)
lax.while_loop(lambda c: np.float32(1.), lambda c: c, np.float32(0.))
with self.assertRaisesRegex(TypeError,
re.escape("body_fun output and input must have same type structure, got PyTreeDef(tuple, [*,*]) and *.")):
lax.while_loop(lambda c: True, lambda c: (1., 1.), 0.)
with self.assertRaisesRegex(TypeError,
re.escape("body_fun output and input must have identical types, got ShapedArray(bool[]) and ShapedArray(float32[]).")):
lax.while_loop(lambda c: True, lambda c: True, 0.)
lax.while_loop(lambda c: True, lambda c: True, np.float32(0.))
def testNestedWhileWithDynamicUpdateSlice(self):
num = 5
@ -521,9 +521,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.cond(True,
1., lambda top: 1., 2., lambda fop: (2., 2.))
with self.assertRaisesRegex(TypeError,
re.escape("true_fun and false_fun output must have identical types, got ShapedArray(int32[1]) and ShapedArray(int32[]).")):
re.escape("true_fun and false_fun output must have identical types, got ShapedArray(float32[1]) and ShapedArray(float32[]).")):
lax.cond(True,
1., lambda top: np.array([1]), 2., lambda fop: 1)
1., lambda top: np.array([1.], np.float32),
2., lambda fop: np.float32(1.))
def testCondOneBranchConstant(self):
@ -882,14 +883,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
'scan got value with no leading axis to scan over.*',
lambda: lax.scan(plus_one, p0, list(range(5))))
@jtu.skip_on_flag('jax_enable_x64', True) # With float64 error messages are different; hard to check precisely
def testScanTypeErrors(self):
"""Test typing error messages for scan."""
a = np.arange(5)
# Body output not a tuple
with self.assertRaisesRegex(TypeError,
re.escape("scan body output must be a pair, got ShapedArray(int32[]).")):
lax.scan(lambda c, x: 0, 0, a)
re.escape("scan body output must be a pair, got ShapedArray(float32[]).")):
lax.scan(lambda c, x: np.float32(0.), 0, a)
with self.assertRaisesRegex(TypeError,
re.escape("scan carry output and input must have same type structure, "
"got PyTreeDef(tuple, [*,*,*]) and PyTreeDef(tuple, [*,PyTreeDef(tuple, [*,*])])")):
@ -900,7 +900,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError,
re.escape("scan carry output and input must have identical types, "
"got ShapedArray(int32[]) and ShapedArray(float32[]).")):
lax.scan(lambda c, x: (0, x), 1.0, a)
lax.scan(lambda c, x: (np.int32(0), x), np.float32(1.0), a)
with self.assertRaisesRegex(TypeError,
re.escape("scan carry output and input must have same type structure, got * and PyTreeDef(tuple, [*,*]).")):
lax.scan(lambda c, x: (0, x), (1, 2), np.arange(5))