mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
547 lines
23 KiB
Python
547 lines
23 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import functools
|
|
import itertools
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import numpy as onp
|
|
|
|
from jax import api
|
|
from jax import numpy as lnp
|
|
from jax import test_util as jtu
|
|
from jax.config import config
|
|
|
|
FLAGS = config.FLAGS
|
|
|
|
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
|
|
|
float_dtypes = [onp.float32, onp.float64]
|
|
complex_dtypes = [onp.complex64]
|
|
int_dtypes = [onp.int32, onp.int64]
|
|
bool_dtypes = [onp.bool_]
|
|
default_dtypes = float_dtypes + int_dtypes
|
|
numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes
|
|
|
|
|
|
OpRecord = collections.namedtuple("OpRecord", ["name", "nargs", "dtypes", "rng",
|
|
"diff_modes", "test_name"])
|
|
|
|
|
|
def op_record(name, nargs, dtypes, rng, diff_modes, test_name=None):
|
|
test_name = test_name or name
|
|
return OpRecord(name, nargs, dtypes, rng, diff_modes, test_name)
|
|
|
|
JAX_ONE_TO_ONE_OP_RECORDS = [
|
|
op_record("abs", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("add", 2, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("bitwise_and", 2, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("bitwise_not", 1, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("bitwise_or", 2, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("bitwise_xor", 2, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("ceil", 1, float_dtypes, jtu.rand_default(), []),
|
|
op_record("conj", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("conjugate", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("equal", 2, default_dtypes, jtu.rand_some_equal(), []),
|
|
op_record("exp", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("floor", 1, float_dtypes, jtu.rand_default(), []),
|
|
op_record("greater", 2, default_dtypes, jtu.rand_some_equal(), []),
|
|
op_record("greater_equal", 2, default_dtypes, jtu.rand_some_equal(), []),
|
|
op_record("less", 2, default_dtypes, jtu.rand_some_equal(), []),
|
|
op_record("less_equal", 2, default_dtypes, jtu.rand_some_equal(), []),
|
|
op_record("log", 1, numeric_dtypes, jtu.rand_positive(), ["rev"]),
|
|
op_record("logical_and", 2, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("logical_not", 1, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("logical_or", 2, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("logical_xor", 2, default_dtypes, jtu.rand_bool(), []),
|
|
op_record("maximum", 2, default_dtypes, jtu.rand_some_inf(), []),
|
|
op_record("minimum", 2, default_dtypes, jtu.rand_some_inf(), []),
|
|
op_record("multiply", 2, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("negative", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("not_equal", 2, default_dtypes, jtu.rand_some_equal(), ["rev"]),
|
|
op_record("power", 2, float_dtypes, jtu.rand_positive(), ["rev"]),
|
|
op_record("subtract", 2, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("tanh", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("sin", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("cos", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
]
|
|
|
|
JAX_COMPOUND_OP_RECORDS = [
|
|
op_record("cosh", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("divide", 2, default_dtypes, jtu.rand_nonzero(), ["rev"]),
|
|
op_record("expm1", 1, numeric_dtypes, jtu.rand_positive(), [],
|
|
test_name="expm1_large"),
|
|
op_record("expm1", 1, numeric_dtypes, jtu.rand_small_positive(), []),
|
|
op_record("floor_divide", 2, default_dtypes, jtu.rand_nonzero(), ["rev"]),
|
|
op_record("isclose", 2, float_dtypes, jtu.rand_small_positive(), []),
|
|
op_record("log1p", 1, numeric_dtypes, jtu.rand_positive(), [],
|
|
test_name="log1p_large"),
|
|
op_record("log1p", 1, numeric_dtypes, jtu.rand_small_positive(), []),
|
|
op_record("logaddexp", 2, float_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("ravel", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("remainder", 2, default_dtypes, jtu.rand_nonzero(), []),
|
|
op_record("sinh", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("sqrt", 1, default_dtypes, jtu.rand_positive(), ["rev"]),
|
|
op_record("transpose", 1, default_dtypes, jtu.rand_default(), ["rev"]),
|
|
op_record("true_divide", 2, default_dtypes, jtu.rand_nonzero(), ["rev"]),
|
|
op_record("where", 3, (onp.float32, onp.int64), jtu.rand_some_zero(), []),
|
|
]
|
|
|
|
JAX_REDUCER_RECORDS = [
|
|
op_record("all", 1, bool_dtypes, jtu.rand_default(), []),
|
|
op_record("any", 1, bool_dtypes, jtu.rand_default(), []),
|
|
op_record("max", 1, default_dtypes, jtu.rand_default(), []),
|
|
op_record("mean", 1, default_dtypes, jtu.rand_default(), []),
|
|
op_record("min", 1, default_dtypes, jtu.rand_default(), []),
|
|
op_record("prod", 1, default_dtypes, jtu.rand_small_positive(), []),
|
|
op_record("sum", 1, default_dtypes, jtu.rand_default(), []),
|
|
op_record("var", 1, default_dtypes, jtu.rand_default(), []),
|
|
]
|
|
|
|
JAX_ARGMINMAX_RECORDS = [
|
|
op_record("argmin", 1, default_dtypes, jtu.rand_some_equal(), []),
|
|
op_record("argmax", 1, default_dtypes, jtu.rand_some_equal(), []),
|
|
]
|
|
|
|
CombosWithReplacement = itertools.combinations_with_replacement
|
|
|
|
|
|
class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|
"""Tests for LAX-backed Numpy implementation."""
|
|
|
|
def _GetArgsMaker(self, rng, shapes, dtypes):
|
|
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
|
dtypes),
|
|
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
|
|
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
|
|
JAX_COMPOUND_OP_RECORDS)
|
|
for shapes in CombosWithReplacement(all_shapes, rec.nargs)
|
|
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
|
|
def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
|
|
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
|
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
|
|
rec.test_name.capitalize(),
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
|
|
"rng": rec.rng, "shape": shape, "dtype": dtype,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
|
|
"axis": axis, "keepdims": keepdims}
|
|
for rec in JAX_REDUCER_RECORDS
|
|
for shape in all_shapes for dtype in rec.dtypes
|
|
for axis in range(-len(shape), len(shape))
|
|
for keepdims in [False, True])
|
|
def testReducer(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims):
|
|
onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
|
|
lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "{}_inshape={}_axis={}".format(
|
|
rec.test_name.capitalize(),
|
|
jtu.format_shape_dtype_string(shape, dtype), axis),
|
|
"rng": rec.rng, "shape": shape, "dtype": dtype,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
|
|
"axis": axis}
|
|
for rec in JAX_ARGMINMAX_RECORDS
|
|
for shape in all_shapes for dtype in rec.dtypes
|
|
for axis in range(-len(shape), len(shape)))
|
|
def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis):
|
|
|
|
def onp_fun(array_to_reduce):
|
|
return onp_op(array_to_reduce, axis)
|
|
|
|
def lnp_fun(array_to_reduce):
|
|
return lnp_op(array_to_reduce, axis)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}_{}_{}".format(
|
|
name,
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"rng": rng}
|
|
for rng in [jtu.rand_default()]
|
|
for name, lhs_shape, rhs_shape in [
|
|
("matrix-scalar", (3, 3), ()),
|
|
("scalar-matrix", (), (3, 3)),
|
|
("matrix-vector", (4, 5), (5,)),
|
|
("vector-matrix", (6,), (6, 4)),
|
|
("matrix-matrix", (3, 4), (4, 5)),
|
|
("tensor-vector", (4, 3, 2), (2,)),
|
|
("vector-tensor", (2,), (3, 2, 4)),
|
|
("tensor-matrix", (4, 3, 2), (2, 5)),
|
|
("matrix-tensor", (5, 2), (3, 2, 4)),
|
|
("tensor-tensor", (2, 3, 4), (5, 4, 1))]
|
|
for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))
|
|
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}_{}_{}".format(
|
|
name,
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"rng": rng}
|
|
for rng in [jtu.rand_default()]
|
|
for name, lhs_shape, rhs_shape in [
|
|
("vector-vector", (3,), (3,)),
|
|
("matrix-vector", (3, 3), (3,)),
|
|
("vector-matrix", (3,), (3, 3)),
|
|
("matrix-matrix", (3, 3), (3, 3)),
|
|
("vector-tensor", (3,), (5, 3, 2)),
|
|
("tensor-vector", (5, 3, 2), (2,)),
|
|
("matrix-tensor", (5, 2), (3, 2, 4)),
|
|
("tensor-matrix", (5, 2, 3), (3, 2)),
|
|
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
|
|
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
|
|
for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))
|
|
def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
|
|
check_dtypes=True)
|
|
self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}_amin={}_amax={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
|
|
"shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
|
|
"rng": jtu.rand_default()}
|
|
for shape in all_shapes for dtype in float_dtypes
|
|
for a_min, a_max in [(-1, None), (None, 1), (-1, 1)])
|
|
def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng):
|
|
onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
|
|
lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}_decimals={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), decimals),
|
|
"shape": shape, "dtype": dtype, "decimals": decimals,
|
|
"rng": jtu.rand_default()}
|
|
for shape in all_shapes for dtype in float_dtypes
|
|
for decimals in [0, 1, -2])
|
|
def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
|
|
onp_fun = lambda x: onp.round(x, decimals=decimals)
|
|
lnp_fun = lambda x: lnp.round(x, decimals=decimals)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
|
|
axis, ",".join(str(d) for d in base_shape),
|
|
",".join(onp.dtype(dtype).name for dtype in dtypes)),
|
|
"axis": axis, "base_shape": base_shape, "dtypes": dtypes,
|
|
"rng": jtu.rand_default()}
|
|
for num_arrs in [3]
|
|
for dtypes in CombosWithReplacement(default_dtypes, num_arrs)
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
for axis in range(-len(base_shape)+1, len(base_shape)))
|
|
def testConcatenate(self, axis, base_shape, dtypes, rng):
|
|
wrapped_axis = axis % len(base_shape)
|
|
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
|
|
onp_fun = lambda *args: onp.concatenate(args, axis=axis)
|
|
lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)
|
|
|
|
def args_maker():
|
|
return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
|
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}".format(
|
|
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
|
|
"shape": shape, "dtypes": dtypes, "rng": rng}
|
|
for dtypes in [
|
|
[onp.float32],
|
|
[onp.float32, onp.float32],
|
|
[onp.float32, onp.int32, onp.float32],
|
|
[onp.float32, onp.int64, onp.float32],
|
|
[onp.float32, onp.int32, onp.float64],
|
|
]
|
|
for shape in [(), (2,), (3, 4), (1, 100)]
|
|
for rng in [jtu.rand_default()])
|
|
def testStack(self, shape, dtypes, rng):
|
|
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
|
self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_inshape=[{}]_indtype={}_outdtype={}".format(
|
|
"_".join(str(d) for d in shape),
|
|
onp.dtype(fill_value_dtype).name, onp.dtype(out_dtype).name),
|
|
"shape": shape, "fill_value_dtype": fill_value_dtype,
|
|
"out_dtype": out_dtype, "rng": jtu.rand_default()}
|
|
for shape in all_shapes
|
|
for fill_value_dtype in default_dtypes
|
|
for out_dtype in default_dtypes)
|
|
def testFull(self, shape, fill_value_dtype, out_dtype, rng):
|
|
onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
|
|
lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
|
|
args_maker = lambda: [rng((), fill_value_dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}_axis={}_{}sections".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
|
|
"shape": shape, "num_sections": num_sections, "axis": axis,
|
|
"dtype": dtype, "rng": jtu.rand_default()}
|
|
for shape, axis, num_sections in [
|
|
((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
|
|
((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
|
|
for dtype in default_dtypes)
|
|
def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
|
|
onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
|
|
lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_inshape={}_outshape={}".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
jtu.format_shape_dtype_string(out_shape, dtype)),
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
|
"rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for arg_shape, out_shape in [
|
|
((3, 4), 12),
|
|
((3, 4), (12,)),
|
|
((3, 4), -1),
|
|
((2, 1, 4), (-1,)),
|
|
((2, 2, 4), (2, 8))
|
|
])
|
|
def testReshape(self, arg_shape, out_shape, dtype, rng):
|
|
onp_fun = lambda x: onp.reshape(x, out_shape)
|
|
lnp_fun = lambda x: lnp.reshape(x, out_shape)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_inshape={}_expanddim={}".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype), dim),
|
|
"arg_shape": arg_shape, "dtype": dtype, "dim": dim,
|
|
"rng": jtu.rand_default()}
|
|
for arg_shape in [(), (3,), (3, 4)]
|
|
for dtype in default_dtypes
|
|
for dim in range(-len(arg_shape)+1, len(arg_shape)))
|
|
def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng):
|
|
onp_fun = lambda x: onp.expand_dims(x, dim)
|
|
lnp_fun = lambda x: lnp.expand_dims(x, dim)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_inshape={}_axes=({},{})".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
|
|
"arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2,
|
|
"rng": jtu.rand_default()}
|
|
for arg_shape, ax1, ax2 in [
|
|
((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
|
|
((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
|
|
for dtype in default_dtypes)
|
|
def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng):
|
|
onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
|
|
lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_inshape={}_axis={}".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype), ax),
|
|
"arg_shape": arg_shape, "dtype": dtype, "ax": ax,
|
|
"rng": jtu.rand_default()}
|
|
for arg_shape, ax in [
|
|
((3, 1), None),
|
|
((3, 1), 1),
|
|
((1, 3, 1), (0, 2)),
|
|
((1, 4, 1), (0,))]
|
|
for dtype in default_dtypes)
|
|
def testSqueeze(self, arg_shape, dtype, ax, rng):
|
|
onp_fun = lambda x: onp.squeeze(x, ax)
|
|
lnp_fun = lambda x: lnp.squeeze(x, ax)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_arg{}".format(i), "arg": arg}
|
|
for i, arg in enumerate([
|
|
[1, 2, 3], [1., 2., 3.],
|
|
[[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
|
|
[[3, onp.array(2), 1], onp.arange(3.)],
|
|
]))
|
|
def testArray(self, arg):
|
|
args_maker = lambda: [arg]
|
|
self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)
|
|
|
|
def testAllClose(self):
|
|
rng = onp.random.RandomState(0)
|
|
x = rng.randn(2, 2)
|
|
y = rng.randn(2)
|
|
|
|
def same(list1, list2):
|
|
allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
|
|
elements_close = list(map(allclose, list1, list2))
|
|
return lnp.all(lnp.array(elements_close))
|
|
|
|
csame = api.jit(same)
|
|
|
|
a1 = same((x, y), (x, y))
|
|
a2 = csame((x, y), (x, y))
|
|
a3 = csame((x, y), (x, 2 * y))
|
|
|
|
self.assertTrue(a1)
|
|
self.assertTrue(a2)
|
|
self.assertFalse(a3)
|
|
|
|
@jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure
|
|
def DISABLED_testOnesBroadcastingConstantHandler(self):
|
|
# TODO(mattjj): update this test for jax3
|
|
|
|
def fun(x):
|
|
ones = lnp.ones((3, 4))
|
|
assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)
|
|
|
|
# To check that the constant handler generates a Broadcast for stride-zero
|
|
# arrays, we monkey-patch the client instance.
|
|
# TODO(mattjj): once we have better HLO dumping and inspecting facilities,
|
|
# we can check the HLO more directly.
|
|
c = x._node.c
|
|
Broadcast = c.Broadcast # pylint: disable=invalid-name
|
|
was_called = []
|
|
c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
|
|
out = x + ones # the ndarray constant handler should call Broadcast here
|
|
assert was_called, "Broadcast was not called."
|
|
|
|
return out
|
|
|
|
fun = api.jit(fun)
|
|
out_val = fun(lnp.ones(4))
|
|
self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)
|
|
|
|
def testZeroStridesConstantHandler(self):
|
|
raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
|
|
const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))
|
|
|
|
def fun(x):
|
|
return x * const
|
|
|
|
fun = api.jit(fun)
|
|
out_val = fun(3.)
|
|
self.assertAllClose(out_val, 3. * const, check_dtypes=False)
|
|
|
|
def testIsInstanceNdarrayDuringTracing(self):
|
|
arr = onp.ones(3)
|
|
|
|
@api.jit
|
|
def f(x):
|
|
self.assertIsInstance(x, lnp.ndarray)
|
|
return lnp.sum(x)
|
|
|
|
f(arr)
|
|
|
|
|
|
def testNonArrayErrorMessage(self):
|
|
x = [1., 2.]
|
|
y = onp.array([3., 4.])
|
|
|
|
def g(x, y):
|
|
return lnp.add(x, y)
|
|
|
|
def f(x, y):
|
|
return lnp.dot(x, y)
|
|
|
|
self.assertRaises(TypeError, lambda: g(x, y))
|
|
self.assertRaises(TypeError, lambda: f(x, y))
|
|
self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
|
|
self.assertRaises(TypeError, lambda: api.jit(f)(x, y))
|
|
|
|
def testAbstractionErrorMessage(self):
|
|
|
|
@api.jit
|
|
def f(x, n):
|
|
for _ in range(n):
|
|
x = x * x
|
|
return x
|
|
|
|
self.assertRaises(TypeError, lambda: f(3., 3))
|
|
|
|
@api.jit
|
|
def g(x):
|
|
if x > 0.:
|
|
return x * 2
|
|
else:
|
|
return x + 2
|
|
|
|
self.assertRaises(TypeError, lambda: g(3.))
|
|
|
|
def DISABLED_testTracingPrimitiveWithNoTranslationErrorMessage(self):
|
|
# TODO(mattjj): update this for jax3
|
|
foo = lnp._not_implemented(lambda x: x)
|
|
|
|
# No error if there's no tracing.
|
|
foo(onp.arange(3))
|
|
|
|
cfoo = api.jit(foo)
|
|
self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))
|
|
|
|
# TODO(mattjj): test infix operator overrides
|
|
|
|
def DISABLED_testRavel(self):
|
|
# TODO(mattjj): support this method-based syntax?
|
|
rng = onp.random.RandomState(0)
|
|
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
|
self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
|
|
|
|
# TODO(mattjj): test other ndarray-like method overrides
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config.config_with_absl()
|
|
absltest.main()
|