mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add ability to specify individual test targets via a regex (#3549)
* Add ability to specify individual test targets * fix missing imports * Use re.search and include test class name
This commit is contained in:
commit
a9fad49e1b
@ -108,6 +108,12 @@ file directly to see more detailed information about the cases being run::
|
||||
You can skip a few tests known as slow, by passing environment variable
|
||||
JAX_SKIP_SLOW_TESTS=1.
|
||||
|
||||
To specify a particular set of tests to run from a test file, you can pass a string
|
||||
or regular expression via the ``--test_targets`` flag. For example, you can run all
|
||||
the tests of ``jax.numpy.pad`` using::
|
||||
|
||||
python tests/lax_numpy_test.py --test_targets="testPad"
|
||||
|
||||
The Colab notebooks are tested for errors as part of the documentation build.
|
||||
|
||||
Type checking
|
||||
|
@ -244,4 +244,4 @@ class ControlExampleTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -95,4 +95,4 @@ class ExamplesTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -320,4 +320,4 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -171,4 +171,4 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -396,4 +396,4 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(f(tf.ones([])), 1.)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -23,6 +23,7 @@ import tensorflow as tf # type: ignore[import]
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -45,4 +46,4 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -58,4 +58,4 @@ class StaxTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -22,6 +22,7 @@ import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
@ -58,6 +59,12 @@ flags.DEFINE_bool(
|
||||
'Skip tests marked as slow (> 5 sec).'
|
||||
)
|
||||
|
||||
flags.DEFINE_string(
|
||||
'test_targets', '',
|
||||
'Regular expression specifying which tests to run, called via re.match on '
|
||||
'the test name. If empty or unspecified, run all tests.'
|
||||
)
|
||||
|
||||
EPS = 1e-4
|
||||
|
||||
def _dtype(x):
|
||||
@ -706,6 +713,16 @@ def cases_from_gens(*gens):
|
||||
yield ('_{}_{}'.format(size, i),) + tuple(gen(size) for gen in gens)
|
||||
|
||||
|
||||
class JaxTestLoader(absltest.TestLoader):
|
||||
def getTestCaseNames(self, testCaseClass):
|
||||
names = super().getTestCaseNames(testCaseClass)
|
||||
if FLAGS.test_targets:
|
||||
pattern = re.compile(FLAGS.test_targets)
|
||||
names = [name for name in names
|
||||
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
||||
return names
|
||||
|
||||
|
||||
class JaxTestCase(parameterized.TestCase):
|
||||
"""Base class for JAX tests including numerical checks and boilerplate."""
|
||||
|
||||
|
@ -3267,4 +3267,4 @@ class BufferDonationTest(jtu.JaxTestCase):
|
||||
self.assertEqual(buffer.is_deleted(), deleted)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -68,4 +68,4 @@ class ApiUtilTest(jtu.JaxTestCase):
|
||||
api_util.rebase_donate_argnums(donate, static))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -135,4 +135,4 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -968,4 +968,4 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -90,4 +90,4 @@ class CallbackTest(jtu.JaxTestCase):
|
||||
jnp.array([4.0, 6.0]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -435,4 +435,4 @@ class CoreTest(jtu.JaxTestCase):
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -51,4 +51,4 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -143,6 +143,5 @@ class DoubleDoubleTest(jtu.JaxTestCase):
|
||||
return result
|
||||
self.assertAllClose(op(*args), class_op(*args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -182,4 +182,4 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -381,4 +381,4 @@ class FftTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -268,4 +268,4 @@ class GeneratedFunTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1069,4 +1069,4 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -20,13 +20,13 @@ import jax
|
||||
from jax import lax, numpy as np
|
||||
from jax.config import config
|
||||
from jax.lib import xla_client
|
||||
import jax.test_util
|
||||
import jax.test_util as jtu
|
||||
import numpy as onp
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
class InfeedTest(jax.test_util.JaxTestCase):
|
||||
class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
def testInfeed(self):
|
||||
@jax.jit
|
||||
@ -92,4 +92,4 @@ class InfeedTest(jax.test_util.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -17,6 +17,7 @@ from absl.testing import absltest
|
||||
import jax.numpy as jnp
|
||||
from jax.tools.jax_to_hlo import jax_to_hlo
|
||||
from jax.lib import xla_client
|
||||
from jax import test_util as jtu
|
||||
|
||||
|
||||
class JaxToHloTest(absltest.TestCase):
|
||||
@ -73,4 +74,4 @@ class JaxToHloTest(absltest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -360,4 +360,4 @@ class JetTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1046,4 +1046,4 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -2356,4 +2356,4 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -341,4 +341,4 @@ class EinsumTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -992,4 +992,4 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -3904,4 +3904,4 @@ class NumpyGradTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -209,4 +209,4 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -159,4 +159,4 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -219,4 +219,4 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1803,4 +1803,4 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -689,4 +689,4 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1321,4 +1321,4 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(jsp.linalg.expm, (a,), modes=["fwd"], order=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -402,4 +402,4 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -733,4 +733,4 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -87,4 +87,4 @@ class MetadataTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -219,4 +219,4 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -191,4 +191,4 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -218,4 +218,4 @@ class NNInitializersTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -236,4 +236,4 @@ class ODETest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -304,4 +304,4 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
self.assertEqual(ans, expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -19,7 +19,7 @@ from absl.testing import absltest
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import optix
|
||||
import jax.test_util
|
||||
import jax.test_util as jtu
|
||||
from jax.tree_util import tree_leaves
|
||||
import numpy as np
|
||||
|
||||
@ -60,7 +60,7 @@ class OptixTest(absltest.TestCase):
|
||||
for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
|
||||
np.testing.assert_allclose(x, y, rtol=1e-5)
|
||||
|
||||
jax.test_util.skip_on_devices("tpu")
|
||||
jtu.skip_on_devices("tpu")
|
||||
def test_apply_every(self):
|
||||
# The frequency of the application of sgd
|
||||
k = 4
|
||||
@ -148,4 +148,4 @@ class OptixTest(absltest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -330,4 +330,4 @@ class ParallelizeTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1651,4 +1651,4 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(buf[0].to_py(), x[idx], check_dtypes=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -150,4 +150,4 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -21,7 +21,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
from jax.config import config
|
||||
import jax.test_util
|
||||
import jax.test_util as jtu
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
@ -68,4 +68,4 @@ class ProfilerTest(unittest.TestCase):
|
||||
del x
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -735,4 +735,4 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -153,4 +153,4 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -109,4 +109,4 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -460,4 +460,4 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -494,4 +494,4 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -258,4 +258,4 @@ class StaxTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_shape, out.shape)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -167,4 +167,4 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertTrue(tree_util.all_leaves([leaf]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -65,4 +65,4 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -142,4 +142,4 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -16,6 +16,7 @@
|
||||
from absl.testing import absltest
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client as xc
|
||||
from jax import test_util as jtu
|
||||
|
||||
|
||||
class XlaBridgeTest(absltest.TestCase):
|
||||
@ -57,4 +58,4 @@ class XlaBridgeTest(absltest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user