add rot90 and flip, adjust testOp test selection

closes #55
This commit is contained in:
Matthew Johnson 2018-12-11 08:54:35 -08:00
parent 567caecdea
commit 5a1aeca96c
3 changed files with 117 additions and 51 deletions

View File

@ -327,6 +327,30 @@ def transpose(x, axis=None):
return lax.transpose(x, axis)
@_wraps(onp.rot90)
def rot90(m, k=1, axes=(0, 1)):
ax1, ax2 = axes
if ax1 % m.ndim == ax2 % m.ndim:
raise ValueError("Axes must be different") # same as numpy error
k = k % 4
if k == 0:
return m
elif k == 2:
return flip(flip(m, ax1), ax2)
else:
perm = list(range(m.ndim))
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
if k == 1:
return transpose(flip(m, ax2), perm)
else:
return flip(transpose(m, perm), ax2)
@_wraps(onp.flip)
def flip(m, axis):
return lax.rev(m, [axis])
@_wraps(onp.sinh)
def sinh(x):
x, = _promote_to_result_dtype(onp.sinh, x)
@ -454,7 +478,10 @@ def where(condition, x=None, y=None):
if not onp.issubdtype(_dtype(condition), onp.bool_):
condition = lax.ne(condition, zeros_like(condition))
condition, x, y = broadcast_arrays(condition, x, y)
return lax.select(condition, *_promote_dtypes(x, y))
if not x.size:
return x
else:
return lax.select(condition, *_promote_dtypes(x, y))
def broadcast_arrays(*args):

View File

@ -171,34 +171,36 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
JAX_COMPOUND_OP_RECORDS)
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS)))
def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
for rec in JAX_BITWISE_OP_RECORDS
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in filter(
_dtypes_are_compatible_for_bitwise_ops,
CombosWithReplacement(rec.dtypes, rec.nargs))))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
for rec in JAX_BITWISE_OP_RECORDS
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in filter(
_dtypes_are_compatible_for_bitwise_ops,
CombosWithReplacement(rec.dtypes, rec.nargs)))
for rec in JAX_BITWISE_OP_RECORDS))
def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes):
if not FLAGS.jax_enable_x64 and any(
onp.iinfo(dtype).bits == 64 for dtype in dtypes):
@ -622,6 +624,41 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
cfoo = api.jit(foo)
self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
for shape in [(3,), (2, 3)]
for dtype in default_dtypes
for axis in range(len(shape))
for rng in [jtu.rand_default()]))
def testFlip(self, shape, dtype, axis, rng):
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
lnp_op = lambda x: lnp.flip(x, axis)
onp_op = lambda x: onp.flip(x, axis)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_k={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, axes),
"rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes}
for shape, axes in [
[(2, 3), (0, 1)],
[(2, 3), (1, 0)],
[(4, 3, 2), (0, 2)],
[(4, 3, 2), (2, 1)],
]
for k in range(-3, 4)
for dtype in default_dtypes
for rng in [jtu.rand_default()]))
def testRot90(self, shape, dtype, k, axes, rng):
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
lnp_op = lambda x: lnp.rot90(x, k, axes)
onp_op = lambda x: onp.rot90(x, k, axes)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
# TODO(mattjj): test infix operator overrides
def DISABLED_testRavel(self):

View File

@ -137,27 +137,29 @@ CombosWithReplacement = itertools.combinations_with_replacement
class LaxTest(jtu.JaxTestCase):
"""Numerical tests for LAX operations."""
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype}
for rec in LAX_OPS
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype}
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes)
for rec in LAX_OPS))
def testOp(self, op, rng, shapes, dtype):
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
"tol": rec.tol}
for rec in LAX_OPS
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
"tol": rec.tol}
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes)
for rec in LAX_OPS))
def testOpAgainstNumpy(self, op, rng, shapes, dtype, tol):
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
numpy_op = getattr(lax_reference, op.__name__)
@ -1436,16 +1438,16 @@ def check_grads_bilinear(f, args, order, atol=None, rtol=None):
class LaxAutodiffTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
"order": rec.order}
for rec in LAX_GRAD_OPS
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes
))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
"order": rec.order}
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes)
for rec in LAX_GRAD_OPS))
def testOpGrad(self, op, rng, shapes, dtype, order):
if FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu"):
if dtype is onp.complex64: