X32 tests: fail on dtype warnings

This commit is contained in:
Jake VanderPlas 2020-12-08 13:03:30 -08:00
parent 4504f27492
commit f74235cdae
18 changed files with 52 additions and 51 deletions

View File

@ -54,7 +54,7 @@ def compute_weight_mat(input_size: int, output_size: int, scale,
# want to interpolate.
kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
sample_f = ((np.arange(output_size) + 0.5) * inv_scale -
sample_f = ((jnp.arange(output_size) + 0.5) * inv_scale -
translation * inv_scale - 0.5)
x = (
jnp.abs(sample_f[jnp.newaxis, :] -

View File

@ -6378,7 +6378,7 @@ def _check_user_dtype_supported(dtype, fun_name=None):
"See https://github.com/google/jax#current-gotchas for more.")
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=2)
def _canonicalize_axis(axis, num_dims):

View File

@ -4985,7 +4985,7 @@ def _unimplemented_setitem(self, i, x):
def _operator_round(number, ndigits=None):
out = round(number, decimals=ndigits or 0)
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
return out.astype(int_) if ndigits is None else out
return out.astype(int) if ndigits is None else out
_operators = {
"getitem": _rewriting_take,

View File

@ -51,4 +51,4 @@ def logcdf(x, loc=0, scale=1):
@_wraps(osp_stats.norm.ppf, update_doc=False)
def ppf(q, loc=0, scale=1):
return jnp.asarray(special.ndtri(q) * scale + loc, 'float64')
return jnp.asarray(special.ndtri(q) * scale + loc, float)

View File

@ -2,7 +2,6 @@
filterwarnings =
error
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
ignore:Explicitly requested dtype.*is not available.*:UserWarning
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
# The rest are for experimental/jax_to_tf
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning

View File

@ -563,7 +563,7 @@ class BatchingTest(jtu.JaxTestCase):
def testCumProd(self):
x = jnp.arange(9).reshape(3, 3) + 1
y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y)
self.assertAllClose(np.cumprod(x, axis=1, dtype=int), y)
def testSelect(self):
pred = np.array([True, False])
@ -930,12 +930,11 @@ class BatchingTest(jtu.JaxTestCase):
def f(key):
def body_fn(uk):
key = uk[1]
u = random.uniform(key, (), dtype=jnp.float64)
u = random.uniform(key, ())
key, _ = random.split(key)
return u, key
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn,
(jnp.float64(1.), key))
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
return u
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash

View File

@ -17,6 +17,7 @@ import operator
from absl.testing import absltest
from absl.testing import parameterized
from jax import dtypes
from jax import numpy as jnp
from jax import test_util as jtu
from jax.experimental.doubledouble import doubledouble, _DoubleDouble
@ -104,7 +105,7 @@ class DoubleDoubleTest(jtu.JaxTestCase):
for val in ["6.0221409e23", "3.14159265358", "0", 123456789]
))
def testClassInstantiation(self, dtype, val):
dtype = jnp.dtype(dtype).type
dtype = dtypes.canonicalize_dtype(dtype).type
self.assertEqual(dtype(val), _DoubleDouble(val, dtype).to_array())
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -77,6 +77,8 @@ class DtypesTest(jtu.JaxTestCase):
{"testcase_name": "_swap={}_jit={}".format(swap, jit),
"swap": swap, "jit": jit}
for swap in [False, True] for jit in [False, True])
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype.*")
def testBinaryPromotion(self, swap, jit):
testcases = [
(jnp.array(1.), 0., jnp.float_),
@ -200,6 +202,8 @@ class TestPromotionTables(jtu.JaxTestCase):
val = jaxtype.type(0)
self.assertIs(dtypes._jax_type(val), jaxtype)
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype.*")
def testObservedPromotionTable(self):
"""Test that the weak & strong dtype promotion table does not change over time."""
# Note: * here refers to weakly-typed values

View File

@ -34,6 +34,7 @@ config.parse_flags_with_absl()
def jvp_taylor(fun, primals, series):
# Computes the Taylor series the slow way, with nested jvp.
order, = set(map(len, series))
primals = tuple(jnp.asarray(p) for p in primals)
def composition(eps):
taylor_terms = [sum([eps ** (i+1) * terms[i] / fact(i + 1)
for i in range(len(terms))]) for terms in series]

View File

@ -1030,36 +1030,36 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testSegmentSum(self):
data = np.array([5, 1, 7, 2, 3, 4, 1, 3])
segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3])
data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3])
segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3])
# test with explicit num_segments
ans = ops.segment_sum(data, segment_ids, num_segments=4)
expected = np.array([13, 2, 7, 4])
expected = jnp.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with explicit num_segments larger than the higher index.
ans = ops.segment_sum(data, segment_ids, num_segments=5)
expected = np.array([13, 2, 7, 4, 0])
expected = jnp.array([13, 2, 7, 4, 0])
self.assertAllClose(ans, expected, check_dtypes=False)
# test without explicit num_segments
ans = ops.segment_sum(data, segment_ids)
expected = np.array([13, 2, 7, 4])
expected = jnp.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with negative segment ids and segment ids larger than num_segments,
# that will be wrapped with the `mod`.
segment_ids = np.array([0, 4, 8, 1, 2, -6, -1, 3])
segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3])
ans = ops.segment_sum(data, segment_ids, num_segments=4)
expected = np.array([13, 2, 7, 4])
expected = jnp.array([13, 2, 7, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with negative segment ids and without without explicit num_segments
# such as num_segments is defined by the smaller index.
segment_ids = np.array([3, 3, 3, 4, 5, 5, -7, -6])
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
ans = ops.segment_sum(data, segment_ids)
expected = np.array([1, 3, 0, 13, 2, 7, 0])
expected = jnp.array([1, 3, 0, 13, 2, 7, 0])
self.assertAllClose(ans, expected, check_dtypes=False)
def testIndexDtypeError(self):

View File

@ -3198,7 +3198,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testNpMean(self):
# from https://github.com/google/jax/issues/125
x = lax.add(jnp.eye(3, dtype=jnp.float_), 0.)
x = lax.add(jnp.eye(3, dtype=float), 0.)
ans = np.mean(x)
self.assertAllClose(ans, np.array(1./3), check_dtypes=False)
@ -3825,6 +3825,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testLongLong(self):
self.assertAllClose(np.int64(7), api.jit(lambda x: x)(np.longlong(7)))
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype.*")
def testArange(self):
# test cases inspired by dask tests at
# https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92

View File

@ -32,8 +32,8 @@ from jax.config import config
config.parse_flags_with_absl()
float_types = [np.float32, np.float64]
complex_types = [np.complex64, np.complex128]
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
def matmul_high_precision(a, b):

View File

@ -304,7 +304,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("gpu", "tpu")
def testEigvalsInf(self):
# https://github.com/google/jax/issues/2661
x = jnp.array([[jnp.inf]], jnp.float64)
x = jnp.array([[jnp.inf]])
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -61,9 +61,8 @@ class LoopsTest(jtu.JaxTestCase):
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
self.assertAllClose(5., api.grad(f_op)(2.))
self.assertAllClose(5., api.grad(f_op)(2.))
inc_batch = np.arange(5, dtype=jnp.float_)
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch],
dtype=jnp.float_),
inc_batch = np.arange(5.0)
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch]),
api.vmap(f_op)(inc_batch))

