Fix CI test failure in line_search_test.

A recent XLA change means that certain matrix-vector products may now be lowered as matrix-matrix multiplications. This may mean that we use lower precisions where we previously did not.

PiperOrigin-RevId: 633949879
This commit is contained in:
Peter Hawkins 2024-05-15 07:37:24 -07:00 committed by jax authors
parent 66a92c41f6
commit 7653db80fe

View File

@ -86,6 +86,7 @@ class TestLineSearch(jtu.JaxTestCase):
@jtu.sample_product(
name=['_line_func_1', '_line_func_2'],
)
@jax.default_matmul_precision("float32")
def test_line_search_wolfe2(self, name):
def bind_index(func, idx):
# Remember Python's closure semantics!