mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix CI test failure in line_search_test.
A recent XLA change means that certain matrix-vector products may now be lowered as matrix-matrix multiplications. This may mean that we use lower precisions where we previously did not. PiperOrigin-RevId: 633949879
This commit is contained in:
parent
66a92c41f6
commit
7653db80fe
1
tests/third_party/scipy/line_search_test.py
vendored
1
tests/third_party/scipy/line_search_test.py
vendored
@ -86,6 +86,7 @@ class TestLineSearch(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
name=['_line_func_1', '_line_func_2'],
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_line_search_wolfe2(self, name):
|
||||
def bind_index(func, idx):
|
||||
# Remember Python's closure semantics!
|
||||
|
Loading…
x
Reference in New Issue
Block a user