mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
567caecdea
commit
5a1aeca96c
@ -327,6 +327,30 @@ def transpose(x, axis=None):
|
||||
return lax.transpose(x, axis)
|
||||
|
||||
|
||||
@_wraps(onp.rot90)
|
||||
def rot90(m, k=1, axes=(0, 1)):
|
||||
ax1, ax2 = axes
|
||||
if ax1 % m.ndim == ax2 % m.ndim:
|
||||
raise ValueError("Axes must be different") # same as numpy error
|
||||
k = k % 4
|
||||
if k == 0:
|
||||
return m
|
||||
elif k == 2:
|
||||
return flip(flip(m, ax1), ax2)
|
||||
else:
|
||||
perm = list(range(m.ndim))
|
||||
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
|
||||
if k == 1:
|
||||
return transpose(flip(m, ax2), perm)
|
||||
else:
|
||||
return flip(transpose(m, perm), ax2)
|
||||
|
||||
|
||||
@_wraps(onp.flip)
|
||||
def flip(m, axis):
|
||||
return lax.rev(m, [axis])
|
||||
|
||||
|
||||
@_wraps(onp.sinh)
|
||||
def sinh(x):
|
||||
x, = _promote_to_result_dtype(onp.sinh, x)
|
||||
@ -454,7 +478,10 @@ def where(condition, x=None, y=None):
|
||||
if not onp.issubdtype(_dtype(condition), onp.bool_):
|
||||
condition = lax.ne(condition, zeros_like(condition))
|
||||
condition, x, y = broadcast_arrays(condition, x, y)
|
||||
return lax.select(condition, *_promote_dtypes(x, y))
|
||||
if not x.size:
|
||||
return x
|
||||
else:
|
||||
return lax.select(condition, *_promote_dtypes(x, y))
|
||||
|
||||
|
||||
def broadcast_arrays(*args):
|
||||
|
@ -171,34 +171,36 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes):
|
||||
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
dtypes),
|
||||
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
|
||||
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
|
||||
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
|
||||
JAX_COMPOUND_OP_RECORDS)
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
CombosWithReplacement(rec.shapes, rec.nargs))
|
||||
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)))
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
dtypes),
|
||||
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
|
||||
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
CombosWithReplacement(rec.shapes, rec.nargs))
|
||||
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
|
||||
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS)))
|
||||
def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
dtypes),
|
||||
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
|
||||
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
|
||||
for rec in JAX_BITWISE_OP_RECORDS
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
CombosWithReplacement(rec.shapes, rec.nargs))
|
||||
for dtypes in filter(
|
||||
_dtypes_are_compatible_for_bitwise_ops,
|
||||
CombosWithReplacement(rec.dtypes, rec.nargs))))
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
dtypes),
|
||||
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
|
||||
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
|
||||
for rec in JAX_BITWISE_OP_RECORDS
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
CombosWithReplacement(rec.shapes, rec.nargs))
|
||||
for dtypes in filter(
|
||||
_dtypes_are_compatible_for_bitwise_ops,
|
||||
CombosWithReplacement(rec.dtypes, rec.nargs)))
|
||||
for rec in JAX_BITWISE_OP_RECORDS))
|
||||
def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes):
|
||||
if not FLAGS.jax_enable_x64 and any(
|
||||
onp.iinfo(dtype).bits == 64 for dtype in dtypes):
|
||||
@ -622,6 +624,41 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
cfoo = api.jit(foo)
|
||||
self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
"rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
|
||||
for shape in [(3,), (2, 3)]
|
||||
for dtype in default_dtypes
|
||||
for axis in range(len(shape))
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testFlip(self, shape, dtype, axis, rng):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
lnp_op = lambda x: lnp.flip(x, axis)
|
||||
onp_op = lambda x: onp.flip(x, axis)
|
||||
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_k={}_axes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k, axes),
|
||||
"rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes}
|
||||
for shape, axes in [
|
||||
[(2, 3), (0, 1)],
|
||||
[(2, 3), (1, 0)],
|
||||
[(4, 3, 2), (0, 2)],
|
||||
[(4, 3, 2), (2, 1)],
|
||||
]
|
||||
for k in range(-3, 4)
|
||||
for dtype in default_dtypes
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testRot90(self, shape, dtype, k, axes, rng):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
lnp_op = lambda x: lnp.rot90(x, k, axes)
|
||||
onp_op = lambda x: onp.rot90(x, k, axes)
|
||||
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
# TODO(mattjj): test infix operator overrides
|
||||
|
||||
def DISABLED_testRavel(self):
|
||||
|
@ -137,27 +137,29 @@ CombosWithReplacement = itertools.combinations_with_replacement
|
||||
class LaxTest(jtu.JaxTestCase):
|
||||
"""Numerical tests for LAX operations."""
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.op.__name__, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype}
|
||||
for rec in LAX_OPS
|
||||
for shape_group in compatible_shapes
|
||||
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
||||
for dtype in rec.dtypes))
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.op.__name__, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype}
|
||||
for shape_group in compatible_shapes
|
||||
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
||||
for dtype in rec.dtypes)
|
||||
for rec in LAX_OPS))
|
||||
def testOp(self, op, rng, shapes, dtype):
|
||||
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.op.__name__, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
|
||||
"tol": rec.tol}
|
||||
for rec in LAX_OPS
|
||||
for shape_group in compatible_shapes
|
||||
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
||||
for dtype in rec.dtypes))
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.op.__name__, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
|
||||
"tol": rec.tol}
|
||||
for shape_group in compatible_shapes
|
||||
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
||||
for dtype in rec.dtypes)
|
||||
for rec in LAX_OPS))
|
||||
def testOpAgainstNumpy(self, op, rng, shapes, dtype, tol):
|
||||
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
||||
numpy_op = getattr(lax_reference, op.__name__)
|
||||
@ -1436,16 +1438,16 @@ def check_grads_bilinear(f, args, order, atol=None, rtol=None):
|
||||
|
||||
class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.op.__name__, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
|
||||
"order": rec.order}
|
||||
for rec in LAX_GRAD_OPS
|
||||
for shape_group in compatible_shapes
|
||||
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
||||
for dtype in rec.dtypes
|
||||
))
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.op.__name__, shapes, itertools.repeat(dtype)),
|
||||
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
|
||||
"order": rec.order}
|
||||
for shape_group in compatible_shapes
|
||||
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
||||
for dtype in rec.dtypes)
|
||||
for rec in LAX_GRAD_OPS))
|
||||
def testOpGrad(self, op, rng, shapes, dtype, order):
|
||||
if FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu"):
|
||||
if dtype is onp.complex64:
|
||||
|
Loading…
x
Reference in New Issue
Block a user