Fix _CheckAgainstNumpy arg order (#3935)

This commit is contained in:
Julius Kunze 2020-08-03 17:17:48 +02:00 committed by GitHub
parent 03df35a9d1
commit 7de784afbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 46 deletions

View File

@ -2582,7 +2582,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
np_op = lambda x: np.asarray(x).astype(jnp.int32)
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2605,7 +2605,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_op = lambda x: jnp.asarray(x).view(dtype)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testPathologicalFloats(self):
@ -2625,7 +2625,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_op = lambda x: np.asarray(x).view('float32').view('uint32')
jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32')
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
# TODO(mattjj): test other ndarray-like method overrides
@ -2663,7 +2663,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2680,7 +2680,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.skipTest("complex sort not supported on TPU")
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp.sort_complex, np.sort_complex, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp.sort_complex, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2701,7 +2701,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [input_type(rng(shape, dtype))]
jnp_op = lambda x: jnp.lexsort(x, axis=axis)
np_op = lambda x: np.lexsort(x, axis=axis)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2723,7 +2723,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2739,7 +2739,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.skipTest("complex sort not supported on TPU")
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp.msort, np.msort, args_maker)
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
self._CompileAndCheck(jnp.msort, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2765,7 +2765,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), np.array(shifts)]
jnp_op = partial(jnp.roll, axis=axis)
np_op = partial(np.roll, axis=axis)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2784,7 +2784,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.rollaxis, axis=axis, start=start)
np_op = partial(np.rollaxis, axis=axis, start=start)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2804,7 +2804,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
np_op = partial(np.packbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2824,7 +2824,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
@ -2851,7 +2851,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng_indices = jtu.rand_int(self.rng(), -5, 5)
jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode)
np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testTakeEmpty(self):
@ -2892,7 +2892,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if hasattr(np, "take_along_axis"):
np_op = lambda x, i: np.take_along_axis(x, i, axis=axis)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -195,7 +195,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
op = getattr(lax, op_name)
numpy_op = getattr(lax_reference, op_name)
self._CheckAgainstNumpy(op, numpy_op, args_maker, tol=tol)
self._CheckAgainstNumpy(numpy_op, op, args_maker, tol=tol)
# TODO test shift_left, shift_right_arithmetic, shift_right_logical
@ -224,7 +224,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.convert_element_type(x, to_dtype)
numpy_op = lambda x: lax_reference.convert_element_type(x, to_dtype)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_from_dtype={}_to_dtype={}"
@ -251,7 +251,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
@ -294,7 +294,7 @@ class LaxTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
shapes = [min_shape, operand_shape, max_shape]
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CheckAgainstNumpy(lax.clamp, lax_reference.clamp, args_maker)
self._CheckAgainstNumpy(lax_reference.clamp, lax.clamp, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
@ -333,7 +333,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
op = lambda *args: lax.concatenate(args, dim)
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -378,7 +378,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
op = lambda lhs, rhs: lax.conv(lhs, rhs, strides, padding)
numpy_op = lambda lhs, rhs: lax_reference.conv(lhs, rhs, strides, padding)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
@ -501,7 +501,7 @@ class LaxTest(jtu.JaxTestCase):
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(jnp_fun, np.dot, args_maker, tol=.1)
self._CheckAgainstNumpy(np.dot, jnp_fun, args_maker, tol=.1)
@staticmethod
@ -581,7 +581,7 @@ class LaxTest(jtu.JaxTestCase):
dimension_numbers=dspec)
# NB: below just checks for agreement, we're not calling numpy.
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -619,7 +619,7 @@ class LaxTest(jtu.JaxTestCase):
dimension_numbers=dspec)
# NB: below just checks for agreement, we're not calling numpy.
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -656,7 +656,7 @@ class LaxTest(jtu.JaxTestCase):
dimension_numbers=dspec)
# NB: below just checks for agreement, we're not calling numpy.
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -693,7 +693,7 @@ class LaxTest(jtu.JaxTestCase):
dimension_numbers=dspec)
# NB: below just checks for agreement, we're not calling numpy.
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
@ -731,7 +731,7 @@ class LaxTest(jtu.JaxTestCase):
1e-14)
}
lax_op = partial(lax.dot, precision=lax.Precision.HIGHEST)
self._CheckAgainstNumpy(lax_op, lax_reference.dot, args_maker, tol=tol)
self._CheckAgainstNumpy(lax_reference.dot, lax_op, args_maker, tol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -813,7 +813,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
op = lambda x, y: lax.dot_general(x, y, dimension_numbers)
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
@ -844,7 +844,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.broadcast(x, broadcast_sizes)
numpy_op = lambda x: lax_reference.broadcast(x, broadcast_sizes)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
@ -912,7 +912,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(inshape, dtype)]
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_dimensions={}".format(
@ -952,7 +952,7 @@ class LaxTest(jtu.JaxTestCase):
op = lambda x: lax.squeeze(x, dimensions)
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
self._CompileAndCheck(op, args_maker)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
@ -988,7 +988,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(arg_shape, dtype)]
op = lambda x: lax.reshape(x, out_shape)
numpy_op = lambda x: lax_reference.reshape(x, out_shape)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_pads={}"
@ -1022,7 +1022,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.pad(x, np.array(0, dtype), pads)
numpy_op = lambda x: lax_reference.pad(x, np.array(0, dtype), pads)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
def testReverse(self):
rev = api.jit(lambda operand: lax.rev(operand, dimensions))
@ -1072,7 +1072,7 @@ class LaxTest(jtu.JaxTestCase):
return [rng(pred_shape, np.bool_), rng(arg_shape, arg_dtype),
rng(arg_shape, arg_dtype)]
rng = rng_factory(self.rng())
return self._CheckAgainstNumpy(lax.select, lax_reference.select, args_maker)
return self._CheckAgainstNumpy(lax_reference.select, lax.select, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1126,7 +1126,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.slice(x, starts, limits, strides)
numpy_op = lambda x: lax_reference.slice(x, starts, limits, strides)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
@ -1167,7 +1167,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), np.array(start_indices)]
op = lambda x, s: lax.dynamic_slice(x, s, size_indices)
numpy_op = lambda x, s: lax_reference.dynamic_slice(x, s, size_indices)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
def testDynamicSliceInDim(self):
# Regression test for mixed type problem in dynamic_slice_in_dim.
@ -1219,8 +1219,8 @@ class LaxTest(jtu.JaxTestCase):
return [rng(shape, dtype), rng(update_shape, dtype),
np.array(start_indices)]
self._CheckAgainstNumpy(lax.dynamic_update_slice,
lax_reference.dynamic_update_slice, args_maker)
self._CheckAgainstNumpy(lax_reference.dynamic_update_slice,
lax.dynamic_update_slice, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_perm={}".format(
@ -1257,7 +1257,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.transpose(x, perm)
numpy_op = lambda x: lax_reference.transpose(x, perm)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}"
@ -1346,7 +1346,7 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), init_val]
self._CompileAndCheck(fun, args_maker)
if all(d == 1 for d in window_dilation):
self._CheckAgainstNumpy(fun, reference_fun, args_maker)
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
# we separately test the version that uses a concrete init_val because it
# can hit different code paths
@ -1380,7 +1380,7 @@ class LaxTest(jtu.JaxTestCase):
np_fun = partial(np_op, axis=axis, dtype=dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker)
self._CheckAgainstNumpy(fun, np_fun, args_maker)
self._CheckAgainstNumpy(np_fun, fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
@ -1423,7 +1423,7 @@ class LaxTest(jtu.JaxTestCase):
return lax_reference.sort(x, axis, kind='stable')
else:
return lax_reference.sort(x, axis)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
@ -1474,7 +1474,7 @@ class LaxTest(jtu.JaxTestCase):
lax_fun = lambda x: lax.sort(tuple(x), num_keys=num_keys)
numpy_fun = lambda x: tuple(x[:, np.lexsort(x[:num_keys][::-1])])
# self._CompileAndCheck(lax_fun, args_maker)
self._CheckAgainstNumpy(lax_fun, numpy_fun, args_maker)
self._CheckAgainstNumpy(numpy_fun, lax_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
@ -1505,7 +1505,7 @@ class LaxTest(jtu.JaxTestCase):
op = lambda ks, vs: lax.sort_key_val(ks, vs, axis)
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(

View File

@ -98,9 +98,9 @@ class NdimageTest(jtu.JaxTestCase):
if dtype in float_dtypes:
epsilon = max([dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
for d in [dtype, coords_dtype]])
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100*epsilon)
self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=100*epsilon)
else:
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0)
self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)
def testMapCoordinatesErrors(self):
x = np.arange(5.0)
@ -130,7 +130,7 @@ class NdimageTest(jtu.JaxTestCase):
lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order)
osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order)
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker)
self._CheckAgainstNumpy(osp_op, lsp_op, args_maker)
def testContinuousGradients(self):
# regression test for https://github.com/google/jax/issues/3024