diff --git a/tests/third_party/scipy/line_search_test.py b/tests/third_party/scipy/line_search_test.py index 9b2480053..5a22372b7 100644 --- a/tests/third_party/scipy/line_search_test.py +++ b/tests/third_party/scipy/line_search_test.py @@ -86,6 +86,7 @@ class TestLineSearch(jtu.JaxTestCase): @jtu.sample_product( name=['_line_func_1', '_line_func_2'], ) + @jax.default_matmul_precision("float32") def test_line_search_wolfe2(self, name): def bind_index(func, idx): # Remember Python's closure semantics!