mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix test failures in JAX under NumPy 1.25.0rc1.
`jnp.finfo(...)` of an Array type yields: ``` TypeError: unhashable type: 'ArrayImpl' ``` However, `np.finfo(...)` no longer accepts NumPy arrays as input either, so it would be consistent to require the user to pass a dtype where they are currently passing an array. PiperOrigin-RevId: 539174254
This commit is contained in:
parent
6234d438af
commit
ef3f2abfd2
@ -408,8 +408,8 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||
)
|
||||
# Step sizes which are too small causes the optimizer to get stuck with a
|
||||
# direction of zero in <64 bit mode - avoid with a floor on minimum step size.
|
||||
alpha_k = state.a_star
|
||||
alpha_k = jnp.where((jnp.finfo(alpha_k).bits != 64)
|
||||
alpha_k = jnp.asarray(state.a_star)
|
||||
alpha_k = jnp.where((jnp.finfo(alpha_k.dtype).bits != 64)
|
||||
& (jnp.abs(alpha_k) < 1e-8),
|
||||
jnp.sign(alpha_k) * 1e-8,
|
||||
alpha_k)
|
||||
|
@ -210,7 +210,7 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
def testQdwhWithTinyElement(self, m, n, r, c, dtype):
|
||||
"""Tests qdwh on matrix with zeros and close-to-zero entries."""
|
||||
a = jnp.zeros((m, n), dtype=dtype)
|
||||
tiny_elem = jnp.finfo(a).tiny
|
||||
tiny_elem = jnp.finfo(a.dtype).tiny
|
||||
a = a.at[r, c].set(tiny_elem)
|
||||
|
||||
is_hermitian = _check_symmetry(a)
|
||||
|
@ -198,7 +198,7 @@ class SvdTest(jtu.JaxTestCase):
|
||||
def testSvdOnTinyElement(self, m, n, r, c, dtype):
|
||||
"""Tests SVD on matrix of zeros and close-to-zero entries."""
|
||||
a = jnp.zeros((m, n), dtype=dtype)
|
||||
tiny_element = jnp.finfo(a).tiny
|
||||
tiny_element = jnp.finfo(a.dtype).tiny
|
||||
a = a.at[r, c].set(tiny_element)
|
||||
|
||||
@jax.jit
|
||||
|
Loading…
x
Reference in New Issue
Block a user