View File

@ -407,7 +407,7 @@ class MaskingTest(jtu.JaxTestCase):
ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)
def rnn_reference(W, seqs, targets):
total_loss = jnp.array(0, jnp.float_)
total_loss = jnp.array(0.0)
for xs, target in zip(seqs, targets):
h = jnp.zeros(n)
for x in xs:

View File

@ -21,7 +21,6 @@ import itertools
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.stats
from jax import core
@ -67,8 +66,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
check_grads(nn.softplus, (float('nan'),), order=1,
rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
@parameterized.parameters([
int, jnp.int32, float, jnp.float64, jnp.float32, jnp.float64,])
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
def testSoftplusZero(self, dtype):
self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0)))
@ -212,7 +210,7 @@ class NNInitializersTest(jtu.JaxTestCase):
"shape": shape, "dtype": dtype}
for rec in INITIALIZER_RECS
for shape in rec.shapes
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testInitializer(self, initializer, shape, dtype):
rng = random.PRNGKey(0)
val = initializer(rng, shape, dtype)
@ -228,7 +226,7 @@ class NNInitializersTest(jtu.JaxTestCase):
"shape": shape, "dtype": dtype}
for rec in INITIALIZER_RECS
for shape in rec.shapes
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testInitializerProvider(self, initializer_provider, shape, dtype):
rng = random.PRNGKey(0)
initializer = initializer_provider(dtype=dtype)

