mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
X32 tests: fail on dtype warnings
This commit is contained in:
parent
4504f27492
commit
f74235cdae
@ -54,7 +54,7 @@ def compute_weight_mat(input_size: int, output_size: int, scale,
|
||||
# want to interpolate.
|
||||
kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
|
||||
|
||||
sample_f = ((np.arange(output_size) + 0.5) * inv_scale -
|
||||
sample_f = ((jnp.arange(output_size) + 0.5) * inv_scale -
|
||||
translation * inv_scale - 0.5)
|
||||
x = (
|
||||
jnp.abs(sample_f[jnp.newaxis, :] -
|
||||
|
@ -6378,7 +6378,7 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
"See https://github.com/google/jax#current-gotchas for more.")
|
||||
fun_name = f"requested in {fun_name}" if fun_name else ""
|
||||
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=2)
|
||||
|
||||
|
||||
def _canonicalize_axis(axis, num_dims):
|
||||
|
@ -4985,7 +4985,7 @@ def _unimplemented_setitem(self, i, x):
|
||||
def _operator_round(number, ndigits=None):
|
||||
out = round(number, decimals=ndigits or 0)
|
||||
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
|
||||
return out.astype(int_) if ndigits is None else out
|
||||
return out.astype(int) if ndigits is None else out
|
||||
|
||||
_operators = {
|
||||
"getitem": _rewriting_take,
|
||||
|
@ -51,4 +51,4 @@ def logcdf(x, loc=0, scale=1):
|
||||
|
||||
@_wraps(osp_stats.norm.ppf, update_doc=False)
|
||||
def ppf(q, loc=0, scale=1):
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, 'float64')
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, float)
|
||||
|
@ -2,7 +2,6 @@
|
||||
filterwarnings =
|
||||
error
|
||||
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
|
||||
ignore:Explicitly requested dtype.*is not available.*:UserWarning
|
||||
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
|
||||
# The rest are for experimental/jax_to_tf
|
||||
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
|
||||
|
@ -563,7 +563,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
def testCumProd(self):
|
||||
x = jnp.arange(9).reshape(3, 3) + 1
|
||||
y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
|
||||
self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y)
|
||||
self.assertAllClose(np.cumprod(x, axis=1, dtype=int), y)
|
||||
|
||||
def testSelect(self):
|
||||
pred = np.array([True, False])
|
||||
@ -930,12 +930,11 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
def f(key):
|
||||
def body_fn(uk):
|
||||
key = uk[1]
|
||||
u = random.uniform(key, (), dtype=jnp.float64)
|
||||
u = random.uniform(key, ())
|
||||
key, _ = random.split(key)
|
||||
return u, key
|
||||
|
||||
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn,
|
||||
(jnp.float64(1.), key))
|
||||
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
|
||||
return u
|
||||
|
||||
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash
|
||||
|
@ -17,6 +17,7 @@ import operator
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental.doubledouble import doubledouble, _DoubleDouble
|
||||
@ -104,7 +105,7 @@ class DoubleDoubleTest(jtu.JaxTestCase):
|
||||
for val in ["6.0221409e23", "3.14159265358", "0", 123456789]
|
||||
))
|
||||
def testClassInstantiation(self, dtype, val):
|
||||
dtype = jnp.dtype(dtype).type
|
||||
dtype = dtypes.canonicalize_dtype(dtype).type
|
||||
self.assertEqual(dtype(val), _DoubleDouble(val, dtype).to_array())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
@ -77,6 +77,8 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_swap={}_jit={}".format(swap, jit),
|
||||
"swap": swap, "jit": jit}
|
||||
for swap in [False, True] for jit in [False, True])
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype.*")
|
||||
def testBinaryPromotion(self, swap, jit):
|
||||
testcases = [
|
||||
(jnp.array(1.), 0., jnp.float_),
|
||||
@ -200,6 +202,8 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
val = jaxtype.type(0)
|
||||
self.assertIs(dtypes._jax_type(val), jaxtype)
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype.*")
|
||||
def testObservedPromotionTable(self):
|
||||
"""Test that the weak & strong dtype promotion table does not change over time."""
|
||||
# Note: * here refers to weakly-typed values
|
||||
|
@ -34,6 +34,7 @@ config.parse_flags_with_absl()
|
||||
def jvp_taylor(fun, primals, series):
|
||||
# Computes the Taylor series the slow way, with nested jvp.
|
||||
order, = set(map(len, series))
|
||||
primals = tuple(jnp.asarray(p) for p in primals)
|
||||
def composition(eps):
|
||||
taylor_terms = [sum([eps ** (i+1) * terms[i] / fact(i + 1)
|
||||
for i in range(len(terms))]) for terms in series]
|
||||
|
@ -1030,36 +1030,36 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSegmentSum(self):
|
||||
data = np.array([5, 1, 7, 2, 3, 4, 1, 3])
|
||||
segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3])
|
||||
data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3])
|
||||
segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3])
|
||||
|
||||
# test with explicit num_segments
|
||||
ans = ops.segment_sum(data, segment_ids, num_segments=4)
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
expected = jnp.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test with explicit num_segments larger than the higher index.
|
||||
ans = ops.segment_sum(data, segment_ids, num_segments=5)
|
||||
expected = np.array([13, 2, 7, 4, 0])
|
||||
expected = jnp.array([13, 2, 7, 4, 0])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test without explicit num_segments
|
||||
ans = ops.segment_sum(data, segment_ids)
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
expected = jnp.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test with negative segment ids and segment ids larger than num_segments,
|
||||
# that will be wrapped with the `mod`.
|
||||
segment_ids = np.array([0, 4, 8, 1, 2, -6, -1, 3])
|
||||
segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3])
|
||||
ans = ops.segment_sum(data, segment_ids, num_segments=4)
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
expected = jnp.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test with negative segment ids and without without explicit num_segments
|
||||
# such as num_segments is defined by the smaller index.
|
||||
segment_ids = np.array([3, 3, 3, 4, 5, 5, -7, -6])
|
||||
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
|
||||
ans = ops.segment_sum(data, segment_ids)
|
||||
expected = np.array([1, 3, 0, 13, 2, 7, 0])
|
||||
expected = jnp.array([1, 3, 0, 13, 2, 7, 0])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testIndexDtypeError(self):
|
||||
|
@ -3198,7 +3198,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
def testNpMean(self):
|
||||
# from https://github.com/google/jax/issues/125
|
||||
x = lax.add(jnp.eye(3, dtype=jnp.float_), 0.)
|
||||
x = lax.add(jnp.eye(3, dtype=float), 0.)
|
||||
ans = np.mean(x)
|
||||
self.assertAllClose(ans, np.array(1./3), check_dtypes=False)
|
||||
|
||||
@ -3825,6 +3825,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testLongLong(self):
|
||||
self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7)))
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype.*")
|
||||
def testArange(self):
|
||||
# test cases inspired by dask tests at
|
||||
# https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92
|
||||
|
@ -32,8 +32,8 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
float_types = [np.float32, np.float64]
|
||||
complex_types = [np.complex64, np.complex128]
|
||||
float_types = jtu.dtypes.floating
|
||||
complex_types = jtu.dtypes.complex
|
||||
|
||||
|
||||
def matmul_high_precision(a, b):
|
||||
|
@ -304,7 +304,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testEigvalsInf(self):
|
||||
# https://github.com/google/jax/issues/2661
|
||||
x = jnp.array([[jnp.inf]], jnp.float64)
|
||||
x = jnp.array([[jnp.inf]])
|
||||
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
@ -61,9 +61,8 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
|
||||
self.assertAllClose(5., api.grad(f_op)(2.))
|
||||
self.assertAllClose(5., api.grad(f_op)(2.))
|
||||
inc_batch = np.arange(5, dtype=jnp.float_)
|
||||
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch],
|
||||
dtype=jnp.float_),
|
||||
inc_batch = np.arange(5.0)
|
||||
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch]),
|
||||
api.vmap(f_op)(inc_batch))
|
||||
|
||||
|
||||
|
@ -407,7 +407,7 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)
|
||||
|
||||
def rnn_reference(W, seqs, targets):
|
||||
total_loss = jnp.array(0, jnp.float_)
|
||||
total_loss = jnp.array(0.0)
|
||||
for xs, target in zip(seqs, targets):
|
||||
h = jnp.zeros(n)
|
||||
for x in xs:
|
||||
|
@ -21,7 +21,6 @@ import itertools
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
|
||||
from jax import core
|
||||
@ -67,8 +66,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
check_grads(nn.softplus, (float('nan'),), order=1,
|
||||
rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
|
||||
|
||||
@parameterized.parameters([
|
||||
int, jnp.int32, float, jnp.float64, jnp.float32, jnp.float64,])
|
||||
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
|
||||
def testSoftplusZero(self, dtype):
|
||||
self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0)))
|
||||
|
||||
@ -212,7 +210,7 @@ class NNInitializersTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype}
|
||||
for rec in INITIALIZER_RECS
|
||||
for shape in rec.shapes
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testInitializer(self, initializer, shape, dtype):
|
||||
rng = random.PRNGKey(0)
|
||||
val = initializer(rng, shape, dtype)
|
||||
@ -228,7 +226,7 @@ class NNInitializersTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype}
|
||||
for rec in INITIALIZER_RECS
|
||||
for shape in rec.shapes
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testInitializerProvider(self, initializer_provider, shape, dtype):
|
||||
rng = random.PRNGKey(0)
|
||||
initializer = initializer_provider(dtype=dtype)
|
||||
|
@ -282,7 +282,7 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
assert trip == 75
|
||||
return opt_final
|
||||
|
||||
initial_params = jnp.float64(0.5)
|
||||
initial_params = jnp.array(0.5)
|
||||
minimize_structure(initial_params)
|
||||
|
||||
def loss(test_params):
|
||||
|
@ -77,11 +77,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
|
||||
if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
|
||||
raise SkipTest("can't test float64 agreement")
|
||||
|
||||
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
|
||||
numpy_bits = np.array(1., dtype).view(bits_dtype)
|
||||
xla_bits = api.jit(
|
||||
@ -156,7 +153,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
|
||||
self.assertArraysEqual(bits32, expected32)
|
||||
|
||||
bits64 = jax._src.random._random_bits(key, 64, (3,))
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = jax._src.random._random_bits(key, 64, (3,))
|
||||
if FLAGS.jax_enable_x64:
|
||||
expected64 = np.array([3982329540505020460, 16822122385914693683,
|
||||
7882654074788531506], dtype=np.uint64)
|
||||
@ -199,7 +197,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in [np.float16, np.float32, np.float64]))
|
||||
for dtype in float_dtypes))
|
||||
def testNormal(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
@ -213,7 +211,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in [np.float16, np.float32, np.float64]))
|
||||
for dtype in float_dtypes))
|
||||
def testTruncatedNormal(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
|
||||
@ -231,7 +229,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in [np.float32, np.float64, np.int32, np.int64]))
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
|
||||
def testShuffle(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
x = np.arange(100).astype(dtype)
|
||||
@ -252,7 +250,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
np.dtype(dtype).name, shape, replace, weighted, array_input),
|
||||
"dtype": dtype, "shape": shape, "replace": replace,
|
||||
"weighted": weighted, "array_input": array_input}
|
||||
for dtype in [np.float32, np.float64, np.int32, np.int64]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
|
||||
for shape in [(), (5,), (4, 5)]
|
||||
for replace in [True, False]
|
||||
for weighted in [True, False]
|
||||
@ -280,7 +278,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"dtype": dtype, "shape": shape}
|
||||
for dtype in [np.float32, np.float64, np.int32, np.int64]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
|
||||
for shape in [100, (10, 10), (10, 5, 2)]))
|
||||
def testPermutationArray(self, dtype, shape):
|
||||
key = random.PRNGKey(0)
|
||||
@ -322,7 +320,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_p={}_dtype={}".format(p, np.dtype(dtype).name),
|
||||
"p": p, "dtype": dtype}
|
||||
for p in [0.1, 0.5, 0.9]
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testBernoulli(self, p, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
@ -345,7 +343,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
([[.5, .1], [.5, .9]], 0),
|
||||
]
|
||||
for sample_shape in [(10000,), (5000, 2)]
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testCategorical(self, p, axis, dtype, sample_shape):
|
||||
key = random.PRNGKey(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
@ -397,7 +395,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in [np.float16, np.float32, np.float64]))
|
||||
for dtype in float_dtypes))
|
||||
def testCauchy(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.cauchy(key, (10000,), dtype)
|
||||
@ -415,7 +413,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for alpha in [
|
||||
np.array([0.2, 1., 5.]),
|
||||
]
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
|
||||
def testDirichlet(self, alpha, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
@ -449,7 +447,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
|
||||
"a": a, "dtype": dtype}
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testGamma(self, a, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
@ -527,7 +525,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testGumbel(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.gumbel(key, (10000,), dtype)
|
||||
@ -571,7 +569,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
|
||||
"b": b, "dtype": dtype}
|
||||
for b in [0.1, 1., 10.]
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testPareto(self, b, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
|
||||
@ -592,7 +590,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_df={}_dtype={}".format(df, np.dtype(dtype).name),
|
||||
"df": df, "dtype": dtype}
|
||||
for df in [0.1, 1., 10.]
|
||||
for dtype in [np.float32, np.float64]))
|
||||
for dtype in jtu.dtypes.floating))
|
||||
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
|
||||
def testT(self, df, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user