Add ability to specify individual test targets via a regex (#3549)

* Add ability to specify individual test targets

* fix missing imports

* Use re.search and include test class name
This commit is contained in:
Jake Vanderplas 2020-06-29 14:16:51 -07:00 committed by GitHub
commit a9fad49e1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 87 additions and 62 deletions

View File

@ -108,6 +108,12 @@ file directly to see more detailed information about the cases being run::
You can skip a few tests known as slow, by passing environment variable
JAX_SKIP_SLOW_TESTS=1.
To specify a particular set of tests to run from a test file, you can pass a string
or regular expression via the ``--test_targets`` flag. For example, you can run all
the tests of ``jax.numpy.pad`` using::
python tests/lax_numpy_test.py --test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
Type checking

View File

@ -244,4 +244,4 @@ class ControlExampleTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -95,4 +95,4 @@ class ExamplesTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -320,4 +320,4 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -171,4 +171,4 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -396,4 +396,4 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
self.assertEqual(f(tf.ones([])), 1.)
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -23,6 +23,7 @@ import tensorflow as tf # type: ignore[import]
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()
@ -45,4 +46,4 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -58,4 +58,4 @@ class StaxTest(tf_test_util.JaxToTfTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -22,6 +22,7 @@ import unittest
import warnings
import zlib
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
@ -58,6 +59,12 @@ flags.DEFINE_bool(
'Skip tests marked as slow (> 5 sec).'
)
flags.DEFINE_string(
'test_targets', '',
'Regular expression specifying which tests to run, called via re.match on '
'the test name. If empty or unspecified, run all tests.'
)
EPS = 1e-4
def _dtype(x):
@ -706,6 +713,16 @@ def cases_from_gens(*gens):
yield ('_{}_{}'.format(size, i),) + tuple(gen(size) for gen in gens)
class JaxTestLoader(absltest.TestLoader):
def getTestCaseNames(self, testCaseClass):
names = super().getTestCaseNames(testCaseClass)
if FLAGS.test_targets:
pattern = re.compile(FLAGS.test_targets)
names = [name for name in names
if pattern.search(f"{testCaseClass.__name__}.{name}")]
return names
class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""

View File

@ -3267,4 +3267,4 @@ class BufferDonationTest(jtu.JaxTestCase):
self.assertEqual(buffer.is_deleted(), deleted)
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -68,4 +68,4 @@ class ApiUtilTest(jtu.JaxTestCase):
api_util.rebase_donate_argnums(donate, static))
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -135,4 +135,4 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -968,4 +968,4 @@ class BatchingTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -90,4 +90,4 @@ class CallbackTest(jtu.JaxTestCase):
jnp.array([4.0, 6.0]))
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -435,4 +435,4 @@ class CoreTest(jtu.JaxTestCase):
core.check_jaxpr(jaxpr)
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -51,4 +51,4 @@ class DebugNaNsTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -143,6 +143,5 @@ class DoubleDoubleTest(jtu.JaxTestCase):
return result
self.assertAllClose(op(*args), class_op(*args))
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -182,4 +182,4 @@ class DtypesTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -381,4 +381,4 @@ class FftTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -268,4 +268,4 @@ class GeneratedFunTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1069,4 +1069,4 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -20,13 +20,13 @@ import jax
from jax import lax, numpy as np
from jax.config import config
from jax.lib import xla_client
import jax.test_util
import jax.test_util as jtu
import numpy as onp
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class InfeedTest(jax.test_util.JaxTestCase):
class InfeedTest(jtu.JaxTestCase):
def testInfeed(self):
@jax.jit
@ -92,4 +92,4 @@ class InfeedTest(jax.test_util.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -17,6 +17,7 @@ from absl.testing import absltest
import jax.numpy as jnp
from jax.tools.jax_to_hlo import jax_to_hlo
from jax.lib import xla_client
from jax import test_util as jtu
class JaxToHloTest(absltest.TestCase):
@ -73,4 +74,4 @@ class JaxToHloTest(absltest.TestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -360,4 +360,4 @@ class JetTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1046,4 +1046,4 @@ class LaxAutodiffTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -2356,4 +2356,4 @@ class LaxControlFlowTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -341,4 +341,4 @@ class EinsumTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -992,4 +992,4 @@ class IndexedUpdateTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -3904,4 +3904,4 @@ class NumpyGradTests(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -209,4 +209,4 @@ class VectorizeTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -159,4 +159,4 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -219,4 +219,4 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1803,4 +1803,4 @@ class LazyConstantTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -689,4 +689,4 @@ class LaxVmapTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1321,4 +1321,4 @@ class ScipyLinalgTest(jtu.JaxTestCase):
jtu.check_grads(jsp.linalg.expm, (a,), modes=["fwd"], order=1)
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -402,4 +402,4 @@ class LoopsTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -733,4 +733,4 @@ class MaskingTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -87,4 +87,4 @@ class MetadataTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -219,4 +219,4 @@ class MultiDeviceTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -191,4 +191,4 @@ class MultiBackendTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -218,4 +218,4 @@ class NNInitializersTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -236,4 +236,4 @@ class ODETest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -304,4 +304,4 @@ class OptimizerTests(jtu.JaxTestCase):
self.assertEqual(ans, expected)
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -19,7 +19,7 @@ from absl.testing import absltest
from jax import numpy as jnp
from jax.experimental import optimizers
from jax.experimental import optix
import jax.test_util
import jax.test_util as jtu
from jax.tree_util import tree_leaves
import numpy as np
@ -60,7 +60,7 @@ class OptixTest(absltest.TestCase):
for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
np.testing.assert_allclose(x, y, rtol=1e-5)
jax.test_util.skip_on_devices("tpu")
jtu.skip_on_devices("tpu")
def test_apply_every(self):
# The frequency of the application of sgd
k = 4
@ -148,4 +148,4 @@ class OptixTest(absltest.TestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -330,4 +330,4 @@ class ParallelizeTest(jtu.JaxTestCase):
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1651,4 +1651,4 @@ class ShardArgsTest(jtu.JaxTestCase):
self.assertAllClose(buf[0].to_py(), x[idx], check_dtypes=False)
if __name__ == '__main__':
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -150,4 +150,4 @@ class TestPolynomial(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -21,7 +21,7 @@ import jax
import jax.numpy as jnp
import jax.profiler
from jax.config import config
import jax.test_util
import jax.test_util as jtu
try:
import portpicker
@ -68,4 +68,4 @@ class ProfilerTest(unittest.TestCase):
del x
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -735,4 +735,4 @@ class LaxRandomTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -153,4 +153,4 @@ class NdimageTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -109,4 +109,4 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -460,4 +460,4 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -494,4 +494,4 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -258,4 +258,4 @@ class StaxTest(jtu.JaxTestCase):
self.assertEqual(out_shape, out.shape)
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -167,4 +167,4 @@ class TreeTest(jtu.JaxTestCase):
self.assertTrue(tree_util.all_leaves([leaf]))
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -65,4 +65,4 @@ class UtilTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -142,4 +142,4 @@ class VectorizeTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -16,6 +16,7 @@
from absl.testing import absltest
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
from jax import test_util as jtu
class XlaBridgeTest(absltest.TestCase):
@ -57,4 +58,4 @@ class XlaBridgeTest(absltest.TestCase):
if __name__ == "__main__":
absltest.main()
absltest.main(testLoader=jtu.JaxTestLoader())