diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index f975f0bf7..0398d3cf0 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -14,6 +14,7 @@ from functools import partial import itertools +import unittest from absl.testing import absltest @@ -766,6 +767,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): @genNamedParametersNArgs(5) def testTruncnormPdf(self, shapes, dtypes): + if jtu.device_under_test() == "cpu": + raise unittest.SkipTest("TODO(b/282695039): test fails at LLVM head") rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.truncnorm.pdf lax_fun = lsp_stats.truncnorm.pdf