mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove warning suppression for tuple and list arguments to reductions. (#3545)
Fix callers.
This commit is contained in:
parent
677baa54dd
commit
e680304dca
@ -341,7 +341,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax=f_jax)
|
||||
for f_jax in REDUCE))
|
||||
def test_reduce_ops_with_numerical_input(self, f_jax):
|
||||
values = [np.array([1, 2, 3], dtype=np.float32)]
|
||||
values = np.array([1, 2, 3], dtype=np.float32)
|
||||
self.ConvertAndCompare(f_jax, values, with_function=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -367,7 +367,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax=f_jax)
|
||||
for f_jax in REDUCE))
|
||||
def test_reduce_ops_with_boolean_input(self, f_jax):
|
||||
values = [np.array([True, False, True], dtype=np.bool_)]
|
||||
values = np.array([True, False, True], dtype=np.bool_)
|
||||
self.ConvertAndCompare(f_jax, values, with_function=True)
|
||||
|
||||
def test_gather_rank_change(self):
|
||||
|
@ -1522,8 +1522,8 @@ def _make_reduction(np_fun, op, init_val, preproc=None, bool_op=None,
|
||||
raise ValueError("reduction does not support the `out` argument.")
|
||||
|
||||
if isinstance(a, (list, tuple)):
|
||||
msg = "Reductions won't accept lists and tuples" + \
|
||||
"in future versions, only scalars and ndarrays"
|
||||
msg = ("jax.numpy reductions won't accept lists and tuples in future "
|
||||
"versions, only scalars and ndarrays")
|
||||
warnings.warn(msg, category=FutureWarning)
|
||||
a = a if isinstance(a, ndarray) else asarray(a)
|
||||
a = preproc(a) if preproc else a
|
||||
|
@ -107,7 +107,7 @@ def matrix_rank(M, tol=None):
|
||||
return jnp.any(M != 0).astype(jnp.int32)
|
||||
S = svd(M, full_matrices=False, compute_uv=False)
|
||||
if tol is None:
|
||||
tol = S.max() * jnp.max(M.shape) * jnp.finfo(S.dtype).eps
|
||||
tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
|
||||
return jnp.sum(S > tol)
|
||||
|
||||
|
||||
|
@ -4,7 +4,6 @@ filterwarnings =
|
||||
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
|
||||
ignore:Explicitly requested dtype.*is not available.*:UserWarning
|
||||
ignore:jax.experimental.vectorize is deprecated.*:FutureWarning
|
||||
ignore:Reductions won't accept lists and tuples*:FutureWarning
|
||||
# The rest are for experimental/jax_to_tf
|
||||
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
|
||||
ignore:can't resolve package from __spec__ or __package__:ImportWarning
|
||||
|
Loading…
x
Reference in New Issue
Block a user