View File

@ -282,7 +282,7 @@ class OptimizerTests(jtu.JaxTestCase):
assert trip == 75
return opt_final
initial_params = jnp.float64(0.5)
initial_params = jnp.array(0.5)
minimize_structure(initial_params)
def loss(test_params):

View File

@ -77,11 +77,8 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
raise SkipTest("can't test float64 agreement")
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
numpy_bits = np.array(1., dtype).view(bits_dtype)
xla_bits = api.jit(
@ -156,7 +153,8 @@ class LaxRandomTest(jtu.JaxTestCase):
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
self.assertArraysEqual(bits32, expected32)
bits64 = jax._src.random._random_bits(key, 64, (3,))
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = jax._src.random._random_bits(key, 64, (3,))
if FLAGS.jax_enable_x64:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
@ -199,7 +197,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float16, np.float32, np.float64]))
for dtype in float_dtypes))
def testNormal(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.normal(key, (10000,), dtype)
@ -213,7 +211,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float16, np.float32, np.float64]))
for dtype in float_dtypes))
def testTruncatedNormal(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
@ -231,7 +229,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float32, np.float64, np.int32, np.int64]))
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
def testShuffle(self, dtype):
key = random.PRNGKey(0)
x = np.arange(100).astype(dtype)
@ -252,7 +250,7 @@ class LaxRandomTest(jtu.JaxTestCase):
np.dtype(dtype).name, shape, replace, weighted, array_input),
"dtype": dtype, "shape": shape, "replace": replace,
"weighted": weighted, "array_input": array_input}
for dtype in [np.float32, np.float64, np.int32, np.int64]
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for shape in [(), (5,), (4, 5)]
for replace in [True, False]
for weighted in [True, False]
@ -280,7 +278,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"dtype": dtype, "shape": shape}
for dtype in [np.float32, np.float64, np.int32, np.int64]
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for shape in [100, (10, 10), (10, 5, 2)]))
def testPermutationArray(self, dtype, shape):
key = random.PRNGKey(0)
@ -322,7 +320,7 @@ class LaxRandomTest(jtu.JaxTestCase):
{"testcase_name": "_p={}_dtype={}".format(p, np.dtype(dtype).name),
"p": p, "dtype": dtype}
for p in [0.1, 0.5, 0.9]
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testBernoulli(self, p, dtype):
key = random.PRNGKey(0)
p = np.array(p, dtype=dtype)
@ -345,7 +343,7 @@ class LaxRandomTest(jtu.JaxTestCase):
([[.5, .1], [.5, .9]], 0),
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testCategorical(self, p, axis, dtype, sample_shape):
key = random.PRNGKey(0)
p = np.array(p, dtype=dtype)
@ -397,7 +395,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float16, np.float32, np.float64]))
for dtype in float_dtypes))
def testCauchy(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
@ -415,7 +413,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for alpha in [
np.array([0.2, 1., 5.]),
]
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
def testDirichlet(self, alpha, dtype):
key = random.PRNGKey(0)
@ -449,7 +447,7 @@ class LaxRandomTest(jtu.JaxTestCase):
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testGamma(self, a, dtype):
key = random.PRNGKey(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
@ -527,7 +525,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testGumbel(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
@ -571,7 +569,7 @@ class LaxRandomTest(jtu.JaxTestCase):
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
"b": b, "dtype": dtype}
for b in [0.1, 1., 10.]
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
def testPareto(self, b, dtype):
key = random.PRNGKey(0)
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
@ -592,7 +590,7 @@ class LaxRandomTest(jtu.JaxTestCase):
{"testcase_name": "_df={}_dtype={}".format(df, np.dtype(dtype).name),
"df": df, "dtype": dtype}
for df in [0.1, 1., 10.]
for dtype in [np.float32, np.float64]))
for dtype in jtu.dtypes.floating))
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = random.PRNGKey(0)