Fix test failures in JAX under NumPy 1.25.0rc1.

`jnp.finfo(...)` of an Array type yields:

```
TypeError: unhashable type: 'ArrayImpl'
```

However, `np.finfo(...)` no longer accepts NumPy arrays as input either, so it would be consistent to require the user to pass a dtype where they are currently passing an array.

PiperOrigin-RevId: 539174254
This commit is contained in:
Peter Hawkins 2023-06-09 14:10:00 -07:00 committed by jax authors
parent 6234d438af
commit ef3f2abfd2
3 changed files with 4 additions and 4 deletions

View File

@ -408,8 +408,8 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
)
# Step sizes which are too small causes the optimizer to get stuck with a
# direction of zero in <64 bit mode - avoid with a floor on minimum step size.
alpha_k = state.a_star
alpha_k = jnp.where((jnp.finfo(alpha_k).bits != 64)
alpha_k = jnp.asarray(state.a_star)
alpha_k = jnp.where((jnp.finfo(alpha_k.dtype).bits != 64)
& (jnp.abs(alpha_k) < 1e-8),
jnp.sign(alpha_k) * 1e-8,
alpha_k)

View File

@ -210,7 +210,7 @@ class QdwhTest(jtu.JaxTestCase):
def testQdwhWithTinyElement(self, m, n, r, c, dtype):
"""Tests qdwh on matrix with zeros and close-to-zero entries."""
a = jnp.zeros((m, n), dtype=dtype)
tiny_elem = jnp.finfo(a).tiny
tiny_elem = jnp.finfo(a.dtype).tiny
a = a.at[r, c].set(tiny_elem)
is_hermitian = _check_symmetry(a)

View File

@ -198,7 +198,7 @@ class SvdTest(jtu.JaxTestCase):
def testSvdOnTinyElement(self, m, n, r, c, dtype):
"""Tests SVD on matrix of zeros and close-to-zero entries."""
a = jnp.zeros((m, n), dtype=dtype)
tiny_element = jnp.finfo(a).tiny
tiny_element = jnp.finfo(a.dtype).tiny
a = a.at[r, c].set(tiny_element)
@jax.jit