2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-11-17 18:03:33 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2022-08-05 22:18:53 -07:00
|
|
|
from __future__ import annotations
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from functools import partial
|
|
|
|
import itertools
|
2023-02-28 12:40:30 -08:00
|
|
|
import math
|
2021-02-08 13:37:25 -08:00
|
|
|
import operator
|
2024-03-04 06:16:24 -08:00
|
|
|
import platform
|
2022-08-05 22:18:53 -07:00
|
|
|
import types
|
2021-05-11 04:31:08 -07:00
|
|
|
import unittest
|
2020-06-01 11:49:35 -07:00
|
|
|
from unittest import SkipTest
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
import jax
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2024-12-18 19:37:58 -08:00
|
|
|
from jax import jvp, grad
|
2018-11-17 18:03:33 -08:00
|
|
|
from jax import lax
|
2022-03-07 12:25:01 -08:00
|
|
|
import jax.numpy as jnp
|
2018-12-17 17:20:52 -08:00
|
|
|
from jax.test_util import check_grads
|
2019-12-17 21:42:37 -05:00
|
|
|
import jax.util
|
2020-12-07 09:10:34 -08:00
|
|
|
|
2022-08-05 22:18:53 -07:00
|
|
|
from jax.interpreters import batching
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax.interpreters import xla
|
2022-09-27 10:06:10 -07:00
|
|
|
from jax._src import array
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import config
|
2022-03-07 12:25:01 -08:00
|
|
|
from jax._src import dtypes
|
|
|
|
from jax._src import lax_reference
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import test_util as jtu
|
2025-01-21 13:28:08 -08:00
|
|
|
from jax._src.errors import UnexpectedTracerError
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src.interpreters import mlir
|
|
|
|
from jax._src.interpreters import pxla
|
|
|
|
from jax._src.internal_test_util import lax_test_util
|
2022-03-07 12:25:01 -08:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2024-06-13 13:09:35 -07:00
|
|
|
from jax._src.util import NumpyComplexWarning, safe_zip
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
from jax._src.tree_util import tree_map
|
2020-12-07 09:10:34 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
config.parse_flags_with_absl()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### lax tests
|
|
|
|
|
2021-04-21 23:58:34 +02:00
|
|
|
# We check cases where the preferred type is at least as wide as the input
|
|
|
|
# type and where both are either both floating-point or both integral,
|
|
|
|
# which are the only supported configurations.
|
|
|
|
preferred_type_combinations = [
|
|
|
|
(np.float16, np.float16), (np.float16, np.float32), (np.float16, np.float64),
|
|
|
|
(dtypes.bfloat16, dtypes.bfloat16), (dtypes.bfloat16, np.float32),
|
|
|
|
(dtypes.bfloat16, np.float64), (np.float32, np.float32), (np.float32, np.float64),
|
|
|
|
(np.float64, np.float64), (np.int8, np.int8), (np.int8, np.int16), (np.int8, np.int32),
|
|
|
|
(np.int8, np.int64), (np.int16, np.int16), (np.int16, np.int32), (np.int16, np.int64),
|
|
|
|
(np.int32, np.int32), (np.int32, np.int64), (np.int64, np.int64),
|
2022-06-15 14:12:09 -07:00
|
|
|
(np.complex64, np.complex64), (np.complex64, np.complex128), (np.complex128, np.complex128),
|
|
|
|
(np.int8, np.float16), (np.int8, dtypes.bfloat16), (np.int8, np.float32), (np.int8, np.float64),
|
|
|
|
(np.int16, np.float16), (np.int16, dtypes.bfloat16), (np.int16, np.float32), (np.int16, np.float64),
|
|
|
|
(np.int32, np.float32), (np.int32, np.float64), (np.int64, np.float64)]
|
2021-04-21 23:58:34 +02:00
|
|
|
|
2019-10-22 19:53:59 -04:00
|
|
|
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
def _reduce_custom_add(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
def _reduce_custom_mul(x, y):
|
|
|
|
return x * y
|
|
|
|
|
|
|
|
def _reduce_custom_sub(x, y):
|
|
|
|
return x - y
|
|
|
|
|
|
|
|
def _reduce_custom_min(x, y):
|
|
|
|
return jnp.minimum(x, y)
|
|
|
|
|
|
|
|
def _reduce_custom_max(x, y):
|
|
|
|
return jnp.maximum(x, y)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class LaxTest(jtu.JaxTestCase):
|
|
|
|
"""Numerical tests for LAX operations."""
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
2023-02-16 15:29:12 -08:00
|
|
|
[dict(op_name=rec.op, rng_factory=rec.rng_factory)],
|
|
|
|
shapes=itertools.chain.from_iterable(
|
|
|
|
itertools.combinations_with_replacement(shape_group, rec.nargs)
|
|
|
|
for shape_group in lax_test_util.compatible_shapes),
|
|
|
|
dtype=rec.dtypes)
|
|
|
|
for rec in lax_test_util.lax_ops()))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOp(self, op_name, rng_factory, shapes, dtype):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2019-10-21 10:56:54 -04:00
|
|
|
op = getattr(lax, op_name)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)],
|
|
|
|
shapes=itertools.chain.from_iterable(
|
|
|
|
itertools.combinations_with_replacement(shape_group, rec.nargs)
|
2023-02-16 15:29:12 -08:00
|
|
|
for shape_group in lax_test_util.compatible_shapes),
|
2022-10-03 13:36:01 +00:00
|
|
|
dtype=rec.dtypes)
|
2023-02-16 15:29:12 -08:00
|
|
|
for rec in lax_test_util.lax_ops()))
|
2024-09-24 12:28:32 -07:00
|
|
|
@jtu.ignore_warning(message="invalid value", category=RuntimeWarning)
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
2023-10-12 13:15:22 +01:00
|
|
|
if (not config.enable_x64.value and op_name == "nextafter"
|
2020-07-14 13:03:24 -07:00
|
|
|
and dtype == np.float64):
|
2020-03-12 15:53:47 -07:00
|
|
|
raise SkipTest("64-bit mode disabled")
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2019-10-21 10:56:54 -04:00
|
|
|
op = getattr(lax, op_name)
|
|
|
|
numpy_op = getattr(lax_reference, op_name)
|
2023-09-05 18:48:18 -07:00
|
|
|
tol = tol or jtu.default_tolerance()
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["tpu"]):
|
2023-09-05 18:48:18 -07:00
|
|
|
if dtype in (np.float32, np.complex64) and op_name in (
|
2023-12-04 14:20:14 -08:00
|
|
|
"acosh", "asinh", "betainc", "cos", "cosh", "digamma", "exp", "exp2", "igamma",
|
2023-09-05 18:48:18 -07:00
|
|
|
"igammac", "log", "log1p", "logistic", "pow", "sin", "sinh", "tan"):
|
2023-09-12 15:59:00 -07:00
|
|
|
tol = jtu.join_tolerance(tol, 2e-4)
|
2023-09-27 12:10:06 -07:00
|
|
|
elif op_name == "asinh" and dtype == np.float16:
|
|
|
|
tol = jtu.join_tolerance(tol, 1e-3)
|
|
|
|
elif op_name == "lgamma" and dtype == np.float32:
|
2023-09-05 18:48:18 -07:00
|
|
|
tol = jtu.join_tolerance(tol, 1e-3)
|
2024-12-12 10:50:05 -08:00
|
|
|
elif op_name == "pow" and dtype == np.complex128:
|
|
|
|
tol = jtu.join_tolerance(tol, 2e-15)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker, tol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# TODO test shift_left, shift_right_arithmetic, shift_right_logical
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
|
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[None, np.float32, np.int32, "float32", "int32"], repeat=2)],
|
|
|
|
weak_type=[False, True],
|
|
|
|
)
|
2021-02-08 13:37:25 -08:00
|
|
|
def testConvertElementType(self, from_dtype, to_dtype, weak_type):
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
2022-03-09 18:18:16 -08:00
|
|
|
op = lambda x: lax_internal._convert_element_type(x, to_dtype, weak_type)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
x = rng((1,), from_dtype)
|
|
|
|
out = op(x)
|
|
|
|
self.assertEqual(out.dtype, dtypes.canonicalize_dtype(to_dtype or x.dtype))
|
|
|
|
self.assertEqual(out.aval.weak_type, weak_type)
|
|
|
|
|
2023-04-04 15:57:32 -07:00
|
|
|
def testConvertElementTypeOOB(self):
|
|
|
|
out = lax.convert_element_type(2 ** 32, 'int32')
|
|
|
|
self.assertEqual(out, 0)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
|
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[np.float32, np.int32, "float32", "int32"], repeat=2)],
|
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.convert_element_type(x, to_dtype)
|
|
|
|
numpy_op = lambda x: lax_reference.convert_element_type(x, to_dtype)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2024-12-20 12:45:24 -08:00
|
|
|
from_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
|
|
|
|
to_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
|
2023-02-16 08:21:18 -08:00
|
|
|
shape = [(), (2,), (2, 3)]
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2023-02-16 08:21:18 -08:00
|
|
|
def testBitcastConvertType(self, from_dtype, to_dtype, shape):
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2024-12-20 12:45:24 -08:00
|
|
|
nbits_in = dtypes.bit_width(from_dtype)
|
|
|
|
nbits_out = dtypes.bit_width(to_dtype)
|
|
|
|
if nbits_in < nbits_out:
|
|
|
|
shape = (*shape, nbits_out // nbits_in)
|
2023-02-16 08:21:18 -08:00
|
|
|
args_maker = lambda: [rng(shape, from_dtype)]
|
2024-12-20 12:45:24 -08:00
|
|
|
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2023-02-16 08:21:18 -08:00
|
|
|
# Test the shape and dtype of the output. We avoid testing the values here
|
|
|
|
# because the bitwise representation may vary from platform to platform.
|
2024-12-20 12:45:24 -08:00
|
|
|
out = jnp_op(*args_maker())
|
|
|
|
if nbits_in == nbits_out:
|
2023-02-16 08:21:18 -08:00
|
|
|
expected_shape = shape
|
2024-12-20 12:45:24 -08:00
|
|
|
elif nbits_in < nbits_out:
|
2023-02-16 08:21:18 -08:00
|
|
|
expected_shape = shape[:-1]
|
|
|
|
else:
|
2024-12-20 12:45:24 -08:00
|
|
|
expected_shape = (*shape, nbits_in // nbits_out)
|
2023-02-16 08:21:18 -08:00
|
|
|
self.assertEqual(out.dtype, to_dtype)
|
|
|
|
self.assertEqual(out.shape, expected_shape)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
|
|
|
|
for from_dtype, to_dtype in itertools.product(
|
2024-12-20 12:45:24 -08:00
|
|
|
['int4', 'uint4', np.int8, np.uint8, np.int32, np.float16, np.float32],
|
|
|
|
repeat=2)],
|
|
|
|
shape=[(4,), (2, 4), (2, 3, 4)]
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2024-12-20 12:45:24 -08:00
|
|
|
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape):
|
|
|
|
nbits_in = dtypes.bit_width(from_dtype)
|
|
|
|
nbits_out = dtypes.bit_width(to_dtype)
|
|
|
|
if nbits_in < nbits_out:
|
|
|
|
shape = (*shape, nbits_out // nbits_in)
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2024-12-20 12:45:24 -08:00
|
|
|
args_maker = lambda: [rng(shape, from_dtype)]
|
|
|
|
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
|
|
|
np_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
|
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
|
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[np.float32, np.int32, "float32", "int32"], repeat=2)],
|
|
|
|
weak_type=[False, True],
|
|
|
|
)
|
2021-02-08 13:37:25 -08:00
|
|
|
def testBitcastConvertWeakType(self, from_dtype, to_dtype, weak_type):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-03-09 18:18:16 -08:00
|
|
|
x_in = lax_internal._convert_element_type(rng((2, 3), from_dtype),
|
|
|
|
weak_type=weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(x_in), weak_type)
|
|
|
|
x_out = op(x_in)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(x_out), False)
|
2021-09-13 16:00:22 -04:00
|
|
|
x_out_jit = jax.jit(op)(x_in)
|
2021-02-08 13:37:25 -08:00
|
|
|
self.assertEqual(dtypes.is_weakly_typed(x_out_jit), False)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(min_shape=min_shape, operand_shape=operand_shape, max_shape=max_shape)
|
|
|
|
for min_shape, operand_shape, max_shape in [
|
2018-11-17 18:03:33 -08:00
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), ()],
|
|
|
|
[(), (2, 3), (2, 3)],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testClamp(self, min_shape, operand_shape, max_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(lax.clamp, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(min_shape=min_shape, operand_shape=operand_shape, max_shape=max_shape)
|
|
|
|
for min_shape, operand_shape, max_shape in [
|
2018-11-17 18:03:33 -08:00
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), ()],
|
|
|
|
[(), (2, 3), (2, 3)],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testClampAgainstNumpy(self, min_shape, operand_shape, max_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(lax_reference.clamp, lax.clamp, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(base_shape=shape, dim=dim) for shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for dim in range(len(shape))],
|
|
|
|
num_arrs=[3],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testConcatenate(self, dim, base_shape, dtype, num_arrs):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
op = lambda *args: lax.concatenate(args, dim)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(base_shape=shape, dim=dim) for shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for dim in range(len(shape))],
|
|
|
|
num_arrs=[3],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
op = lambda *args: lax.concatenate(args, dim)
|
|
|
|
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2024-12-17 10:05:58 -08:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for axis in range(len(shape))],
|
|
|
|
num_pieces=range(3),
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
|
|
|
)
|
|
|
|
def testSplit(self, axis, base_shape, dtype, num_pieces):
|
|
|
|
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
|
|
|
|
shape = list(base_shape)
|
|
|
|
shape[axis] = np.sum(sizes)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.split(x, sizes, axis=axis)
|
|
|
|
def numpy_op(x):
|
|
|
|
return np.split(x, np.cumsum(sizes[:-1]), axis=axis)
|
|
|
|
self._CompileAndCheck(op, args_maker)
|
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
|
|
|
|
|
|
|
def testSplitErrors(self):
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
"Sizes passed to split must be nonnegative"):
|
|
|
|
lax.split(np.arange(5), [-1])
|
|
|
|
with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"):
|
|
|
|
lax.split(np.arange(5), [6])
|
|
|
|
with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"):
|
|
|
|
lax.split(np.arange(5), sizes=(), axis=1)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-03-23 15:49:44 -07:00
|
|
|
[
|
|
|
|
dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.float_dtypes,
|
|
|
|
strides=[(1, 1), (1, 2), (2, 1)],
|
|
|
|
padding=["VALID", "SAME", "SAME_LOWER"],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testConv(self, lhs_shape, rhs_shape, dtype, strides, padding):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv(lhs, rhs, strides, padding)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)],
|
|
|
|
[dict(dtype=dtype, preferred_element_type=preferred)
|
|
|
|
for dtype, preferred in preferred_type_combinations]
|
|
|
|
)
|
2023-09-05 18:48:18 -07:00
|
|
|
@jax.default_matmul_precision("float32")
|
2021-04-21 23:58:34 +02:00
|
|
|
def testConvPreferredElement(self, lhs_shape, rhs_shape, dtype, preferred_element_type):
|
2023-10-12 13:15:22 +01:00
|
|
|
if (not config.enable_x64.value and
|
2021-04-21 23:58:34 +02:00
|
|
|
(dtype == np.float64 or preferred_element_type == np.float64
|
|
|
|
or dtype == np.int64 or preferred_element_type == np.int64
|
|
|
|
or dtype == np.complex128 or preferred_element_type == np.complex128)):
|
|
|
|
raise SkipTest("64-bit mode disabled")
|
2023-09-27 12:10:06 -07:00
|
|
|
if (jtu.test_device_matches(["tpu"]) and
|
2021-04-21 23:58:34 +02:00
|
|
|
(dtype == np.complex128 or preferred_element_type == np.complex128)):
|
|
|
|
raise SkipTest("np.complex128 is not yet supported on TPU")
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["gpu"]) and np.issubdtype(dtype, np.integer):
|
2022-10-18 18:09:02 -07:00
|
|
|
# TODO(b/183565702): Support integer convolutions on CPU/GPU.
|
|
|
|
raise SkipTest("Integer convolution not yet supported on GPU")
|
2021-04-21 23:58:34 +02:00
|
|
|
# x64 implementation is only accurate to ~float32 precision for this case.
|
2021-09-13 11:43:39 -04:00
|
|
|
if dtype == np.complex64 and preferred_element_type == np.complex128:
|
|
|
|
tol = 1e-5
|
|
|
|
else:
|
|
|
|
tol = {np.float64: 1e-14}
|
2023-09-27 12:10:06 -07:00
|
|
|
if (jtu.test_device_matches(["tpu"]) and dtype == np.float16 and
|
2023-09-05 18:48:18 -07:00
|
|
|
preferred_element_type == np.float32):
|
|
|
|
tol = 2e-3
|
2023-09-27 12:10:06 -07:00
|
|
|
if (jtu.test_device_matches(["tpu"]) and dtype == jnp.bfloat16 and
|
2023-09-05 18:48:18 -07:00
|
|
|
preferred_element_type == np.float32):
|
|
|
|
tol = 1e-5
|
|
|
|
|
2021-04-21 23:58:34 +02:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
x = rng(lhs_shape, dtype)
|
|
|
|
y = rng(rhs_shape, dtype)
|
|
|
|
# We first compute the conv when both inputs are a lower-precision type and
|
|
|
|
# preferred_element_type is a higher-precision type. We then compute results
|
|
|
|
# where the inputs are first upcast to the higher-precision type and no
|
|
|
|
# `preferred_element_type` is given. We expect the result to be extremely
|
|
|
|
# similar given the semantics of `preferred_element_type`.
|
|
|
|
result_with_preferred_type = lax.conv(
|
|
|
|
x, y, (1, 1), "VALID",
|
|
|
|
preferred_element_type=preferred_element_type)
|
|
|
|
result_with_upcast_inputs = lax.conv(
|
|
|
|
x.astype(preferred_element_type),
|
|
|
|
y.astype(preferred_element_type),
|
|
|
|
(1, 1), "VALID")
|
|
|
|
self.assertArraysAllClose(
|
|
|
|
result_with_preferred_type, result_with_upcast_inputs, rtol=tol, atol=tol)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
strides=[(1, 1), (1, 2), (2, 1)],
|
|
|
|
padding=["VALID", "SAME"],
|
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testConvAgainstNumpy(self, lhs_shape, rhs_shape, dtype, strides, padding):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
op = lambda lhs, rhs: lax.conv(lhs, rhs, strides, padding)
|
|
|
|
numpy_op = lambda lhs, rhs: lax_reference.conv(lhs, rhs, strides, padding)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([1, 2, 3], repeat=3)],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
strides=[(1, 1), (1, 2), (2, 1)],
|
|
|
|
padding=[((0, 0), (0, 0)), ((1, 2), (2, 0))],
|
|
|
|
lhs_dilation=[(1, 1), (1, 2), (2, 2)],
|
|
|
|
rhs_dilation=[(1, 1), (1, 2), (2, 2)],
|
|
|
|
)
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvWithGeneralPadding(self, lhs_shape, rhs_shape, dtype, strides,
|
2020-12-03 11:01:16 -08:00
|
|
|
padding, lhs_dilation, rhs_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_with_general_padding(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([1, 2, 3], repeat=3)],
|
|
|
|
dtype=[np.float32],
|
|
|
|
strides=[(1, 1), (1, 2), (2, 1)],
|
|
|
|
padding=[((0, 0), (0, 0)), ((1, 2), (2, 0))],
|
|
|
|
lhs_dilation=[(1, 1), (1, 2), (2, 2)],
|
|
|
|
rhs_dilation=[(1, 1), (1, 2), (2, 2)],
|
|
|
|
)
|
2020-04-06 21:45:10 -04:00
|
|
|
def testConvWithGeneralPaddingAgainstNumpy(
|
2018-11-17 18:03:33 -08:00
|
|
|
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation,
|
2020-12-03 11:01:16 -08:00
|
|
|
rhs_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_with_general_padding(
|
2020-04-06 21:45:10 -04:00
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def numpy_fun(lhs, rhs):
|
|
|
|
return lax_reference.conv_with_general_padding(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
|
|
|
|
2020-04-06 21:45:10 -04:00
|
|
|
self._CheckAgainstNumpy(numpy_fun, fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
lhs_shape=(b * batch_group_count, i * feature_group_count),
|
|
|
|
rhs_shape=(j * feature_group_count * batch_group_count, i),
|
|
|
|
batch_group_count=batch_group_count,
|
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
)
|
|
|
|
for batch_group_count, feature_group_count in [(1, 1), (2, 1), (1, 2)]
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)
|
|
|
|
],
|
|
|
|
[dict(dimension_numbers=("NC", "OI", "NC"), perms=([0, 1], [0, 1]))],
|
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2022-09-08 08:10:32 -07:00
|
|
|
def testConvGeneralDilated0D(self, lhs_shape, rhs_shape, dtype,
|
|
|
|
feature_group_count, batch_group_count,
|
|
|
|
dimension_numbers, perms):
|
2022-10-19 06:49:56 -07:00
|
|
|
if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_):
|
2022-09-08 08:10:32 -07:00
|
|
|
# TODO(b/183565702): Support integer convolutions on CPU/GPU.
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["gpu"]):
|
2022-09-08 08:10:32 -07:00
|
|
|
raise SkipTest("Integer convolution not yet supported on GPU")
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
lhs_perm, rhs_perm = perms # permute to compatible shapes
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [lax.transpose(rng(lhs_shape, dtype), lhs_perm),
|
|
|
|
lax.transpose(rng(rhs_shape, dtype), rhs_perm)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_general_dilated(
|
|
|
|
lhs, rhs, window_strides=(), padding=(),
|
|
|
|
dimension_numbers=dimension_numbers,
|
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
batch_group_count=batch_group_count)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
lhs_shape=(b * batch_group_count, i * feature_group_count, 9, w),
|
|
|
|
rhs_shape=(j * feature_group_count * batch_group_count, i, 4, 5),
|
|
|
|
batch_group_count=batch_group_count,
|
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
)
|
|
|
|
for batch_group_count, feature_group_count in [(1, 1), (2, 1), (1, 2)]
|
|
|
|
for w in [0, 10]
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)
|
|
|
|
],
|
|
|
|
[
|
|
|
|
dict(
|
|
|
|
dimension_numbers=("NCHW", "OIHW", "NCHW"),
|
|
|
|
perms=([0, 1, 2, 3], [0, 1, 2, 3]),
|
|
|
|
),
|
|
|
|
dict(
|
|
|
|
dimension_numbers=("NHWC", "HWIO", "NHWC"),
|
|
|
|
perms=([0, 2, 3, 1], [2, 3, 1, 0]),
|
|
|
|
),
|
|
|
|
dict(
|
|
|
|
dimension_numbers=("NCHW", "HWIO", "NHWC"),
|
|
|
|
perms=([0, 1, 2, 3], [2, 3, 1, 0]),
|
|
|
|
),
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.all_dtypes,
|
|
|
|
strides=[(1, 1), (2, 1)],
|
|
|
|
padding=[((1, 2), (2, 0)), ((10, 8), (7, 13))],
|
|
|
|
lhs_dilation=[(1, 1), (1, 2), (1, 4)],
|
|
|
|
rhs_dilation=[(1, 1), (1, 2), (1, 4)],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvGeneralDilated(self, lhs_shape, rhs_shape, dtype, strides,
|
|
|
|
padding, lhs_dilation, rhs_dilation,
|
2020-04-09 16:21:30 -04:00
|
|
|
feature_group_count, batch_group_count,
|
2020-12-03 11:01:16 -08:00
|
|
|
dimension_numbers, perms):
|
2022-10-19 06:49:56 -07:00
|
|
|
if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_):
|
2021-03-24 11:08:07 -07:00
|
|
|
# TODO(b/183565702): Support integer convolutions on CPU/GPU.
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["gpu"]):
|
2021-03-24 11:08:07 -07:00
|
|
|
raise SkipTest("Integer convolution not yet supported on GPU")
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs_perm, rhs_perm = perms # permute to compatible shapes
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [lax.transpose(rng(lhs_shape, dtype), lhs_perm),
|
|
|
|
lax.transpose(rng(rhs_shape, dtype), rhs_perm)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_general_dilated(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
|
2020-04-09 16:21:30 -04:00
|
|
|
dimension_numbers, feature_group_count=feature_group_count,
|
|
|
|
batch_group_count=batch_group_count)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2023-09-05 18:48:18 -07:00
|
|
|
@jax.default_matmul_precision("float32")
|
2020-10-20 22:58:53 -07:00
|
|
|
def testConvGeneralDilatedPatchesOverlapping1D(self):
|
|
|
|
lhs = np.array([[1]], np.float32).reshape((1, 1))
|
|
|
|
patches = lax.conv_general_dilated_patches(
|
|
|
|
lhs=lhs,
|
|
|
|
filter_shape=(),
|
|
|
|
window_strides=(),
|
|
|
|
padding='SAME'
|
|
|
|
)
|
|
|
|
self.assertAllClose(lhs, patches)
|
|
|
|
|
|
|
|
dn = ('NHC', 'OIH', 'NHC')
|
|
|
|
lhs = np.array([1, 2, 3, 4, 5], np.float32).reshape((1, -1, 1))
|
|
|
|
|
|
|
|
patches = lax.conv_general_dilated_patches(
|
|
|
|
lhs=lhs,
|
|
|
|
filter_shape=(2,),
|
|
|
|
window_strides=(2,),
|
|
|
|
padding='VALID',
|
|
|
|
dimension_numbers=dn
|
|
|
|
)
|
|
|
|
self.assertAllClose(
|
|
|
|
np.array([[1, 2],
|
|
|
|
[3, 4]], np.float32).reshape((1, 2, 2)), patches)
|
|
|
|
|
|
|
|
patches = lax.conv_general_dilated_patches(
|
|
|
|
lhs=lhs,
|
|
|
|
filter_shape=(3,),
|
|
|
|
window_strides=(1,),
|
|
|
|
padding='SAME',
|
|
|
|
dimension_numbers=dn
|
|
|
|
)
|
|
|
|
self.assertAllClose(
|
|
|
|
np.array([[0, 1, 2],
|
|
|
|
[1, 2, 3],
|
|
|
|
[2, 3, 4],
|
|
|
|
[3, 4, 5],
|
|
|
|
[4, 5, 0]], np.float32).reshape((1, 5, 3)), patches)
|
|
|
|
|
|
|
|
patches = lax.conv_general_dilated_patches(
|
|
|
|
lhs=lhs,
|
|
|
|
filter_shape=(3,),
|
|
|
|
window_strides=(1,),
|
|
|
|
padding='SAME',
|
|
|
|
rhs_dilation=(2,),
|
|
|
|
dimension_numbers=dn
|
|
|
|
)
|
|
|
|
self.assertAllClose(
|
|
|
|
np.array([[0, 1, 3],
|
|
|
|
[0, 2, 4],
|
|
|
|
[1, 3, 5],
|
|
|
|
[2, 4, 0],
|
|
|
|
[3, 5, 0]], np.float32).reshape((1, 5, 3)), patches)
|
|
|
|
|
|
|
|
def testConvGeneralDilatedPatchesOverlapping2D(self):
|
|
|
|
lhs = np.array([[1, 2, 3],
|
|
|
|
[4, 5, 6]], np.float32).reshape((1, 2, 3, 1))
|
|
|
|
patches = lax.conv_general_dilated_patches(
|
|
|
|
lhs=lhs,
|
|
|
|
filter_shape=(2, 2),
|
|
|
|
window_strides=(1, 1),
|
|
|
|
padding='SAME',
|
|
|
|
dimension_numbers=('NHWC', 'OIHW', 'NHWC')
|
|
|
|
)
|
|
|
|
self.assertAllClose(np.array([[1, 2, 4, 5],
|
|
|
|
[2, 3, 5, 6],
|
|
|
|
[3, 0, 6, 0],
|
|
|
|
[4, 5, 0, 0],
|
|
|
|
[5, 6, 0, 0],
|
|
|
|
[6, 0, 0, 0]],
|
|
|
|
np.float32).reshape((1, 2, 3, 4)), patches)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
lhs_shape=lhs_shape,
|
|
|
|
filter_shape=filter_shape,
|
|
|
|
strides=strides,
|
|
|
|
padding=padding,
|
|
|
|
dimension_numbers=dim_nums,
|
|
|
|
)
|
|
|
|
for lhs_shape, filter_shape, strides, padding, dim_nums in [
|
|
|
|
((2, 5), (), (), [], ("NC", "OI", "CN")),
|
|
|
|
((2, 3, 4), (2,), (2,), [(0, 2)], ("CNH", "OHI", "HNC")),
|
|
|
|
(
|
|
|
|
(3, 1, 4, 5),
|
|
|
|
(1, 3),
|
|
|
|
(1, 3),
|
|
|
|
[(3, 1), (2, 2)],
|
|
|
|
("NCHW", "OIHW", "NCHW"),
|
|
|
|
),
|
|
|
|
((3, 2, 5, 6), (4, 3), (4, 3), [(5, 2), (2, 4)], None),
|
|
|
|
(
|
|
|
|
(1, 2, 3, 4),
|
|
|
|
(1, 1),
|
|
|
|
(1, 1),
|
|
|
|
[(0, 0), (0, 0)],
|
|
|
|
("NCWH", "OHWI", "CNHW"),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
(1, 2, 3, 4),
|
|
|
|
(3, 2),
|
|
|
|
(1, 1),
|
|
|
|
[(0, 0), (0, 0)],
|
|
|
|
("CWHN", "HOWI", "NCHW"),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
(2, 3, 4, 5, 6),
|
|
|
|
(2, 1, 3),
|
|
|
|
(2, 1, 3),
|
|
|
|
[(1, 2), (5, 3), (3, 5)],
|
|
|
|
("NHWDC", "HDIWO", "DCWNH"),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.all_dtypes,
|
|
|
|
precision=[
|
|
|
|
None,
|
|
|
|
lax.Precision.DEFAULT,
|
|
|
|
lax.Precision.HIGH,
|
|
|
|
lax.Precision.HIGHEST,
|
|
|
|
],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-10-20 22:58:53 -07:00
|
|
|
def testConvGeneralDilatedPatchesNonOverlapping(self,
|
|
|
|
lhs_shape,
|
|
|
|
filter_shape,
|
|
|
|
dtype,
|
|
|
|
strides,
|
|
|
|
padding,
|
|
|
|
dimension_numbers,
|
|
|
|
precision):
|
2022-10-19 06:49:56 -07:00
|
|
|
if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_):
|
2021-03-24 11:08:07 -07:00
|
|
|
# TODO(b/183565702): Support integer convolutions on CPU/GPU.
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["gpu"]):
|
2021-03-24 11:08:07 -07:00
|
|
|
raise SkipTest("Integer convolution not yet supported on GPU")
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_small(self.rng())
|
2020-10-20 22:58:53 -07:00
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
|
|
|
|
if dimension_numbers is None:
|
|
|
|
lhs_spec, rhs_spec, out_spec = "NCHW", "OIHW", "NCHW"
|
|
|
|
else:
|
|
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
|
|
|
|
|
|
|
filter_spec = ''.join(c for c in rhs_spec if c not in ('I', 'O'))
|
|
|
|
patches_spec = out_spec.replace('C', 'C' + filter_spec.lower())
|
|
|
|
|
|
|
|
full_padding = []
|
|
|
|
for c in lhs_spec:
|
|
|
|
if c in ('N', 'C'):
|
|
|
|
full_padding += [(0, 0)]
|
|
|
|
else:
|
|
|
|
full_padding += [padding[filter_spec.index(c)]]
|
|
|
|
|
|
|
|
lhs_padded = np.pad(lhs, full_padding, 'constant')
|
|
|
|
out = lax.transpose(lhs_padded, [lhs_spec.index(c) for c in out_spec])
|
|
|
|
|
|
|
|
patches = lax.conv_general_dilated_patches(
|
|
|
|
lhs=lhs,
|
|
|
|
filter_shape=filter_shape,
|
|
|
|
window_strides=strides,
|
|
|
|
padding=padding,
|
|
|
|
dimension_numbers=dimension_numbers,
|
|
|
|
precision=precision
|
|
|
|
)
|
|
|
|
|
|
|
|
source = []
|
|
|
|
|
|
|
|
# Test that output spatial shape is factored into `#patches x patch_size`.
|
|
|
|
for c in out_spec:
|
|
|
|
out_c = out.shape[out_spec.index(c)]
|
|
|
|
patch_c = patches.shape[out_spec.index(c)]
|
|
|
|
|
|
|
|
if c == 'N':
|
|
|
|
self.assertEqual(out_c, patch_c)
|
|
|
|
elif c == 'C':
|
2023-04-13 11:48:11 -07:00
|
|
|
self.assertEqual(out_c * math.prod(filter_shape), patch_c)
|
2020-10-20 22:58:53 -07:00
|
|
|
else:
|
|
|
|
self.assertEqual(out_c, patch_c * filter_shape[filter_spec.index(c)])
|
|
|
|
|
|
|
|
source += [patches_spec.index(c), patches_spec.index(c.lower())]
|
|
|
|
|
|
|
|
# Test that stacking patches together gives the source image, padded.
|
|
|
|
c = out_spec.index('C')
|
|
|
|
patches = patches.reshape(patches.shape[:c] +
|
|
|
|
(lhs_shape[lhs_spec.index('C')],) +
|
|
|
|
filter_shape +
|
|
|
|
patches.shape[c + 1:]
|
|
|
|
)
|
|
|
|
patches = np.moveaxis(patches, source, range(len(source)))
|
|
|
|
for i in range(len(filter_shape)):
|
|
|
|
patches = patches.reshape(patches.shape[:i] + (-1,) +
|
|
|
|
patches.shape[2 + i:])
|
|
|
|
patches = np.moveaxis(
|
|
|
|
patches,
|
|
|
|
range(len(filter_shape)),
|
|
|
|
[out_spec.index(c) for c in out_spec if c not in ('N', 'C')])
|
2023-09-05 18:48:18 -07:00
|
|
|
tol = None
|
2023-09-27 12:10:06 -07:00
|
|
|
if (jtu.test_device_matches(["tpu"]) and
|
2023-09-05 18:48:18 -07:00
|
|
|
precision in (None, lax.Precision.DEFAULT)):
|
|
|
|
tol = 1e-3
|
|
|
|
self.assertAllClose(out, patches, atol=tol, rtol=tol)
|
2020-10-20 22:58:53 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(n=n, lhs_spec=lhs_spec, rhs_spec=rhs_spec, out_spec=out_spec)
|
|
|
|
for n in [1, 2]
|
|
|
|
for lhs_spec in [
|
|
|
|
"".join(s) for s in itertools.permutations("NCHWD"[: n + 2])
|
|
|
|
]
|
|
|
|
for rhs_spec in [
|
|
|
|
"".join(s) for s in itertools.permutations("OIHWDX"[: n + 2])
|
|
|
|
]
|
|
|
|
for out_spec in [
|
|
|
|
"".join(s) for s in itertools.permutations("NCHWDX"[: n + 2])
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.inexact_dtypes,
|
|
|
|
precision=[
|
|
|
|
None,
|
|
|
|
lax.Precision.DEFAULT,
|
|
|
|
lax.Precision.HIGH,
|
|
|
|
lax.Precision.HIGHEST,
|
|
|
|
(lax.Precision.DEFAULT, lax.Precision.HIGHEST),
|
|
|
|
],
|
|
|
|
padding=["SAME", "VALID"],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
|
|
|
def testConvGeneralDilatedLocal(self, dtype, precision, n, padding, lhs_spec,
|
|
|
|
rhs_spec, out_spec):
|
2021-05-13 12:20:31 -07:00
|
|
|
"""Make sure LCN with tiled CNN kernel matches CNN."""
|
|
|
|
lhs_spec_default = 'NCHWDX'[:n + 2]
|
|
|
|
rhs_spec_default = 'OIHWDX'[:n + 2]
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
rng = jtu.rand_small(self.rng())
|
2021-05-13 12:20:31 -07:00
|
|
|
|
|
|
|
lhs_default = rng((2, 4, 7, 6, 5, 8)[:n + 2], dtype)
|
|
|
|
rhs_default = rng((5, 4, 2, 3, 1, 2)[:n + 2], dtype)
|
|
|
|
|
|
|
|
window_strides = (1, 2, 3, 4)[:n]
|
|
|
|
rhs_dilation = (2, 1, 3, 2)[:n]
|
|
|
|
|
|
|
|
lhs_perm = [lhs_spec_default.index(c) for c in lhs_spec]
|
|
|
|
lhs = np.transpose(lhs_default, lhs_perm)
|
|
|
|
|
|
|
|
rhs_perm = [rhs_spec_default.index(c) for c in rhs_spec]
|
|
|
|
rhs = np.transpose(rhs_default, rhs_perm)
|
|
|
|
|
|
|
|
kwargs = dict(
|
|
|
|
lhs=lhs,
|
|
|
|
window_strides=window_strides,
|
|
|
|
padding=padding,
|
|
|
|
rhs_dilation=rhs_dilation,
|
|
|
|
dimension_numbers=(lhs_spec, rhs_spec, out_spec),
|
|
|
|
precision=precision
|
|
|
|
)
|
|
|
|
|
|
|
|
out_conv = lax.conv_general_dilated(rhs=rhs, **kwargs)
|
|
|
|
|
|
|
|
rhs_local = np.moveaxis(rhs, (rhs_spec.index('O'), rhs_spec.index('I')),
|
|
|
|
(0, 1))
|
|
|
|
rhs_local = rhs_local.reshape((rhs_local.shape[0], -1) + (1,) * n)
|
|
|
|
|
|
|
|
rhs_shape = (rhs_local.shape[:2] +
|
|
|
|
tuple(out_conv.shape[out_spec.index(c)]
|
|
|
|
for c in rhs_spec_default[2:]))
|
|
|
|
|
|
|
|
rhs_local = np.broadcast_to(rhs_local, rhs_shape)
|
|
|
|
rhs_local = np.transpose(rhs_local, rhs_perm)
|
|
|
|
|
|
|
|
filter_shape = [rhs.shape[i]
|
|
|
|
for i in range(n + 2) if rhs_spec[i] not in ('O', 'I')]
|
|
|
|
out_local = lax.conv_general_dilated_local(rhs=rhs_local,
|
|
|
|
filter_shape=filter_shape,
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
self.assertAllClose(out_conv, out_local)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# TODO(mattjj): test conv_general_dilated against numpy
|
|
|
|
|
2020-01-09 14:36:37 -05:00
|
|
|
def testConv0DIsDot(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-09 14:36:37 -05:00
|
|
|
def args_maker():
|
2020-07-14 13:03:24 -07:00
|
|
|
return [rng((10, 5), np.float32), rng((5, 7), np.float32)]
|
2020-01-09 14:36:37 -05:00
|
|
|
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
|
|
|
|
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np.dot, jnp_fun, args_maker, tol=.1)
|
2020-01-09 14:36:37 -05:00
|
|
|
|
2021-07-11 10:09:19 +03:00
|
|
|
def testGradConv0D(self):
|
|
|
|
# Reproduces a failure in neural_tangents not caught in our presubmit tests
|
|
|
|
# See cl/367416742.
|
|
|
|
lhs = np.ones((2, 5), dtype=np.float32)
|
|
|
|
rhs = np.ones((5, 10), dtype=np.float32)
|
|
|
|
|
|
|
|
def f_jax(lhs, rhs):
|
|
|
|
return lax.conv_general_dilated(
|
|
|
|
lhs, rhs, window_strides=(),
|
|
|
|
padding=(), lhs_dilation=(), rhs_dilation=(),
|
|
|
|
dimension_numbers=lax.ConvDimensionNumbers((0, 1), (1, 0), (0, 1)),
|
|
|
|
batch_group_count=1, feature_group_count=1, precision=None,
|
|
|
|
preferred_element_type=None)
|
|
|
|
res, pullback = jax.vjp(f_jax, lhs, rhs)
|
|
|
|
grad = pullback(np.ones_like(res))
|
|
|
|
self.assertAllClose((lhs * 10., rhs * 2.), grad)
|
2020-01-09 14:36:37 -05:00
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
@staticmethod
|
|
|
|
def _conv_transpose_via_grad(data, kernel, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=None, dimension_numbers=None):
|
2019-04-11 09:12:21 -07:00
|
|
|
"""Helper method: calculates conv transpose via grad for testing."""
|
2019-04-09 22:59:03 -07:00
|
|
|
assert len(data.shape) == len(kernel.shape)
|
|
|
|
nspatial = len(data.shape) - 2
|
|
|
|
one = (1,) * nspatial
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation = rhs_dilation or one
|
2019-04-09 22:59:03 -07:00
|
|
|
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
|
|
|
dimension_numbers)
|
2020-07-14 13:03:24 -07:00
|
|
|
in_shape = np.take(data.shape, dn.lhs_spec)
|
2019-04-09 22:59:03 -07:00
|
|
|
in_sdims = in_shape[2:]
|
2020-07-14 13:03:24 -07:00
|
|
|
k_shape = np.take(kernel.shape, dn.rhs_spec)
|
2019-04-09 22:59:03 -07:00
|
|
|
k_sdims = k_shape[2:]
|
2019-12-17 02:03:17 +00:00
|
|
|
e_k_sdims = [(k-1) * r + 1 for k, r in zip(k_sdims, rhs_dilation)]
|
2019-04-09 22:59:03 -07:00
|
|
|
if padding == 'VALID':
|
2019-12-17 02:03:17 +00:00
|
|
|
o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0)
|
2019-04-09 22:59:03 -07:00
|
|
|
for i in range(nspatial)]
|
|
|
|
elif padding == 'SAME':
|
|
|
|
o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)]
|
|
|
|
o_shape = [in_shape[0], k_shape[1]] + o_sdims
|
|
|
|
out_spec_inv = [x[0] for x in
|
|
|
|
sorted(enumerate(dn.out_spec), key=lambda x: x[1])]
|
2020-07-14 13:03:24 -07:00
|
|
|
o_layout = np.take(np.array(o_shape), out_spec_inv)
|
|
|
|
placeholder = np.ones(o_layout, data.dtype)
|
2019-04-09 22:59:03 -07:00
|
|
|
conv = lambda x: lax.conv_general_dilated(x, kernel, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
one, rhs_dilation, dn)
|
2021-09-13 16:00:22 -04:00
|
|
|
_, g = jax.vjp(conv, placeholder)
|
2019-04-09 22:59:03 -07:00
|
|
|
return g(data)[0]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _transpose_conv_kernel(data, kernel, dimension_numbers):
|
|
|
|
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
|
|
|
dimension_numbers)
|
2020-07-14 13:03:24 -07:00
|
|
|
spatial_axes = np.array(dn.rhs_spec)[2:]
|
2019-04-09 22:59:03 -07:00
|
|
|
for axis in spatial_axes:
|
2020-07-14 13:03:24 -07:00
|
|
|
kernel = np.flip(kernel, axis)
|
|
|
|
kernel = np.swapaxes(kernel, dn.rhs_spec[0], dn.rhs_spec[1])
|
2019-04-09 22:59:03 -07:00
|
|
|
return kernel
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
(
|
|
|
|
(b, 9, 10, i),
|
|
|
|
(k, k, j, i),
|
|
|
|
) # NB: i,j flipped in RHS for transpose
|
|
|
|
for b, i, j, k in itertools.product(
|
|
|
|
[2, 3], [2, 3], [2, 3], [3, 4, 5]
|
|
|
|
)
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.float_dtypes,
|
|
|
|
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
|
|
|
|
padding=["VALID", "SAME"],
|
|
|
|
dspec=[
|
|
|
|
("NHWC", "HWIO", "NHWC"),
|
|
|
|
],
|
|
|
|
rhs_dilation=[None, (2, 2)],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose2DT(self, lhs_shape, rhs_shape, dtype, strides,
|
2020-12-03 11:01:16 -08:00
|
|
|
padding, dspec, rhs_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2019-04-09 22:59:03 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
# NB: this test calculates conv_transpose performing identically to the
|
|
|
|
# lhs-grad of conv.
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
|
|
|
transpose_kernel=True)
|
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, 9, 10, i), (k, k, i, j))
|
|
|
|
for b, i, j, k in itertools.product(
|
|
|
|
[2, 3], [2, 3], [2, 3], [3, 4, 5]
|
|
|
|
)
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.float_dtypes,
|
|
|
|
strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)],
|
|
|
|
padding=["VALID", "SAME"],
|
|
|
|
dspec=[
|
|
|
|
("NHWC", "HWIO", "NHWC"),
|
|
|
|
],
|
|
|
|
rhs_dilation=[None, (2, 2)],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose2D(self, lhs_shape, rhs_shape, dtype, strides,
|
2020-12-03 11:01:16 -08:00
|
|
|
padding, dspec, rhs_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2019-04-09 22:59:03 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
|
|
|
transpose_kernel=False)
|
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2019-04-09 22:59:03 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, 10, i), (k, i, j))
|
|
|
|
for b, i, j, k in itertools.product(
|
|
|
|
[2, 3], [2, 3], [2, 3], [3, 4, 5]
|
|
|
|
)
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.float_dtypes,
|
|
|
|
strides=[(1,), (2,), (3,)],
|
|
|
|
padding=["VALID", "SAME"],
|
|
|
|
dspec=[
|
|
|
|
("NHC", "HIO", "NHC"),
|
|
|
|
],
|
|
|
|
rhs_dilation=[None, (2,)],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose1D(self, lhs_shape, rhs_shape, dtype, strides,
|
2020-12-03 11:01:16 -08:00
|
|
|
padding, dspec, rhs_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2019-04-09 15:06:46 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
transpose_kernel=False)
|
2020-07-02 14:38:35 -07:00
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
|
|
|
rhs_dilation=rhs_dilation,
|
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2020-07-02 14:38:35 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i), (i, j))
|
|
|
|
for b, i, j in itertools.product([2, 3], [2, 3], [2, 3])
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.float_dtypes,
|
|
|
|
strides=[()],
|
|
|
|
padding=["VALID", "SAME"],
|
|
|
|
dspec=[
|
|
|
|
("NC", "IO", "NC"),
|
|
|
|
],
|
|
|
|
rhs_dilation=[None, ()],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-07-02 14:38:35 -07:00
|
|
|
def testConvTranspose0D(self, lhs_shape, rhs_shape, dtype, strides,
|
2020-12-03 11:01:16 -08:00
|
|
|
padding, dspec, rhs_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2020-07-02 14:38:35 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
|
|
|
dimension_numbers=dspec,
|
|
|
|
rhs_dilation=rhs_dilation,
|
|
|
|
transpose_kernel=False)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
2019-04-09 22:59:03 -07:00
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2021-11-24 22:01:52 -08:00
|
|
|
def testConvTransposePaddingList(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/discussions/8695
|
2021-11-24 22:01:52 -08:00
|
|
|
a = jnp.ones((28,28))
|
|
|
|
b = jnp.ones((3,3))
|
|
|
|
c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1))
|
2022-01-27 09:11:21 -05:00
|
|
|
self.assertAllClose(c, 9 * jnp.ones((1, 1, 26, 26)))
|
2021-11-24 22:01:52 -08:00
|
|
|
|
2022-05-16 16:06:52 -04:00
|
|
|
def testConvInvalidPadding(self):
|
|
|
|
x = jnp.ones((1, 10, 10, 5), dtype=jnp.bfloat16)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
r"padding argument.*, got \(3, 3\)"):
|
|
|
|
jax.lax.conv_general_dilated_patches(x, (5, 5), window_strides=(1, 1),
|
|
|
|
padding=(3, 3))
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-05-12 19:56:59 -07:00
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]],
|
|
|
|
[dict(lhs_dtype=lhs_dtype, rhs_dtype=rhs_dtype)
|
|
|
|
for lhs_dtype, rhs_dtype in
|
|
|
|
itertools.chain(
|
|
|
|
itertools.product(lax_test_util.int_dtypes +
|
|
|
|
lax_test_util.float_dtypes +
|
|
|
|
lax_test_util.complex_dtypes +
|
|
|
|
lax_test_util.uint_dtypes,
|
|
|
|
repeat=2),
|
|
|
|
zip(lax_test_util.bool_dtypes, lax_test_util.bool_dtypes))],
|
2023-02-16 15:29:12 -08:00
|
|
|
precision=[
|
|
|
|
None,
|
|
|
|
lax.Precision.DEFAULT,
|
|
|
|
lax.Precision.HIGH,
|
|
|
|
lax.Precision.HIGHEST,
|
|
|
|
(lax.Precision.DEFAULT, lax.Precision.HIGHEST),
|
|
|
|
],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2023-05-12 19:56:59 -07:00
|
|
|
def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision):
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2023-05-12 19:56:59 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2024-09-25 06:16:22 -07:00
|
|
|
@parameterized.parameters([
|
|
|
|
(algorithm, dtype)
|
|
|
|
for algorithm, test_dtypes in [
|
|
|
|
(lax.DotAlgorithm(
|
|
|
|
lhs_precision_type=np.float32,
|
|
|
|
rhs_precision_type=np.float32,
|
|
|
|
accumulation_type=np.float32,
|
|
|
|
lhs_component_count=1,
|
|
|
|
rhs_component_count=1,
|
|
|
|
num_primitive_operations=1,
|
|
|
|
allow_imprecise_accumulation=False,
|
|
|
|
), [np.float32]),
|
|
|
|
(lax.DotAlgorithm(
|
|
|
|
lhs_precision_type=np.float16,
|
|
|
|
rhs_precision_type=np.float16,
|
|
|
|
accumulation_type=np.float32,
|
|
|
|
), [np.float16]),
|
|
|
|
("F16_F16_F32", [np.float16]),
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
(lax.DotAlgorithmPreset.DEFAULT, lax_test_util.float_dtypes),
|
|
|
|
(lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes),
|
|
|
|
(lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes),
|
|
|
|
(lax.DotAlgorithmPreset.F16_F16_F16, [np.float16]),
|
|
|
|
(lax.DotAlgorithmPreset.F16_F16_F32, [np.float16]),
|
|
|
|
(lax.DotAlgorithmPreset.BF16_BF16_BF16, [dtypes.bfloat16]),
|
|
|
|
(lax.DotAlgorithmPreset.BF16_BF16_F32, [dtypes.bfloat16]),
|
|
|
|
(lax.DotAlgorithmPreset.BF16_BF16_F32_X3, [np.float32]),
|
|
|
|
(lax.DotAlgorithmPreset.BF16_BF16_F32_X6, [np.float32]),
|
|
|
|
(lax.DotAlgorithmPreset.TF32_TF32_F32, [np.float32]),
|
|
|
|
(lax.DotAlgorithmPreset.TF32_TF32_F32_X3, [np.float32]),
|
|
|
|
(lax.DotAlgorithmPreset.F32_F32_F32, [np.float32]),
|
|
|
|
(lax.DotAlgorithmPreset.F64_F64_F64, [np.float64]),
|
2024-09-25 06:16:22 -07:00
|
|
|
] for dtype in test_dtypes
|
|
|
|
if jtu.dtypes.supported([dtype])
|
|
|
|
])
|
|
|
|
def testDotAlgorithm(self, algorithm, dtype):
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
if jtu.test_device_matches(["cpu"]):
|
|
|
|
if algorithm not in {
|
|
|
|
lax.DotAlgorithmPreset.DEFAULT,
|
|
|
|
lax.DotAlgorithmPreset.F16_F16_F16,
|
|
|
|
lax.DotAlgorithmPreset.F32_F32_F32,
|
|
|
|
lax.DotAlgorithmPreset.F64_F64_F64,
|
2024-12-07 11:13:37 -08:00
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
}:
|
|
|
|
raise SkipTest(
|
|
|
|
f"The dot algorithm '{algorithm}' is not supported on CPU.")
|
2024-09-25 06:16:22 -07:00
|
|
|
if jtu.test_device_matches(["gpu"]):
|
|
|
|
# GPU algorithm support is a little spotty. It is checked in
|
|
|
|
# xla/service/algorithm_util.cc and the logic is copied here.
|
|
|
|
if algorithm in {
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
lax.DotAlgorithmPreset.F16_F16_F32,
|
|
|
|
lax.DotAlgorithmPreset.TF32_TF32_F32,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
|
2024-09-25 06:16:22 -07:00
|
|
|
}:
|
|
|
|
if not jtu.is_cuda_compute_capability_at_least("8.0"):
|
|
|
|
raise SkipTest(
|
|
|
|
f"The dot algorithm '{algorithm}' requires CUDA compute "
|
|
|
|
"capability >= 8.0.")
|
|
|
|
elif algorithm not in {
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
lax.DotAlgorithmPreset.DEFAULT,
|
|
|
|
lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32,
|
|
|
|
lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
|
|
|
|
lax.DotAlgorithmPreset.F32_F32_F32,
|
|
|
|
lax.DotAlgorithmPreset.F64_F64_F64,
|
2024-09-25 06:16:22 -07:00
|
|
|
}:
|
|
|
|
raise SkipTest(
|
|
|
|
f"The dot algorithm '{algorithm}' is not supported on GPU.")
|
2024-10-21 14:29:57 -07:00
|
|
|
if jtu.test_device_matches(["tpu"]):
|
2024-12-19 05:03:34 -08:00
|
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
2024-10-21 14:29:57 -07:00
|
|
|
if algorithm not in {
|
|
|
|
lax.DotAlgorithmPreset.DEFAULT,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
|
|
|
|
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
|
|
|
|
}:
|
|
|
|
raise SkipTest(
|
|
|
|
f"The dot algorithm '{algorithm}' is not supported on TPU."
|
|
|
|
)
|
2024-09-25 06:16:22 -07:00
|
|
|
lhs_shape = (3, 4)
|
|
|
|
rhs_shape = (4, 3)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
self._CompileAndCheck(partial(lax.dot, precision=algorithm), args_maker)
|
|
|
|
self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype)
|
2024-09-25 06:16:22 -07:00
|
|
|
|
|
|
|
def testDotAlgorithmInvalidFloat8Type(self):
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
if jtu.test_device_matches(["cpu"]):
|
|
|
|
raise SkipTest("Not supported on CPU.")
|
2024-09-25 06:16:22 -07:00
|
|
|
lhs_shape = (3, 4)
|
|
|
|
rhs_shape = (4, 3)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn)
|
|
|
|
with self.assertRaisesRegex(ValueError, "The dot algorithm"):
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32")
|
2024-09-25 06:16:22 -07:00
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
def testDotAlgorithmCasting(self):
|
2024-10-21 14:29:57 -07:00
|
|
|
if jtu.test_device_matches(["tpu"]):
|
|
|
|
raise SkipTest("F32_F32_F32 is not supported on TPU.")
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot(lhs, rhs, precision="F32_F32_F32")
|
2024-09-25 06:16:22 -07:00
|
|
|
lhs_shape = (3, 4)
|
|
|
|
rhs_shape = (4, 3)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
|
|
|
|
self.assertEqual(fun(lhs, rhs).dtype, np.float16)
|
2024-09-25 06:16:22 -07:00
|
|
|
|
2024-11-12 05:29:40 -08:00
|
|
|
def testDotAlgorithmAllowedOutputStorage(self):
|
|
|
|
# see https://github.com/jax-ml/jax/issues/24794
|
|
|
|
if not jtu.test_device_matches(["gpu"]):
|
|
|
|
self.skipTest("Only supported on GPU.")
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot(lhs, rhs, precision="F16_F16_F32",
|
|
|
|
preferred_element_type=np.float16)
|
|
|
|
lhs_shape = (3, 4)
|
|
|
|
rhs_shape = (4, 3)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
|
|
|
|
self.assertNotIn("convert", jax.jit(fun).lower(lhs, rhs).as_text())
|
|
|
|
|
2024-10-24 13:44:18 -04:00
|
|
|
def testDotAlgorithmConfig(self):
|
|
|
|
lhs_shape = (3, 4)
|
|
|
|
rhs_shape = (4, 3)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
|
|
|
|
|
|
|
|
expected = ("algorithm = <lhs_precision_type = f32, rhs_precision_type = "
|
|
|
|
"f32, accumulation_type = f32")
|
|
|
|
with jax.default_matmul_precision("F32_F32_F32"):
|
|
|
|
hlo = jax.jit(lax.dot).lower(lhs, rhs).as_text()
|
|
|
|
self.assertRegex(hlo, expected)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]],
|
|
|
|
[dict(dtype=d, preferred_element_type=p)
|
|
|
|
for d, p in preferred_type_combinations],
|
|
|
|
)
|
2023-05-12 19:56:59 -07:00
|
|
|
def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
|
|
|
|
preferred_element_type):
|
2023-10-12 13:15:22 +01:00
|
|
|
if (not config.enable_x64.value and
|
2020-12-10 02:29:40 +00:00
|
|
|
(dtype == np.float64 or preferred_element_type == np.float64
|
|
|
|
or dtype == np.int64 or preferred_element_type == np.int64)):
|
|
|
|
raise SkipTest("64-bit mode disabled")
|
2023-09-27 12:10:06 -07:00
|
|
|
if (jtu.test_device_matches(["tpu"]) and
|
2021-04-21 23:58:34 +02:00
|
|
|
(dtype == np.complex128 or preferred_element_type == np.complex128)):
|
|
|
|
raise SkipTest("np.complex128 is not yet supported on TPU")
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["gpu"]):
|
2021-06-18 07:12:21 -07:00
|
|
|
# TODO(b/189287598)
|
2023-05-12 19:56:59 -07:00
|
|
|
raise SkipTest("dot_general with preferred_element_type returns NaN "
|
|
|
|
"non-deterministically on GPU")
|
2020-12-10 02:29:40 +00:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
x = rng(lhs_shape, dtype)
|
|
|
|
y = rng(rhs_shape, dtype)
|
|
|
|
# We first compute the dot when both inputs are a lower-precision type and
|
|
|
|
# preferred_element_type is a higher-precision type. We then compute results
|
|
|
|
# where the inputs are first upcast to the higher-precision type and no
|
|
|
|
# `preferred_element_type` is given. We expect the result to be extremely
|
|
|
|
# similar given the semantics of `preferred_element_type`.
|
|
|
|
result_with_preferred_type = lax.dot(x, y, preferred_element_type=preferred_element_type)
|
|
|
|
result_with_upcast_inputs = lax.dot(
|
|
|
|
x.astype(preferred_element_type),
|
|
|
|
y.astype(preferred_element_type))
|
|
|
|
self.assertArraysAllClose(result_with_preferred_type, result_with_upcast_inputs)
|
|
|
|
|
2024-05-13 16:50:52 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]],
|
|
|
|
[dict(dtype_lhs=dtype_lhs, dtype_rhs=dtype_rhs)
|
|
|
|
for dtype_lhs, dtype_rhs in [(dtypes.float8_e4m3fn, dtypes.float8_e5m2),
|
2024-05-22 05:50:28 +00:00
|
|
|
(dtypes.float8_e5m2, dtypes.float8_e4m3fn),
|
|
|
|
(dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz),
|
|
|
|
(dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)]],
|
2024-05-13 16:50:52 +00:00
|
|
|
)
|
|
|
|
def test_mixed_fp8_dot_general(self, lhs_shape, rhs_shape, dtype_lhs, dtype_rhs):
|
|
|
|
if jtu.test_device_matches(["tpu"]):
|
2024-10-21 14:29:57 -07:00
|
|
|
raise SkipTest("Mixed fp8 precision matmul is not yet supported on TPU")
|
2024-05-22 05:50:28 +00:00
|
|
|
if not jtu.is_device_rocm() and (
|
|
|
|
dtype_lhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz] or
|
|
|
|
dtype_rhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz]
|
|
|
|
):
|
2024-10-21 14:29:57 -07:00
|
|
|
raise SkipTest(
|
|
|
|
"float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm"
|
|
|
|
)
|
2024-05-13 16:50:52 +00:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
lhs = rng(lhs_shape, dtype=dtype_lhs)
|
|
|
|
rhs = rng(rhs_shape, dtype=dtype_rhs)
|
|
|
|
dot_general_result = lax.dot(
|
|
|
|
lhs, rhs,
|
|
|
|
preferred_element_type=jnp.float32
|
|
|
|
)
|
|
|
|
|
|
|
|
lhs_upcasted = lhs.astype(jnp.float32)
|
|
|
|
rhs_upcasted = rhs.astype(jnp.float32)
|
|
|
|
dot_general_result_upcasted = lax.dot(
|
|
|
|
lhs_upcasted, rhs_upcasted,
|
|
|
|
preferred_element_type=jnp.float32
|
|
|
|
)
|
|
|
|
self.assertArraysAllClose(
|
|
|
|
dot_general_result, dot_general_result_upcasted, rtol=1e-3, atol=1e-3)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape in [(3,), (4, 3)]
|
|
|
|
for rhs_shape in [(3,), (3, 6)]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = {
|
2020-07-14 13:03:24 -07:00
|
|
|
np.float16: 1e-2,
|
|
|
|
np.float64: max(jtu.default_tolerance()[np.dtype(np.float64)], 1e-14),
|
|
|
|
np.complex128: max(jtu.default_tolerance()[np.dtype(np.complex128)],
|
2024-03-13 09:57:02 -07:00
|
|
|
1e-14),
|
|
|
|
jnp.bfloat16: 1e-1
|
2019-11-20 22:43:46 -05:00
|
|
|
}
|
2019-11-16 13:51:42 -05:00
|
|
|
lax_op = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(lax_reference.dot, lax_op, args_maker, tol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
lhs_shape=lhs_shape,
|
|
|
|
rhs_shape=rhs_shape,
|
|
|
|
lhs_contracting=lhs_contracting,
|
|
|
|
rhs_contracting=rhs_contracting,
|
|
|
|
)
|
|
|
|
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
|
|
|
[(5,), (5,), [0], [0]],
|
|
|
|
[(5, 7), (5,), [0], [0]],
|
|
|
|
[(7, 5), (5,), [1], [0]],
|
|
|
|
[(3, 5), (2, 5), [1], [1]],
|
|
|
|
[(5, 3), (5, 2), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]],
|
|
|
|
[(5, 3, 2), (3, 5, 2, 4), [0, 2], [1, 2]],
|
|
|
|
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
|
|
|
[(3, 2), (2, 4), [1], [0]],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
|
2020-12-03 11:01:16 -08:00
|
|
|
lhs_contracting, rhs_contracting):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
lhs_shape=lhs_shape,
|
|
|
|
rhs_shape=rhs_shape,
|
|
|
|
dimension_numbers=dimension_numbers,
|
|
|
|
)
|
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
|
2020-12-03 11:01:16 -08:00
|
|
|
dimension_numbers):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
lhs_shape=lhs_shape,
|
|
|
|
rhs_shape=rhs_shape,
|
|
|
|
dimension_numbers=dimension_numbers,
|
|
|
|
)
|
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
|
2020-12-03 11:01:16 -08:00
|
|
|
dimension_numbers):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
op = lambda x, y: lax.dot_general(x, y, dimension_numbers)
|
|
|
|
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
shape=[(), (2, 3)],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
|
|
|
broadcast_sizes=[(), (2,), (1, 2)],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testBroadcast(self, shape, dtype, broadcast_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
shape=[(), (2, 3)],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
|
|
|
broadcast_sizes=[(), (2,), (1, 2)],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testBroadcastAgainstNumpy(self, shape, dtype, broadcast_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
|
|
|
numpy_op = lambda x: lax_reference.broadcast(x, broadcast_sizes)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(inshape=inshape, outshape=outshape, dimensions=dimensions)
|
|
|
|
for inshape, outshape, dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
|
|
|
([1], [2, 3], [1]),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(inshape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-12-29 10:30:22 -08:00
|
|
|
def testBroadcastInDimOperandShapeTranspose(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/5276
|
2020-12-29 10:30:22 -08:00
|
|
|
def f(x):
|
|
|
|
return lax.broadcast_in_dim(x, (2, 3, 4), broadcast_dimensions=(0, 1, 2)).sum()
|
|
|
|
def g(x):
|
|
|
|
return lax.broadcast_in_dim(x.reshape((3,)), (2, 3, 4), broadcast_dimensions=(1,)).sum()
|
|
|
|
x = np.ones((1, 3, 1))
|
|
|
|
self.assertArraysEqual(jax.grad(f)(x), jax.grad(g)(x))
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.parameters(
|
|
|
|
{"inshape": inshape, "outshape": outshape,
|
2020-03-16 09:54:58 +01:00
|
|
|
"broadcast_dimensions": broadcast_dimensions, "err_msg": err_msg}
|
|
|
|
for inshape, outshape, broadcast_dimensions, err_msg in [
|
|
|
|
([2], [2, 2], [0, 1], ('broadcast_dimensions must have length equal to '
|
|
|
|
'operand ndim')),
|
|
|
|
([2, 2], [2], [0, 1], ('target broadcast shape must have equal or higher rank '
|
|
|
|
'to the operand shape')),
|
|
|
|
([2], [2, 3], [2], ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
|
|
|
'dimensions')),
|
|
|
|
([2], [3], [0], ('operand dimension sizes must either be 1, or be '
|
|
|
|
'equal to their corresponding dimensions in the target broadcast shape')),
|
|
|
|
([2, 2], [2, 2], [1, 0], ('broadcast_dimensions must be strictly increasing')),
|
2022-10-03 13:36:01 +00:00
|
|
|
])
|
2020-03-16 09:54:58 +01:00
|
|
|
def testBroadcastInDimShapeCheck(self, inshape, outshape, broadcast_dimensions, err_msg):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
x = rng(inshape, np.float32)
|
2020-03-16 09:54:58 +01:00
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
lax.broadcast_in_dim(x, shape=outshape, broadcast_dimensions=broadcast_dimensions)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(inshape=inshape, outshape=outshape, dimensions=dimensions)
|
|
|
|
for inshape, outshape, dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
|
|
|
([1], [2, 3], [1]),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testBroadcastInDimAgainstNumpy(self, inshape, dtype, outshape, dimensions):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(inshape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
|
|
|
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.parameters(
|
|
|
|
{"inshape": inshape, "dimensions": dimensions, "error_type": error_type,
|
|
|
|
"err_msg": err_msg}
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
for inshape, dimensions, error_type, err_msg in [
|
|
|
|
((1, 2, 3), (0, 0), ValueError, 'dimensions are not unique'),
|
|
|
|
((1, 2, 3), (3,), ValueError, 'axis 3 is out of bounds'),
|
|
|
|
((1, 2, 3), (-4,), ValueError, 'axis -4 is out of bounds'),
|
|
|
|
((1, 2, 3), (1,), ValueError, 'cannot select an axis to squeeze out'),
|
|
|
|
((1, 2, 3), (None,), TypeError, 'cannot be interpreted as an integer'),
|
2022-10-03 13:36:01 +00:00
|
|
|
])
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
def testSqueezeShapeCheck(self, inshape, dimensions, error_type, err_msg):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
x = rng(inshape, np.float32)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
with self.assertRaisesRegex(error_type, err_msg):
|
|
|
|
lax.squeeze(x, dimensions=dimensions)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, dimensions=dimensions)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
for arg_shape, dimensions in [
|
2022-10-03 13:36:01 +00:00
|
|
|
[(1,), (0,)],
|
|
|
|
[(1,), (-1,)],
|
|
|
|
[(2, 1, 4), (1,)],
|
|
|
|
[(2, 1, 3, 1), (1,)],
|
|
|
|
[(2, 1, 3, 1), (1, 3)],
|
|
|
|
[(2, 1, 3, 1), (3,)],
|
|
|
|
]],
|
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testSqueeze(self, arg_shape, dimensions):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
args_maker = lambda: [rng(arg_shape, np.float32)]
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
op = lambda x: lax.squeeze(x, dimensions)
|
|
|
|
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
check_grads(op, args_maker(), 3, ["fwd", "rev"], eps=1.)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
input_type=["np.array", "jnp.array", "float", "np.float32"],
|
|
|
|
jit=[True, False],
|
|
|
|
)
|
2022-08-23 10:48:40 -07:00
|
|
|
def testEmptySqueezeReturnType(self, input_type, jit):
|
|
|
|
if input_type == "np.array":
|
|
|
|
operand = np.arange(5)
|
|
|
|
elif input_type == "jnp.array":
|
|
|
|
operand = jnp.arange(5)
|
|
|
|
elif input_type == "float":
|
|
|
|
operand = 2.0
|
|
|
|
elif input_type == "np.float32":
|
|
|
|
operand = np.float32(2.0)
|
|
|
|
else:
|
2022-12-01 09:12:01 -08:00
|
|
|
raise ValueError(f"Unrecognized {input_type=}")
|
2022-08-23 10:48:40 -07:00
|
|
|
|
|
|
|
op = lambda x: lax.squeeze(x, dimensions=())
|
|
|
|
if jit:
|
|
|
|
op = jax.jit(op)
|
|
|
|
result = op(operand)
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(result, jax.Array)
|
2022-08-23 10:48:40 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, out_shape=out_shape)
|
|
|
|
for arg_shape, out_shape in [
|
|
|
|
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
|
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testReshape(self, arg_shape, out_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
|
|
op = lambda x: lax.reshape(x, out_shape)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, out_shape=out_shape)
|
|
|
|
for arg_shape, out_shape in [
|
|
|
|
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
|
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testReshapeAgainstNumpy(self, arg_shape, out_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
|
|
op = lambda x: lax.reshape(x, out_shape)
|
|
|
|
numpy_op = lambda x: lax_reference.reshape(x, out_shape)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-11-23 15:33:04 +01:00
|
|
|
def testRoundRoundingMethods(self):
|
|
|
|
x = np.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5], dtype=np.float32)
|
|
|
|
self.assertAllClose(lax.round(x, lax.RoundingMethod.AWAY_FROM_ZERO),
|
|
|
|
np.array([-3, -2, -1, 1, 2, 3], dtype=np.float32))
|
|
|
|
self.assertAllClose(lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN),
|
|
|
|
np.array([-2, -2, 0, 0, 2, 2], dtype=np.float32))
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, pads=pads) for shape, pads in [
|
|
|
|
((0, 2), [(1, 2, 1), (0, 1, 0)]),
|
|
|
|
((2, 3), [(1, 2, 1), (0, 1, 0)]),
|
|
|
|
((2,), [(1, 2, 0)]),
|
|
|
|
((1, 2), [(1, 2, 0), (3, 4, 0)]),
|
|
|
|
((1, 2), [(0, 0, 0), (0, 0, 0)]),
|
|
|
|
((2,), [(1, 2, 3),]),
|
|
|
|
((3, 2), [(1, 2, 1), (3, 4, 2)]),
|
|
|
|
((2,), [(-1, 2, 0),]),
|
|
|
|
((4, 2), [(-1, -2, 0), (1, 2, 0)]),
|
|
|
|
((4, 2), [(-1, 2, 0), (1, 2, 2)]),
|
|
|
|
((5,), [(-1, -2, 2),]),
|
|
|
|
((4, 2), [(-1, -2, 1), (1, 2, 2)])
|
|
|
|
]
|
|
|
|
],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testPad(self, shape, dtype, pads):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-07-14 13:03:24 -07:00
|
|
|
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(2, 3)],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
pads=[
|
2020-06-17 11:57:21 +03:00
|
|
|
[(0, 0, 0), (0, 0, 0)], # no padding
|
|
|
|
[(1, 1, 0), (2, 2, 0)], # only positive edge padding
|
|
|
|
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
|
|
|
|
[(0, 0, 0), (-1, -1, 0)], # negative padding
|
|
|
|
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges
|
|
|
|
[(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension
|
2022-10-03 13:36:01 +00:00
|
|
|
]
|
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testPadAgainstNumpy(self, shape, dtype, pads):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-07-14 13:03:24 -07:00
|
|
|
op = lambda x: lax.pad(x, np.array(0, dtype), pads)
|
|
|
|
numpy_op = lambda x: lax_reference.pad(x, np.array(0, dtype), pads)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-08 21:14:25 -07:00
|
|
|
def testPadErrors(self):
|
2024-11-20 16:21:45 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, "padding_value must be a scalar"):
|
|
|
|
lax.pad(np.zeros(2), np.zeros(2), [(0, 0, 0)])
|
2020-09-08 21:14:25 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, "padding_config"):
|
|
|
|
lax.pad(np.zeros(2), 0., [(0, 1, 0), (0, 1, 0)])
|
2021-04-05 11:08:46 +03:00
|
|
|
with self.assertRaisesRegex(ValueError, "interior padding in padding_config must be nonnegative"):
|
2020-09-08 21:14:25 -07:00
|
|
|
lax.pad(np.zeros(2), 0., [(0, 1, -1)])
|
2021-04-05 11:08:46 +03:00
|
|
|
with self.assertRaisesRegex(ValueError, "Dimension size after padding is not at least 0"):
|
|
|
|
lax.pad(np.zeros(2), 0., [(-3, 0, 0)])
|
|
|
|
with self.assertRaisesRegex(ValueError, "Dimension size after padding is not at least 0"):
|
|
|
|
lax.pad(np.zeros(2), 0., [(-4, 0, 1)])
|
2020-09-08 21:14:25 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def testReverse(self):
|
2021-09-13 16:00:22 -04:00
|
|
|
rev = jax.jit(lambda operand: lax.rev(operand, dimensions))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-27 16:14:28 -08:00
|
|
|
dimensions = []
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertAllClose(np.array([0, 1, 2, 3]), rev(np.array([0, 1, 2, 3])),
|
2020-01-27 16:14:28 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
dimensions = [0]
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertAllClose(np.array([3, 2, 1]), rev(np.array([1, 2, 3])),
|
2018-11-17 18:03:33 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
dimensions = [0, 1]
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertAllClose(np.array([[6, 5, 4], [3, 2, 1]]),
|
|
|
|
rev(np.array([[1, 2, 3], [4, 5, 6]])),
|
2018-11-17 18:03:33 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
2022-10-03 13:36:01 +00:00
|
|
|
],
|
2023-02-16 15:29:12 -08:00
|
|
|
arg_dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testSelect(self, pred_shape, arg_shape, arg_dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
def args_maker():
|
2020-07-14 13:03:24 -07:00
|
|
|
return [rng(pred_shape, np.bool_), rng(arg_shape, arg_dtype),
|
2018-11-17 18:03:33 -08:00
|
|
|
rng(arg_shape, arg_dtype)]
|
2022-02-14 14:29:38 -05:00
|
|
|
return self._CheckAgainstNumpy(lax_reference.select, lax.select, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
return self._CompileAndCheck(lax.select, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(arg_shape=arg_shape, pred_shape=pred_shape)
|
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
|
|
|
],
|
|
|
|
[
|
|
|
|
dict(pred_dtype=pred_dtype, num_args=num_args)
|
|
|
|
for (pred_dtype, num_args) in (
|
|
|
|
list(
|
|
|
|
itertools.product(
|
|
|
|
[np.dtype(np.bool_), np.dtype(np.int32)], [1, 2]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
+ [(np.dtype(np.int32), 6)]
|
|
|
|
)
|
|
|
|
],
|
|
|
|
arg_dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2022-02-09 11:02:31 -08:00
|
|
|
def testSelectN(self, pred_dtype, pred_shape, arg_shape, arg_dtype, num_args):
|
2022-02-14 14:29:38 -05:00
|
|
|
if pred_dtype == np.bool_:
|
|
|
|
pred_rng = jtu.rand_default(self.rng())
|
|
|
|
else:
|
|
|
|
pred_rng = jtu.rand_int(self.rng(), low=-1, high=num_args + 1)
|
2022-02-09 11:02:31 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
def args_maker():
|
2022-02-14 14:29:38 -05:00
|
|
|
return [pred_rng(pred_shape, pred_dtype)] + (
|
|
|
|
[rng(arg_shape, arg_dtype) for _ in range(num_args)])
|
|
|
|
return self._CheckAgainstNumpy(lambda c, *xs: np.choose(c, xs, mode='clip'),
|
|
|
|
lax.select_n, args_maker)
|
2022-02-09 11:02:31 -08:00
|
|
|
return self._CompileAndCheck(lax.select_n, args_maker)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
shape=shape, starts=indices, limits=limit_indices, strides=strides
|
|
|
|
)
|
|
|
|
for shape, indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testSlice(self, shape, dtype, starts, limits, strides):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.slice(x, starts, limits, strides)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
shape=shape, starts=indices, limits=limit_indices, strides=strides
|
|
|
|
)
|
|
|
|
for shape, indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testSliceAgainstNumpy(self, shape, dtype, starts, limits, strides):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.slice(x, starts, limits, strides)
|
|
|
|
numpy_op = lambda x: lax_reference.slice(x, starts, limits, strides)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(shape=shape, indices=indices, size_indices=size_indices)
|
|
|
|
for shape, indices, size_indices in [
|
|
|
|
[(3,), np.array((1,)), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(5, 3), np.array((1, 1)), (3, 1)],
|
|
|
|
[(7, 5, 3), np.array((4, 1, 0)), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2021-09-08 11:40:30 -04:00
|
|
|
def testDynamicSlice(self, shape, dtype, indices, size_indices):
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-09-08 11:40:30 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), np.array(indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(shape=shape, indices=indices, size_indices=size_indices)
|
|
|
|
for shape, indices, size_indices in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2021-09-08 11:40:30 -04:00
|
|
|
def testDynamicSliceAgainstNumpy(self, shape, dtype, indices, size_indices):
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-09-08 11:40:30 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), np.array(indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
op = lambda x, s: lax.dynamic_slice(x, s, size_indices)
|
|
|
|
numpy_op = lambda x, s: lax_reference.dynamic_slice(x, s, size_indices)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-20 13:29:53 -05:00
|
|
|
def testDynamicSliceInDim(self):
|
|
|
|
# Regression test for mixed type problem in dynamic_slice_in_dim.
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
x = rng((6, 7), np.int32)
|
|
|
|
np.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5])
|
2019-12-20 13:29:53 -05:00
|
|
|
|
2020-07-22 12:10:43 -07:00
|
|
|
def testDynamicSliceArraySliceSizes(self):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
x = rng((6, 7), np.int32)
|
|
|
|
np.testing.assert_equal(lax.dynamic_slice(x, [2, 3], jnp.array([2, 2])),
|
|
|
|
x[2:4, 3:5])
|
|
|
|
|
2022-06-02 20:41:06 -07:00
|
|
|
def testDynamicSliceWithNonScalarIndex(self):
|
|
|
|
x = jnp.ones((6, 7), np.int32)
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
lax.dynamic_slice_in_dim(x, jnp.array([2, 2]), 3)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(shape=shape, indices=indices, update_shape=update_shape)
|
|
|
|
for shape, indices, update_shape in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2021-09-08 11:40:30 -04:00
|
|
|
def testDynamicUpdateSlice(self, shape, dtype, indices, update_shape):
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def args_maker():
|
2021-09-08 11:40:30 -04:00
|
|
|
return [rng(shape, dtype), rng(update_shape, dtype), np.array(indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(lax.dynamic_update_slice, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(shape=shape, indices=indices, update_shape=update_shape)
|
|
|
|
for shape, indices, update_shape in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2021-09-08 11:40:30 -04:00
|
|
|
def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, indices,
|
2020-12-03 11:01:16 -08:00
|
|
|
update_shape):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def args_maker():
|
2021-09-08 11:40:30 -04:00
|
|
|
return [rng(shape, dtype), rng(update_shape, dtype), np.array(indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(lax_reference.dynamic_update_slice,
|
|
|
|
lax.dynamic_update_slice, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-01-04 12:39:31 -08:00
|
|
|
def testDynamicUpdateSliceBatched(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/9083
|
2022-01-04 12:39:31 -08:00
|
|
|
x = jnp.arange(5)
|
|
|
|
y = jnp.arange(6, 9)
|
|
|
|
ind = jnp.arange(6)
|
|
|
|
expected = jnp.vstack([lax.dynamic_update_slice(x, y, (i,)) for i in ind])
|
|
|
|
actual = jax.vmap(lax.dynamic_update_slice, (None, None, 0))(x, y, (ind,))
|
|
|
|
self.assertAllClose(expected, actual)
|
|
|
|
|
2022-06-02 20:41:06 -07:00
|
|
|
def testDynamicUpdateSliceWithNonScalarIndex(self):
|
|
|
|
x = jnp.ones((6, 7), np.int32)
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
lax.dynamic_update_slice_in_dim(x, jnp.ones((2, 7), np.int32),
|
|
|
|
jnp.array([2, 2]), axis=0)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(shape=shape, perm=perm)
|
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testTranspose(self, shape, dtype, perm):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.transpose(x, perm)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-22 12:10:43 -07:00
|
|
|
def testTransposeWithArrayPermutation(self):
|
|
|
|
x = lax.transpose(np.ones((2, 3)), jnp.array([1, 0]))
|
|
|
|
self.assertEqual((3, 2), x.shape)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(shape=shape, perm=perm)
|
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testTransposeAgainstNumpy(self, shape, dtype, perm):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.transpose(x, perm)
|
|
|
|
numpy_op = lambda x: lax_reference.transpose(x, perm)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
[
|
|
|
|
dict(
|
|
|
|
op=rec.op,
|
|
|
|
reference_op=rec.reference_op,
|
|
|
|
init_val=rec.init_val,
|
|
|
|
primitive=rec.primitive,
|
|
|
|
dtype=dtype,
|
|
|
|
)
|
|
|
|
for rec in lax_test_util.lax_reduce_ops()
|
|
|
|
for dtype in rec.dtypes
|
|
|
|
],
|
|
|
|
[
|
|
|
|
dict(shape=shape, dims=dims)
|
|
|
|
for shape, dims in [
|
|
|
|
[(3, 4, 5), (0,)],
|
|
|
|
[(3, 4, 5), (1, 2)],
|
|
|
|
[(3, 4, 5), (0, 2)],
|
|
|
|
[(3, 4, 5), (0, 1, 2)],
|
|
|
|
]
|
|
|
|
],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2022-06-15 21:27:42 +02:00
|
|
|
def testReduce(self, op, reference_op, init_val, shape, dtype, dims, primitive):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype in (np.float64, np.int64, np.uint64):
|
2022-06-15 21:27:42 +02:00
|
|
|
raise SkipTest("x64 mode is disabled.")
|
|
|
|
def reference_fun(operand):
|
|
|
|
if hasattr(reference_op, "reduce"):
|
|
|
|
initial = np.array(init_val, dtype=dtype)
|
|
|
|
result = reference_op.reduce(operand, axis=dims, initial=initial)
|
|
|
|
else:
|
|
|
|
result = reference_op(operand, axis=dims)
|
|
|
|
|
|
|
|
return result.astype(dtype)
|
|
|
|
|
2020-12-03 11:01:16 -08:00
|
|
|
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
|
|
|
else jtu.rand_small)
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2022-12-08 19:40:56 +00:00
|
|
|
init_val = np.asarray(init_val).astype(dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), init_val]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# we separately test the version that uses a concrete init_val because it
|
|
|
|
# can hit different code paths
|
|
|
|
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2022-06-09 20:38:53 +02:00
|
|
|
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
|
|
|
|
|
2022-06-15 21:27:42 +02:00
|
|
|
# check that the correct monoid reducer primitive is used inside the jaxpr.
|
|
|
|
# This requires the init_val (monoid identity element) to be static
|
2022-06-09 20:38:53 +02:00
|
|
|
jaxpr = jax.make_jaxpr(fun)(rng(shape, dtype))
|
2022-06-15 21:27:42 +02:00
|
|
|
self.assertEqual(jaxpr.eqns[0].primitive, primitive)
|
2022-06-09 20:38:53 +02:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
op=["add", "mul"],
|
|
|
|
op_namespace=[lax, operator],
|
|
|
|
arr_weak_type=[False, True],
|
|
|
|
init_weak_type=[False, True],
|
|
|
|
)
|
2021-02-08 13:37:25 -08:00
|
|
|
def testReduceWeakType(self, op_namespace, op, arr_weak_type, init_weak_type):
|
|
|
|
op = getattr(op_namespace, op)
|
2022-03-09 18:18:16 -08:00
|
|
|
arr = lax_internal._convert_element_type(np.arange(10), int,
|
|
|
|
weak_type=arr_weak_type)
|
|
|
|
init = lax_internal._convert_element_type(1, int, weak_type=init_weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
fun = lambda arr, init: lax.reduce(arr, init, op, (0,))
|
|
|
|
out = fun(arr, init)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(out), arr_weak_type and init_weak_type)
|
2021-09-13 16:00:22 -04:00
|
|
|
out_jit = jax.jit(fun)(arr, init)
|
2021-02-08 13:37:25 -08:00
|
|
|
self.assertEqual(dtypes.is_weakly_typed(out_jit), arr_weak_type and init_weak_type)
|
|
|
|
|
2022-06-01 10:23:42 -07:00
|
|
|
def testReduceWindowScalar(self):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
dtype = jnp.float32
|
|
|
|
init_val = np.asarray(0, dtype=dtype)
|
|
|
|
op = lax.add
|
|
|
|
|
|
|
|
def fun(operand, init_val):
|
|
|
|
return lax.reduce_window(
|
|
|
|
operand, init_val, op, window_dimensions=(), window_strides=(),
|
|
|
|
padding=(), base_dilation=(), window_dilation=())
|
|
|
|
|
|
|
|
def reference_fun(operand, init_val):
|
|
|
|
return lax_reference.reduce_window(
|
|
|
|
operand, init_val, op, window_dimensions=(), window_strides=(),
|
|
|
|
padding=(), base_dilation=())
|
|
|
|
|
|
|
|
args_maker = lambda: [rng((), dtype), init_val]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(init_val=init_val, op=op, dtype=dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
for init_val, op, dtypes in [
|
2020-07-14 13:03:24 -07:00
|
|
|
(0, lax.add, [np.float32]),
|
|
|
|
(-np.inf, lax.max, [np.float32]),
|
|
|
|
(np.inf, lax.min, [np.float32]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2022-10-03 13:36:01 +00:00
|
|
|
for dtype in dtypes
|
|
|
|
],
|
|
|
|
[dict(shape=shape, dims=dims, strides=strides, padding=padding,
|
|
|
|
base_dilation=base_dilation, window_dilation=window_dilation)
|
2020-07-20 17:27:24 -04:00
|
|
|
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
|
|
|
itertools.chain(
|
|
|
|
itertools.product(
|
2019-02-03 21:10:03 -05:00
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
2020-07-17 16:05:51 -04:00
|
|
|
[(1, 1), (2, 1), (1, 2)],
|
2020-07-20 17:27:24 -04:00
|
|
|
["VALID", "SAME", [(0, 3), (1, 2)]],
|
|
|
|
[(1, 1), (2, 3)],
|
|
|
|
[(1, 1), (1, 2)]),
|
|
|
|
itertools.product(
|
2019-02-03 21:10:03 -05:00
|
|
|
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
2020-07-17 16:05:51 -04:00
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
2020-07-20 17:27:24 -04:00
|
|
|
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
|
|
|
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
|
|
|
[(1, 1, 1, 1), (1, 2, 2, 1)])))
|
2022-10-03 13:36:01 +00:00
|
|
|
],
|
|
|
|
)
|
2020-07-20 17:27:24 -04:00
|
|
|
def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
init_val = np.asarray(init_val, dtype=dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def fun(operand, init_val):
|
2020-07-20 17:27:24 -04:00
|
|
|
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-17 16:05:51 -04:00
|
|
|
def reference_fun(operand, init_val):
|
|
|
|
return lax_reference.reduce_window(operand, init_val, op, dims, strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation)
|
2020-07-17 16:05:51 -04:00
|
|
|
|
2020-07-20 17:27:24 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), init_val]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
if all(d == 1 for d in window_dilation):
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# we separately test the version that uses a concrete init_val because it
|
|
|
|
# can hit different code paths
|
|
|
|
def fun(operand):
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
return lax.reduce_window(
|
|
|
|
operand,
|
|
|
|
init_val,
|
|
|
|
op,
|
|
|
|
dims,
|
|
|
|
strides,
|
|
|
|
padding,
|
|
|
|
base_dilation,
|
|
|
|
window_dilation,
|
|
|
|
)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-20 17:27:24 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
# TODO(voz): I broke these out to their own test for 2 reasons:
|
|
|
|
# 1. I wanted to show that general ops work, there's a small subset of
|
|
|
|
# ops, specifically, the ones used in the test above, lax.add, lax.max, and
|
|
|
|
# lax.min that actually route to a monoid operator that *doesn't* pass JVP
|
|
|
|
# tests.
|
|
|
|
# 2. Slightly different parameterization.
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
[
|
|
|
|
dict(init_val=init_val, op=op, dtype=dtype)
|
|
|
|
for init_val, op, dtypes in [
|
|
|
|
(1, _reduce_custom_add, [np.float32]),
|
|
|
|
(0, _reduce_custom_mul, [np.float32]),
|
|
|
|
(0, _reduce_custom_sub, [np.float32]),
|
|
|
|
]
|
|
|
|
for dtype in dtypes
|
|
|
|
],
|
|
|
|
[
|
|
|
|
dict(
|
|
|
|
shape=shape,
|
|
|
|
dims=dims,
|
|
|
|
strides=strides,
|
|
|
|
padding=padding,
|
|
|
|
base_dilation=base_dilation,
|
|
|
|
window_dilation=window_dilation,
|
|
|
|
)
|
|
|
|
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
|
|
|
itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
|
|
|
[(1, 1), (2, 1), (1, 2)],
|
|
|
|
['VALID', 'SAME', [(0, 3), (1, 2)]],
|
|
|
|
[(1, 1), (2, 3)],
|
|
|
|
[(1, 1), (1, 2)],
|
|
|
|
),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)],
|
|
|
|
[(1, 1, 2, 1), (2, 1, 2, 1)],
|
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
|
|
|
['VALID', 'SAME', [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
|
|
|
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
|
|
|
[(1, 1, 1, 1), (1, 2, 2, 1)],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
@jtu.skip_on_devices('gpu') # jax.lax.mul has an XLA bug on GPU b/339071103
|
|
|
|
@jtu.skip_on_devices('tpu') # b/39342488
|
|
|
|
def testReduceWindowGeneralJVP(
|
|
|
|
self,
|
|
|
|
op,
|
|
|
|
init_val,
|
|
|
|
dtype,
|
|
|
|
shape,
|
|
|
|
dims,
|
|
|
|
strides,
|
|
|
|
padding,
|
|
|
|
base_dilation,
|
|
|
|
window_dilation,
|
|
|
|
):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
init_val = np.asarray(init_val, dtype=dtype)
|
|
|
|
|
|
|
|
def fun(operand, init_val):
|
|
|
|
return lax.reduce_window(
|
|
|
|
operand,
|
|
|
|
init_val,
|
|
|
|
op,
|
|
|
|
dims,
|
|
|
|
strides,
|
|
|
|
padding,
|
|
|
|
base_dilation,
|
|
|
|
window_dilation,
|
|
|
|
)
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype), init_val]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
args = args_maker()
|
|
|
|
init_val = args[1]
|
|
|
|
|
|
|
|
# we separately test the version that uses a concrete init_val because it
|
|
|
|
# can hit different code paths
|
|
|
|
def fun2(operand):
|
|
|
|
return lax.reduce_window(
|
|
|
|
operand,
|
|
|
|
init_val,
|
|
|
|
op,
|
|
|
|
dims,
|
|
|
|
strides,
|
|
|
|
padding,
|
|
|
|
base_dilation,
|
|
|
|
window_dilation,
|
|
|
|
)
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(fun2, args_maker)
|
|
|
|
|
|
|
|
operand = args_maker()[0]
|
|
|
|
jtu.check_jvp(fun2, partial(jax.jvp, fun2), (operand,))
|
|
|
|
check_grads(fun2, (operand,), 3, ["fwd"], eps=1.)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
[
|
|
|
|
dict(init_val=init_val, op=op, dtype=dtype)
|
|
|
|
for init_val, op, dtypes in [
|
|
|
|
(-np.inf, lax.max, [np.float32]),
|
|
|
|
(np.inf, lax.min, [np.float32]),
|
|
|
|
(0, lax.add, [np.float32]),
|
|
|
|
]
|
|
|
|
for dtype in dtypes
|
|
|
|
],
|
|
|
|
[
|
|
|
|
dict(
|
|
|
|
shape=shape,
|
|
|
|
dims=dims,
|
|
|
|
strides=strides,
|
|
|
|
padding=padding,
|
|
|
|
base_dilation=base_dilation,
|
|
|
|
window_dilation=window_dilation,
|
|
|
|
)
|
|
|
|
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
|
|
|
itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
|
|
|
[(1, 1), (2, 1), (1, 2)],
|
|
|
|
['VALID', 'SAME', [(0, 3), (1, 2)]],
|
|
|
|
[(1, 1), (2, 3)],
|
|
|
|
[(1, 1), (1, 2)],
|
|
|
|
),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)],
|
|
|
|
[(1, 1, 2, 1), (2, 1, 2, 1)],
|
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
|
|
|
['VALID', 'SAME', [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
|
|
|
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
|
|
|
[(1, 1, 1, 1), (1, 2, 2, 1)],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
@jtu.skip_on_devices('gpu') # jax.lax.mul has an XLA bug on GPU b/339071103
|
|
|
|
@jtu.skip_on_devices('tpu') # b/39342488
|
|
|
|
def testReduceWindowCustomSameAsMonoid(
|
|
|
|
self,
|
|
|
|
op,
|
|
|
|
init_val,
|
|
|
|
dtype,
|
|
|
|
shape,
|
|
|
|
dims,
|
|
|
|
strides,
|
|
|
|
padding,
|
|
|
|
base_dilation,
|
|
|
|
window_dilation,
|
|
|
|
):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
init_val = np.asarray(init_val, dtype=dtype)
|
|
|
|
|
|
|
|
def fun(op_, operand_):
|
|
|
|
return lax.reduce_window(
|
|
|
|
operand_,
|
|
|
|
init_val,
|
|
|
|
op_,
|
|
|
|
dims,
|
|
|
|
strides,
|
|
|
|
padding,
|
|
|
|
base_dilation,
|
|
|
|
window_dilation,
|
|
|
|
)
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
args = args_maker()
|
|
|
|
operand = args[0]
|
|
|
|
rng = np.random.RandomState(0)
|
|
|
|
tangent = tree_map(partial(jtu.rand_like, rng), operand)
|
|
|
|
|
|
|
|
# There are "special" paths for "monoid" ops that have
|
|
|
|
# their jvp defined separately, either for legacy reasons
|
|
|
|
# or for optimization - compare across both and prove
|
|
|
|
# that their jvp is the same.
|
|
|
|
# TODO(voz): Look into the "monoid" paths and collapse them as necessary.
|
|
|
|
# Especially when we go to add support for (1) recursive is_jvp (hessians),
|
|
|
|
# and (2) transpose?
|
|
|
|
custom_equiv = {
|
|
|
|
lax.max: _reduce_custom_max,
|
|
|
|
lax.min: _reduce_custom_min,
|
|
|
|
lax.add: _reduce_custom_add,
|
|
|
|
}
|
|
|
|
custom_op = custom_equiv[op]
|
|
|
|
custom_primals, custom_tangents = jax.jvp(
|
|
|
|
partial(fun, custom_op),
|
|
|
|
primals=(operand,),
|
|
|
|
tangents=(tangent,),
|
|
|
|
)
|
|
|
|
lax_primals, lax_tangents = jax.jvp(
|
|
|
|
partial(fun, op),
|
|
|
|
primals=(operand,),
|
|
|
|
tangents=(tangent,),
|
|
|
|
)
|
|
|
|
# tol = 1e-4
|
|
|
|
# None is sane defaults, but useful to have here for debugging.
|
|
|
|
tol = None
|
|
|
|
jtu.check_close(
|
|
|
|
lax_primals,
|
|
|
|
custom_primals,
|
|
|
|
atol=tol,
|
|
|
|
rtol=tol,
|
|
|
|
err_msg='Mismatched primal',
|
|
|
|
)
|
|
|
|
jtu.check_close(
|
|
|
|
lax_tangents,
|
|
|
|
custom_tangents,
|
|
|
|
atol=tol,
|
|
|
|
rtol=tol,
|
|
|
|
err_msg='Mismatched tangents',
|
|
|
|
)
|
|
|
|
# Numerical jvp comparison for min and max values
|
|
|
|
# does not work - the underlying implementation of the test util
|
|
|
|
# nans on infs.
|
|
|
|
if init_val.item() in (np.inf, -np.inf):
|
|
|
|
return
|
|
|
|
op_bound_fn = partial(fun, op)
|
|
|
|
jtu.check_jvp(
|
|
|
|
op_bound_fn,
|
|
|
|
partial(jax.jvp, op_bound_fn),
|
|
|
|
(operand,),
|
|
|
|
)
|
|
|
|
check_grads(partial(fun, op), [operand], 3, ["fwd"], eps=1.)
|
|
|
|
check_grads(partial(fun, custom_op), [operand], 3, ["fwd"], eps=1.)
|
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
# TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
[
|
|
|
|
dict(
|
|
|
|
shape=shape,
|
|
|
|
dims=dims,
|
|
|
|
strides=strides,
|
|
|
|
padding=padding,
|
|
|
|
base_dilation=base_dilation,
|
|
|
|
window_dilation=window_dilation,
|
|
|
|
)
|
|
|
|
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
|
|
|
itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
|
|
|
[(1, 1), (2, 1), (1, 2)],
|
|
|
|
['VALID', 'SAME', [(0, 3), (1, 2)]],
|
|
|
|
[(1, 1), (2, 3)],
|
|
|
|
[(1, 1), (1, 2)],
|
|
|
|
),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)],
|
|
|
|
[(1, 1, 2, 1), (2, 1, 2, 1)],
|
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
|
|
|
['VALID', 'SAME', [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
|
|
|
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
|
|
|
[(1, 1, 1, 1), (1, 2, 2, 1)],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
],
|
|
|
|
dtype=[np.float32],
|
|
|
|
)
|
|
|
|
@jtu.skip_on_devices('gpu')
|
2021-11-22 13:20:55 -08:00
|
|
|
def testReduceWindowVariadic(self, dtype, shape, dims, strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
2023-09-27 12:10:06 -07:00
|
|
|
if (jtu.test_device_matches(["tpu"]) and
|
2021-11-22 13:20:55 -08:00
|
|
|
any(d != 1 for d in window_dilation)):
|
|
|
|
raise SkipTest("TPU support missing for arbitrary window dilation.")
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
init_values = (np.asarray(0, dtype=dtype), np.array(-np.inf, dtype=dtype))
|
|
|
|
|
|
|
|
def reducer(xs, ys):
|
|
|
|
x1, x2 = xs
|
|
|
|
y1, y2 = ys
|
|
|
|
return (x1 + y1, lax.max(x2, y2))
|
|
|
|
|
|
|
|
def fun(*operands):
|
|
|
|
return lax.reduce_window(operands, init_values, reducer, dims, strides,
|
|
|
|
padding, base_dilation, window_dilation)
|
|
|
|
|
|
|
|
def reference_fun(*operands):
|
|
|
|
return [
|
|
|
|
lax_reference.reduce_window(operand, init_val, op, dims, strides,
|
|
|
|
padding, base_dilation)
|
|
|
|
for operand, init_val, op in zip(operands, init_values,
|
|
|
|
[np.add, np.maximum])]
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
if all(d == 1 for d in window_dilation):
|
|
|
|
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
|
|
|
|
|
2020-08-27 09:47:19 +02:00
|
|
|
def testReduceWindowFailures(self):
|
|
|
|
def empty_window_test():
|
|
|
|
return lax.reduce_window(np.ones((1,)), 0., lax.add, padding='VALID',
|
|
|
|
window_dimensions=(0,), window_strides=(1,))
|
|
|
|
|
|
|
|
def zero_stride_test():
|
|
|
|
return lax.reduce_window(np.ones((1,)), 0., lax.add, padding='VALID',
|
|
|
|
window_dimensions=(1,), window_strides=(0,))
|
|
|
|
|
|
|
|
for failure_fun in [empty_window_test, zero_stride_test]:
|
|
|
|
with self.assertRaisesRegex(TypeError, "must have every element be"):
|
|
|
|
failure_fun()
|
|
|
|
|
2021-01-28 15:36:15 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2021-11-22 13:20:55 -08:00
|
|
|
"reduce_window output must have the same tree structure as the "
|
|
|
|
"operands.*"):
|
2021-01-28 15:36:15 -08:00
|
|
|
return lax.reduce_window(
|
|
|
|
np.ones((1,)), 0., lambda x, y: [x + y],
|
|
|
|
padding='VALID', window_dimensions=(1,), window_strides=(1,))
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, window_dimensions=window_dimensions,
|
|
|
|
base_dilation=base_dilation, window_dilation=window_dilation)
|
2020-08-20 20:45:15 +02:00
|
|
|
for shape, window_dimensions, base_dilation, window_dilation in (
|
|
|
|
itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)],
|
|
|
|
[(1, 1), (3, 4)],
|
|
|
|
[(1, 1), (1, 2), (2, 13), (40, 60)],
|
|
|
|
[(1, 1), (1, 2), (2, 13), (40, 60)]),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)],
|
|
|
|
[(1, 1, 1, 1), (2, 1, 2, 1)],
|
|
|
|
[(1, 1, 1, 1), (1, 2, 2, 1), (30, 40, 3, 2)],
|
2022-10-03 13:36:01 +00:00
|
|
|
[(1, 1, 1, 1), (1, 2, 2, 1), (30, 40, 3, 2)])))
|
|
|
|
],
|
|
|
|
)
|
2020-08-20 20:45:15 +02:00
|
|
|
def testReduceWindowShapeDilation(self, shape, window_dimensions,
|
|
|
|
base_dilation, window_dilation):
|
2020-11-29 16:34:54 -08:00
|
|
|
operand, padding, strides = np.ones(shape), 'SAME', (1,) * len(shape)
|
|
|
|
result = lax.reduce_window(operand, 0., lax.add, padding=padding,
|
|
|
|
window_strides=strides,
|
|
|
|
window_dimensions=window_dimensions)
|
|
|
|
# With a stride of 1 in each direction and a padding of 'SAME', the
|
|
|
|
# shape of the input should be equal to the shape of the result according
|
|
|
|
# to https://www.tensorflow.org/xla/operation_semantics#reducewindow.
|
|
|
|
self.assertEqual(shape, result.shape)
|
2020-08-20 20:45:15 +02:00
|
|
|
|
2022-04-18 22:24:26 -04:00
|
|
|
def testReduceWindowWithEmptyOutput(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/10315
|
2022-04-18 22:24:26 -04:00
|
|
|
shape = (5, 3, 2)
|
|
|
|
operand, padding, strides = np.ones(shape), 'VALID', (1,) * len(shape)
|
|
|
|
out = jax.eval_shape(lambda x: lax.reduce_window(x, 0., lax.add, padding=padding,
|
|
|
|
window_strides=strides,
|
|
|
|
window_dimensions=(3, 1, 1),
|
|
|
|
window_dilation=(3, 1, 1)), operand)
|
|
|
|
self.assertEqual((0, 3, 2), out.shape)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(op=op, np_op=np_op) for op, np_op in [
|
|
|
|
(lax.cumsum, np.cumsum),
|
|
|
|
(lax.cumprod, np.cumprod),
|
|
|
|
(lax.cummax, np.maximum.accumulate),
|
|
|
|
(lax.cummin, np.minimum.accumulate),
|
|
|
|
]],
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in [[10], [3, 4, 5]] for axis in range(len(shape))],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
reverse=[False, True],
|
|
|
|
)
|
2020-10-16 10:09:11 -04:00
|
|
|
def testCumulativeReduce(self, op, np_op, shape, dtype, axis, reverse):
|
|
|
|
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
|
|
|
else jtu.rand_small)
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-10-16 10:09:11 -04:00
|
|
|
fun = partial(op, axis=axis, reverse=reverse)
|
|
|
|
def np_fun(x):
|
|
|
|
if reverse:
|
|
|
|
return np.flip(np_op(np.flip(x, axis), axis=axis, dtype=dtype), axis)
|
|
|
|
else:
|
|
|
|
return np_op(x, axis=axis, dtype=dtype)
|
2020-04-06 11:22:01 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_fun, fun, args_maker)
|
2020-04-06 11:22:01 -04:00
|
|
|
|
2022-10-31 15:08:19 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in [[10], [3, 4, 5]] for axis in range(len(shape))],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
2022-10-31 15:08:19 -07:00
|
|
|
reverse=[False, True],
|
|
|
|
)
|
|
|
|
def testCumulativeLogSumExp(self, shape, dtype, axis, reverse):
|
|
|
|
# This op only works on floating-point types, so we've separated out the
|
|
|
|
# test.
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
fun = partial(lax.cumlogsumexp, axis=axis, reverse=reverse)
|
|
|
|
def np_fun(x):
|
|
|
|
if reverse:
|
|
|
|
return np.flip(np.logaddexp.accumulate(
|
|
|
|
np.flip(x, axis), axis=axis, dtype=dtype), axis)
|
|
|
|
else:
|
|
|
|
return np.logaddexp.accumulate(x, axis=axis, dtype=dtype)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2023-09-05 18:48:18 -07:00
|
|
|
tol = None
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(["tpu"]) and dtype == np.float32:
|
2023-09-05 18:48:18 -07:00
|
|
|
tol = 1e-4
|
|
|
|
self._CheckAgainstNumpy(np_fun, fun, args_maker, atol=tol, rtol=tol)
|
2021-04-05 09:54:14 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(), (3,), (3, 4)],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
|
|
|
out_dtype=lax_test_util.float_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2021-04-05 09:54:14 -07:00
|
|
|
def testReducePrecision(self, shape, dtype, out_dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
info = dtypes.finfo(out_dtype)
|
|
|
|
fun = lambda x: lax.reduce_precision(x, info.nexp, info.nmant)
|
|
|
|
np_fun = lambda x: np.asarray(x).astype(out_dtype).astype(dtype)
|
|
|
|
self._CheckAgainstNumpy(np_fun, fun, args_maker)
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
|
2023-09-11 16:35:00 -07:00
|
|
|
def testReducePrecisionGrad(self):
|
|
|
|
info = dtypes.finfo(jnp.dtype('bfloat16'))
|
|
|
|
y, f_vjp = jax.vjp(lambda x: lax.reduce_precision(x, info.nexp, info.nmant), jnp.pi)
|
|
|
|
y2 = f_vjp(jnp.pi)
|
|
|
|
y3 = lax.reduce_precision(jnp.pi, info.nexp, info.nmant)
|
|
|
|
self.assertArraysEqual(y, y2)
|
|
|
|
self.assertArraysEqual(y, y3)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in [(5,), (5, 7)] for axis in [-1, len(shape) - 1]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
is_stable=[False, True],
|
|
|
|
)
|
2020-06-26 18:40:00 +01:00
|
|
|
def testSort(self, shape, dtype, axis, is_stable):
|
2020-05-14 19:17:44 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-26 18:40:00 +01:00
|
|
|
fun = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2023-02-16 15:29:12 -08:00
|
|
|
@jtu.sample_product(dtype=lax_test_util.float_dtypes)
|
2022-01-13 13:03:41 -08:00
|
|
|
def testSortFloatSpecialValues(self, dtype):
|
|
|
|
# Test confirms that
|
|
|
|
# - NaNs are sorted to the end, regardless of representation
|
|
|
|
# - sign bit of 0.0 is ignored
|
|
|
|
x = jnp.array([-np.inf, 0.0, -0.0, np.inf, np.nan, -np.nan], dtype=dtype)
|
|
|
|
index = lax.iota(dtypes.int_, x.size)
|
|
|
|
argsort = lambda x: lax.sort_key_val(x, lax.iota(dtypes.int_, x.size), is_stable=True)[1]
|
|
|
|
self.assertArraysEqual(argsort(x), index)
|
|
|
|
self.assertArraysEqual(jax.jit(argsort)(x), index)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in [(5,), (5, 7)] for axis in [-1, len(shape) - 1]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
is_stable=[False, True],
|
|
|
|
)
|
2020-06-26 18:40:00 +01:00
|
|
|
def testSortAgainstNumpy(self, shape, dtype, axis, is_stable):
|
2020-05-14 19:17:44 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-26 18:40:00 +01:00
|
|
|
op = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
|
|
|
|
def numpy_op(x):
|
|
|
|
if is_stable:
|
|
|
|
return lax_reference.sort(x, axis, kind='stable')
|
|
|
|
else:
|
|
|
|
return lax_reference.sort(x, axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in [(3,), (5, 3)] for axis in [-1, len(shape) - 1]],
|
2023-02-16 15:29:12 -08:00
|
|
|
key_dtype=lax_test_util.float_dtypes + lax_test_util.complex_dtypes +
|
|
|
|
lax_test_util.int_dtypes + lax_test_util.uint_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
val_dtype=[np.float32, np.int32, np.uint32],
|
|
|
|
is_stable=[False, True],
|
|
|
|
)
|
2020-06-26 18:40:00 +01:00
|
|
|
def testSortKeyVal(self, shape, key_dtype, val_dtype, axis, is_stable):
|
2020-11-29 16:34:54 -08:00
|
|
|
if (np.issubdtype(key_dtype, np.complexfloating) and
|
2023-09-27 12:10:06 -07:00
|
|
|
jtu.test_device_matches(["cpu"])):
|
2020-05-14 19:17:44 -04:00
|
|
|
raise SkipTest("Complex-valued sort not implemented")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
# This test relies on the property that wherever keys are tied, values are
|
|
|
|
# too, since we don't guarantee the same ordering of values with equal keys.
|
|
|
|
# To avoid that case, we generate unique keys (globally in the key array).
|
|
|
|
def args_maker():
|
2023-02-28 12:40:30 -08:00
|
|
|
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
|
2020-05-04 23:00:20 -04:00
|
|
|
keys = self.rng().permutation(flat_keys).reshape(shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
values = rng(shape, val_dtype)
|
|
|
|
return keys, values
|
|
|
|
|
2020-06-26 18:40:00 +01:00
|
|
|
fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, num_keys=num_keys)
|
|
|
|
for shape in [(3, 5), (4, 3)] for num_keys in range(1, shape[0] + 1)],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-07-09 20:05:19 -07:00
|
|
|
def testSortNumKeys(self, shape, dtype, num_keys):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
lax_fun = lambda x: lax.sort(tuple(x), num_keys=num_keys)
|
2020-07-14 13:03:24 -07:00
|
|
|
numpy_fun = lambda x: tuple(x[:, np.lexsort(x[:num_keys][::-1])])
|
2020-07-09 20:05:19 -07:00
|
|
|
# self._CompileAndCheck(lax_fun, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_fun, lax_fun, args_maker)
|
2020-07-09 20:05:19 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in [(3,), (5, 3)] for axis in [-1, len(shape) - 1]],
|
2023-02-16 15:29:12 -08:00
|
|
|
key_dtype=lax_test_util.float_dtypes + lax_test_util.complex_dtypes +
|
|
|
|
lax_test_util.int_dtypes + lax_test_util.uint_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
val_dtype=[np.float32, np.int32, np.uint32],
|
|
|
|
)
|
2020-05-14 19:17:44 -04:00
|
|
|
def testSortKeyValAgainstNumpy(self, shape, key_dtype, val_dtype, axis):
|
2020-11-29 16:34:54 -08:00
|
|
|
if (np.issubdtype(key_dtype, np.complexfloating) and
|
2023-09-27 12:10:06 -07:00
|
|
|
jtu.test_device_matches(["cpu"])):
|
2020-05-14 19:17:44 -04:00
|
|
|
raise SkipTest("Complex-valued sort not implemented")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
# This test relies on the property that wherever keys are tied, values are
|
|
|
|
# too, since we don't guarantee the same ordering of values with equal keys.
|
|
|
|
# To avoid that case, we generate unique keys (globally in the key array).
|
|
|
|
def args_maker():
|
2023-02-28 12:40:30 -08:00
|
|
|
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
|
2020-05-04 23:00:20 -04:00
|
|
|
keys = self.rng().permutation(flat_keys).reshape(shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
values = rng(shape, val_dtype)
|
|
|
|
return keys, values
|
|
|
|
|
|
|
|
op = lambda ks, vs: lax.sort_key_val(ks, vs, axis)
|
|
|
|
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=[np.float32, np.int32, np.uint32],
|
2023-05-08 13:50:53 -07:00
|
|
|
shape=[(20,), (5, 20), (2000,)],
|
|
|
|
k=[1, 3, 12],
|
|
|
|
negative=[False, True]
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2023-05-08 13:50:53 -07:00
|
|
|
def testTopK(self, shape, dtype, k, negative):
|
2020-02-20 17:15:25 -08:00
|
|
|
def args_maker():
|
2023-02-28 12:40:30 -08:00
|
|
|
flat_values = np.arange(math.prod(shape), dtype=dtype)
|
2020-05-04 23:00:20 -04:00
|
|
|
values = self.rng().permutation(flat_values).reshape(shape)
|
2023-05-08 13:50:53 -07:00
|
|
|
if negative:
|
|
|
|
values = -values
|
2020-02-20 17:15:25 -08:00
|
|
|
return [values]
|
|
|
|
def reference_top_k(x):
|
2020-07-14 13:03:24 -07:00
|
|
|
bcast_idxs = np.broadcast_to(np.arange(shape[-1], dtype=np.int32), shape)
|
2020-02-20 17:15:25 -08:00
|
|
|
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
|
|
|
|
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
|
|
|
|
op = lambda vs: lax.top_k(vs, k=k)
|
|
|
|
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2020-02-20 17:15:25 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [((3, 2), (2, 4)),
|
|
|
|
((5, 3, 2), (5, 2, 4)),
|
2022-10-03 13:36:01 +00:00
|
|
|
((1, 2, 2, 3), (1, 2, 3, 1))]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testBatchMatMul(self, lhs_shape, rhs_shape, dtype):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
arg_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(lax.batch_matmul, arg_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testCollapse(self):
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2018-11-17 18:03:33 -08:00
|
|
|
def collapse_first_two(x):
|
|
|
|
return lax.collapse(x, 0, 2)
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertEqual((6,), collapse_first_two(np.zeros((2, 3))).shape)
|
|
|
|
self.assertEqual((6, 4), collapse_first_two(np.zeros((2, 3, 4))).shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
self.assertEqual((2, 3, 4),
|
2020-07-14 13:03:24 -07:00
|
|
|
collapse_first_two(np.zeros((1, 2, 3, 4))).shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2023-06-26 18:29:55 -07:00
|
|
|
def testCollapseLastTwo(self):
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def collapse_last_two_none_end(x):
|
|
|
|
return lax.collapse(x, -2)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def collapse_last_two_pos_end(x):
|
|
|
|
return lax.collapse(x, -2)
|
|
|
|
|
|
|
|
self.assertEqual((4, 3, 10),
|
|
|
|
collapse_last_two_none_end(np.zeros((4, 3, 2, 5))).shape)
|
|
|
|
self.assertEqual((4, 3, 10),
|
|
|
|
collapse_last_two_pos_end(np.zeros((4, 3, 2, 5))).shape)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, idxs=idxs, axes=axes)
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, idxs, axes in [
|
2020-07-14 13:03:24 -07:00
|
|
|
[(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
|
|
|
|
[(3, 4, 5), (np.array([-1, -2]),), (0,)],
|
|
|
|
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
|
2022-09-12 12:10:17 -07:00
|
|
|
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), [0, 2]],
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testIndexTake(self, shape, dtype, idxs, axes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
rand_idxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
|
|
|
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, idxs=idxs, dnums=dnums, slice_sizes=slice_sizes)
|
2019-01-14 14:33:40 -05:00
|
|
|
for shape, idxs, dnums, slice_sizes in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
|
|
|
(1, 3)),
|
2024-10-04 13:55:36 -07:00
|
|
|
((2, 5), np.array([[[0], [2]], [[1], [1]]]),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(), collapsed_slice_dims=(1,),
|
|
|
|
start_index_map=(1,), operand_batching_dims=(0,),
|
|
|
|
start_indices_batching_dims=(0,)),
|
|
|
|
(1, 1)),
|
|
|
|
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(2,), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(2,), operand_batching_dims=(0, 1),
|
|
|
|
start_indices_batching_dims=(1, 0)),
|
|
|
|
(1, 1, 3))
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.all_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testGather(self, shape, dtype, idxs, dnums, slice_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
rng_idx = jtu.rand_int(self.rng(), high=max(shape))
|
2019-01-14 14:33:40 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2020-08-29 10:24:03 +02:00
|
|
|
# These tests are adapted from the corresponding tests in
|
|
|
|
# tensorflow/compiler/xla/service/shape_inference_test.cc with slight
|
|
|
|
# variations to account for the implicit setting of index_vector_dim in JAX.
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.named_parameters(
|
2020-08-29 10:24:03 +02:00
|
|
|
{"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape,
|
2021-09-08 11:40:30 -04:00
|
|
|
"indices_shape": indices_shape,
|
2024-10-04 13:55:36 -07:00
|
|
|
"dimension_numbers": dimension_numbers,
|
2020-08-29 10:24:03 +02:00
|
|
|
"slice_sizes": slice_sizes, "msg": msg}
|
2024-10-04 13:55:36 -07:00
|
|
|
for (testcase_name, operand_shape, indices_shape, dimension_numbers,
|
|
|
|
slice_sizes, msg) in [
|
2020-08-29 10:24:03 +02:00
|
|
|
("NonAscendingWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 8, 7), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6), "offset_dims in gather op must be sorted"),
|
2020-08-29 10:24:03 +02:00
|
|
|
("RepeatedWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 7), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6), "offset_dims in gather op must not repeat"),
|
2020-08-29 10:24:03 +02:00
|
|
|
("WindowIndexOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 100, 101, 102), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6), "Offset dimension 2 in gather op is out of bounds"),
|
2020-08-29 10:24:03 +02:00
|
|
|
("WindowIndexBarelyOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 9), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6), "Offset dimension 4 in gather op is out of bounds"),
|
2020-08-29 10:24:03 +02:00
|
|
|
("MismatchingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6),
|
2021-10-04 17:54:18 -07:00
|
|
|
("All components of the offset index in a gather op must either be a "
|
2024-10-04 13:55:36 -07:00
|
|
|
"offset dimension or explicitly collapsed/batching")),
|
|
|
|
("MismatchingElidedWindowDimsV2", (10, 9, 8, 7, 6, 5), (10, 4, 3, 2, 4),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,),
|
|
|
|
start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,),
|
|
|
|
start_indices_batching_dims=(0,)),
|
|
|
|
(10, 9, 8, 7, 6, 5),
|
|
|
|
("All components of the offset index in a gather op must either be a "
|
|
|
|
"offset dimension or explicitly collapsed/batching")),
|
2020-08-29 10:24:03 +02:00
|
|
|
("OutOfBoundsWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 19),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6),
|
2020-08-29 10:24:03 +02:00
|
|
|
"Invalid collapsed_slice_dims set in gather op; valid range is"),
|
|
|
|
("RepeatedWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 3),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must not repeat"),
|
2020-08-29 10:24:03 +02:00
|
|
|
("MismatchingGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3)),
|
|
|
|
(10, 9, 8, 7, 6),
|
2021-10-04 17:54:18 -07:00
|
|
|
("Gather op has 4 elements in start_index_map and the bound of "
|
|
|
|
"dimension index_vector_dim=4 of indices is 5. These two "
|
|
|
|
"numbers must be equal.")),
|
2020-08-29 10:24:03 +02:00
|
|
|
("OutOfBoundsGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3, 7)),
|
|
|
|
(10, 9, 8, 7, 6), "Invalid start_index_map"),
|
2020-08-29 10:24:03 +02:00
|
|
|
("RepeatedGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3, 3)),
|
|
|
|
(10, 9, 8, 7, 6), "start_index_map in gather op must not repeat"),
|
2020-08-29 10:24:03 +02:00
|
|
|
("NonAscendingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 1),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6),
|
2020-08-29 10:24:03 +02:00
|
|
|
"collapsed_slice_dims in gather op must be sorted"),
|
|
|
|
("WindowBoundsTooLarge", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(2,),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 100, 6),
|
2020-08-29 10:24:03 +02:00
|
|
|
"Slice size at index 3 in gather op is out of range"),
|
|
|
|
("MismatchingNumberOfWindowBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7),
|
2020-08-29 10:24:03 +02:00
|
|
|
"Gather op must have one slice size for every input dimension"),
|
|
|
|
("WindowBoundsNot1ForElidedDim", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
2024-10-04 13:55:36 -07:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(1,),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4)),
|
|
|
|
(10, 9, 8, 7, 6),
|
2021-10-29 06:33:47 -07:00
|
|
|
("Gather op can only collapse slice dims with bound 1, but bound "
|
2024-10-04 13:55:36 -07:00
|
|
|
"is 9 for index 1 at position 0.")),
|
|
|
|
("RepeatedOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1),
|
|
|
|
start_index_map=(0, 1, 4), operand_batching_dims=(2, 3, 3)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
"operand_batching_dims in gather op must not repeat"),
|
|
|
|
("NonAscendingOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1),
|
|
|
|
start_index_map=(0, 1, 4), operand_batching_dims=(3, 2)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
"operand_batching_dims in gather op must be sorted"),
|
|
|
|
("OutOfBoundsOperandBatchingDims", (10, 9, 8, 7, 6),
|
|
|
|
(5, 4, 3, 2, 5),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4),
|
|
|
|
operand_batching_dims=(0, 10)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
"Invalid operand_batching_dims set in gather op; valid range is"),
|
|
|
|
("NonDisjointCollapsedAndBatchingDims", (10, 9, 8, 7, 6),
|
|
|
|
(5, 4, 3, 2, 3),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1, 2),
|
|
|
|
start_index_map=(0, 1, 4), operand_batching_dims=(2, 3)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
("collapsed_slice_dims and operand_batching_dims in gather op must be "
|
|
|
|
"disjoint")),
|
|
|
|
("NonDisjointStartIndexMapAndBatchingDims", (10, 9, 8, 7, 6),
|
|
|
|
(5, 4, 3, 2, 4),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1),
|
|
|
|
start_index_map=(0, 1, 2, 4), operand_batching_dims=(2, 3)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
("start_index_map and operand_batching_dims in gather op must be "
|
|
|
|
"disjoint")),
|
|
|
|
("WindowBoundsNot1ForBatchingDim", (10, 9, 8, 7, 6), (9, 4, 3, 2, 4),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0, 2, 3, 4), operand_batching_dims=(1,),
|
|
|
|
start_indices_batching_dims=(0,)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
("Gather op can only have operand batching dims with bound 0/1, but "
|
|
|
|
"bound is 9 for index 1 at position 0.")),
|
|
|
|
("RepeatedStartIndicesBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4),
|
|
|
|
start_indices_batching_dims=(0, 1, 0)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
"start_indices_batching_dims in gather op must not repeat"),
|
|
|
|
("OutOfBoundsStartIndicesBatchingDims", (10, 9, 8, 7, 6),
|
|
|
|
(5, 4, 3, 2, 5),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4),
|
|
|
|
start_indices_batching_dims=(0, 5)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
"Invalid start_indices_batching_dims set in gather op; valid range"),
|
|
|
|
("IndexVectorDimInStartIndicesBatchingDims", (10, 9, 8, 7, 6),
|
|
|
|
(5, 4, 3, 2, 5),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1),
|
|
|
|
start_index_map=(0, 1, 2, 3, 4),
|
|
|
|
start_indices_batching_dims=(0, 4)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
("Gather op cannot have the index vector dimension as a batching "
|
|
|
|
"dimension")),
|
|
|
|
("MismatchingNumberOfBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 4),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6), collapsed_slice_dims=(1, 2),
|
|
|
|
start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,),
|
|
|
|
start_indices_batching_dims=(0, 1)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
("Gather op requires equal numbers of operand_batching_dims and "
|
|
|
|
"start_indices_batching_dims")),
|
|
|
|
("MismatchingBatchingDimSizes", (10, 9, 8, 7, 6), (10, 9, 3, 2, 3),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 3, 4),
|
|
|
|
start_index_map=(2, 3, 4), operand_batching_dims=(0, 1),
|
|
|
|
start_indices_batching_dims=(1, 0)),
|
|
|
|
(10, 9, 8, 7, 6),
|
|
|
|
("Gather op requires operand batching dimensions and indices batching "
|
|
|
|
"dimensions to have the same shape"))
|
2020-08-29 10:24:03 +02:00
|
|
|
]
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2021-09-08 11:40:30 -04:00
|
|
|
def testGatherShapeCheckingRule(self, operand_shape, indices_shape,
|
2020-08-29 10:24:03 +02:00
|
|
|
dimension_numbers, slice_sizes, msg):
|
2024-10-04 13:55:36 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand_shape:
|
|
|
|
indices_shape:
|
|
|
|
dimension_numbers:
|
|
|
|
slice_sizes:
|
|
|
|
msg:
|
|
|
|
"""
|
2020-08-29 10:24:03 +02:00
|
|
|
operand = np.ones(operand_shape, dtype=np.int32)
|
2021-09-08 11:40:30 -04:00
|
|
|
indices = np.ones(indices_shape, dtype=np.int32)
|
2020-08-29 10:24:03 +02:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
2021-09-08 11:40:30 -04:00
|
|
|
lax.gather(operand, indices, dimension_numbers, slice_sizes)
|
2020-08-29 10:24:03 +02:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
|
|
|
|
dnums=dnums)
|
2019-01-14 14:33:40 -05:00
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,),
|
|
|
|
scatter_indices_batching_dims=(0,))),
|
|
|
|
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
|
|
|
|
(3, 2, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(2,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
|
|
|
|
scatter_indices_batching_dims=(1, 0)))
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.inexact_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
mode=["clip", "fill", None],
|
2024-10-03 00:26:44 -07:00
|
|
|
op=[lax.scatter_add, lax.scatter_sub],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2024-10-03 00:26:44 -07:00
|
|
|
def testScatterAddSub(self, arg_shape, dtype, idxs, update_shape, dnums, mode, op):
|
2020-12-03 11:01:16 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
2019-01-14 14:33:40 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
2024-10-03 00:26:44 -07:00
|
|
|
fun = partial(op, dimension_numbers=dnums, mode=mode)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
|
|
|
|
dnums=dnums)
|
2019-06-21 19:31:41 -07:00
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,),
|
|
|
|
scatter_indices_batching_dims=(0,))),
|
|
|
|
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
|
|
|
|
(3, 2, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(2,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
|
|
|
|
scatter_indices_batching_dims=(1, 0)))
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
2019-06-21 19:31:41 -07:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_min, dimension_numbers=dnums)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-06-21 19:31:41 -07:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
|
|
|
|
dnums=dnums)
|
2019-06-21 19:31:41 -07:00
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,),
|
|
|
|
scatter_indices_batching_dims=(0,))),
|
|
|
|
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
|
|
|
|
(3, 2, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(2,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
|
|
|
|
scatter_indices_batching_dims=(1, 0)))
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
2019-06-21 19:31:41 -07:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_max, dimension_numbers=dnums)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-06-21 19:31:41 -07:00
|
|
|
|
2023-07-10 16:42:45 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape, dnums=dnums)
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2023-07-10 16:42:45 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,),
|
|
|
|
scatter_indices_batching_dims=(0,))),
|
|
|
|
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
|
|
|
|
(3, 2, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(2,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
|
|
|
|
scatter_indices_batching_dims=(1, 0)))
|
2023-07-10 16:42:45 -07:00
|
|
|
]],
|
|
|
|
dtype=lax_test_util.float_dtypes,
|
|
|
|
)
|
|
|
|
def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs()]
|
|
|
|
fun = partial(lax.scatter_apply, func=jnp.sin, update_shape=update_shape, dimension_numbers=dnums)
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
|
|
|
|
dnums=dnums)
|
2019-03-01 15:41:49 -05:00
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-03-01 15:41:49 -05:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-03-01 15:41:49 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-03-01 15:41:49 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2024-10-04 13:55:36 -07:00
|
|
|
((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,),
|
|
|
|
scatter_indices_batching_dims=(0,))),
|
|
|
|
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
|
|
|
|
(3, 2, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(2,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
|
|
|
|
scatter_indices_batching_dims=(1, 0)))
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.float_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2020-12-03 11:01:16 -08:00
|
|
|
def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
2019-03-01 15:41:49 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter, dimension_numbers=dnums)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-03-01 15:41:49 -05:00
|
|
|
|
2020-08-27 11:04:32 +02:00
|
|
|
# These tests are adapted from the corresponding tests in
|
|
|
|
# tensorflow/compiler/xla/service/shape_inference_test.cc with slight
|
|
|
|
# variations to account for the implicit setting of index_vector_dim in JAX.
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.named_parameters(
|
2020-08-27 11:04:32 +02:00
|
|
|
{"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape,
|
2024-10-04 13:55:36 -07:00
|
|
|
"indices_shape": indices_shape, "update_shape": update_shape,
|
|
|
|
"dimension_numbers": dimension_numbers,
|
2020-08-27 11:04:32 +02:00
|
|
|
"msg": msg}
|
2024-10-04 13:55:36 -07:00
|
|
|
for (testcase_name, operand_shape, indices_shape, update_shape,
|
|
|
|
dimension_numbers, msg) in [
|
|
|
|
("ScatterWithUpdatesBiggerThanInput", (64, 48), (32, 1), (65, 32),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(0,), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1,)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Bounds of the window dimensions"),
|
2024-10-04 13:55:36 -07:00
|
|
|
("ScatterWithUpdatesBiggerThanInputV2", (64, 48), (32, 1),
|
|
|
|
(32, 49), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(1,)),
|
|
|
|
"Bounds of the window dimensions"),
|
|
|
|
("ScatterWithUpdatesNotMatchingIndices", (64, 48), (32, 1),
|
|
|
|
(64, 31), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(1,)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Bounds of the scatter dimensions"),
|
2024-10-04 13:55:36 -07:00
|
|
|
("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), (32, 1),
|
|
|
|
(31, 48), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(1,)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Bounds of the scatter dimensions"),
|
|
|
|
("ScatterNdWithUpdatesBiggerThanInput", (64, 48),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 1), (10, 9, 8, 7, 65),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4,), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1,)),
|
|
|
|
"Bounds of the window dimensions"),
|
2020-08-27 11:04:32 +02:00
|
|
|
("ScatterNdWithUpdatesNotMatchingIndices", (64, 48),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 1), (9, 9, 8, 7, 64),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4,), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(0,)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Bounds of the scatter dimensions"),
|
2024-10-04 13:55:36 -07:00
|
|
|
("InvalidUpdates", (50, 49, 48, 47, 46), (10, 9, 8, 7, 5),
|
|
|
|
(10, 9, 8, 7, 3, 2, 4, 1),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Updates tensor must be of rank 7; got 8."),
|
2024-10-04 13:55:36 -07:00
|
|
|
("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1),
|
|
|
|
(10, 9, 8, 7, 6, 5, 4, 3, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6, 8, 7), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"update_window_dims in scatter op must be sorted"),
|
2024-10-04 13:55:36 -07:00
|
|
|
("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1),
|
|
|
|
(10, 9, 8, 7, 6, 5, 4, 3, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6, 7, 7), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"update_window_dims in scatter op must not repeat"),
|
2024-10-04 13:55:36 -07:00
|
|
|
("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1),
|
|
|
|
(10, 9, 8, 7, 6, 5, 4, 3, 2),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6, 7, 9), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Invalid update_window_dims set in scatter op"),
|
|
|
|
("NonAscendingInsertedWindowDims", (50, 49, 48, 47, 46),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(2, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"inserted_window_dims in scatter op must be sorted"),
|
|
|
|
("RepeatedInsertedWindowDims", (50, 49, 48, 47, 46),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"inserted_window_dims in scatter op must not repeat"),
|
|
|
|
("OutOfBoundsInsertedWindowDims", (50, 49, 48, 47, 46),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1, 5),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Invalid inserted_window_dims set in scatter op"),
|
|
|
|
("MismatchingScatterDimsToOperandDims", (50, 49, 48, 47, 46),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3)),
|
2021-10-04 17:54:18 -07:00
|
|
|
("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
|
|
|
|
"the bound of dimension index_vector_dim=4 of indices "
|
|
|
|
"is 5. These two numbers must be equal")),
|
2020-08-27 11:04:32 +02:00
|
|
|
("OutOfBoundsScatterDimsToOperandDims", (50, 49, 48, 47, 46),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 10)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"Invalid scatter_dims_to_operand_dims mapping"),
|
|
|
|
("RepeatedValuesInScatterDimsToOperandDims", (50, 49, 48, 47, 46),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 2, 3)),
|
2020-08-27 11:04:32 +02:00
|
|
|
"scatter_dims_to_operand_dims in scatter op must not repeat"),
|
|
|
|
("InsufficientWindowDims", (50, 49, 48, 47, 46),
|
2024-10-04 13:55:36 -07:00
|
|
|
(10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3)),
|
2021-10-04 17:54:18 -07:00
|
|
|
("Scatter op has window of size 4; doesn't match operand of "
|
2024-10-04 13:55:36 -07:00
|
|
|
"rank 5.")),
|
|
|
|
("InsufficientWindowDimsV2", (10, 49, 48, 47, 46, 45),
|
|
|
|
(10, 9, 8, 7, 3), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1,),
|
|
|
|
scatter_dims_to_operand_dims=(1, 2, 3),
|
|
|
|
operand_batching_dims=(0,),
|
|
|
|
scatter_indices_batching_dims=(0,)),
|
|
|
|
("Scatter op has window of size 5; doesn't match operand of "
|
|
|
|
"rank 6.")),
|
|
|
|
("RepeatedOperandBatchingDims", (50, 49, 48, 47, 46),
|
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 4),
|
|
|
|
operand_batching_dims=(2, 3, 3)),
|
|
|
|
"operand_batching_dims in scatter op must not repeat"),
|
|
|
|
("NonAscendingOperandBatchingDims", (50, 49, 48, 47, 46),
|
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 4),
|
|
|
|
operand_batching_dims=(3, 2)),
|
|
|
|
"operand_batching_dims in scatter op must be sorted"),
|
|
|
|
("OutOfBoundsOperandBatchingDims", (50, 49, 48, 47, 46),
|
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4),
|
|
|
|
operand_batching_dims=(0, 10)),
|
|
|
|
("Invalid operand_batching_dims set in scatter op; valid range "
|
|
|
|
"is")),
|
|
|
|
("NonDisjointCollapsedAndBatchingDims", (50, 49, 48, 47, 46, 45),
|
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 4),
|
|
|
|
operand_batching_dims=(1, 2)),
|
|
|
|
("inserted_window_dims and operand_batching_dims in scatter op "
|
|
|
|
"must be disjoint")),
|
|
|
|
("NonDisjointScatterDimsToOperandDimsAndBatchingDims",
|
|
|
|
(50, 49, 48, 47, 46), (10, 9, 8, 7, 5),
|
|
|
|
(10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 4),
|
|
|
|
operand_batching_dims=(2, 3)),
|
|
|
|
("scatter_dims_to_operand_dims and operand_batching_dims in "
|
|
|
|
"scatter op must be disjoint")),
|
|
|
|
("RepeatedScatterIndicesBatchingDims", (50, 49, 48, 47, 46),
|
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4),
|
|
|
|
scatter_indices_batching_dims=(0, 1, 0)),
|
|
|
|
"scatter_indices_batching_dims in scatter op must not repeat"),
|
|
|
|
("OutOfBoundsScatterIndicesBatchingDims", (50, 49, 48, 47, 46),
|
|
|
|
(10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4),
|
|
|
|
scatter_indices_batching_dims=(0, 5)),
|
|
|
|
("Invalid scatter_indices_batching_dims set in scatter op; "
|
|
|
|
"valid range")),
|
|
|
|
("IndexVectorDimInScatterIndicesBatchingDims",
|
|
|
|
(50, 49, 48, 47, 46), (10, 9, 8, 7, 5),
|
|
|
|
(10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1, 2, 3, 4),
|
|
|
|
scatter_indices_batching_dims=(0, 4)),
|
|
|
|
("Scatter op cannot have the index vector dimension as a "
|
|
|
|
"batching dimension")),
|
|
|
|
("MismatchingNumberOfBatchingDims", (50, 49, 48, 47, 46, 45),
|
|
|
|
(10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2),
|
|
|
|
scatter_dims_to_operand_dims=(1, 2, 3, 4),
|
|
|
|
operand_batching_dims=(0,),
|
|
|
|
scatter_indices_batching_dims=(0, 1)),
|
|
|
|
("Scatter op requires equal numbers of operand_batching_dims "
|
|
|
|
"and scatter_indices_batching_dims")),
|
|
|
|
("MismatchingBatchingDimSizes", (10, 9, 48, 47, 46, 45),
|
|
|
|
(10, 9, 8, 7, 2), (10, 9, 8, 7, 3, 2, 4),
|
|
|
|
lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(4, 5, 6), inserted_window_dims=(2,),
|
|
|
|
scatter_dims_to_operand_dims=(2, 3),
|
|
|
|
operand_batching_dims=(0, 1),
|
|
|
|
scatter_indices_batching_dims=(1, 0)),
|
|
|
|
("Scatter op requires operand batching dimensions and indices "
|
|
|
|
"batching dimensions to have the same shape"))
|
2020-08-27 11:04:32 +02:00
|
|
|
]
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2024-10-04 13:55:36 -07:00
|
|
|
def testScatterShapeCheckingRule(self, operand_shape, indices_shape,
|
2020-08-27 11:04:32 +02:00
|
|
|
update_shape, dimension_numbers, msg):
|
2024-10-04 13:55:36 -07:00
|
|
|
indices = np.zeros(indices_shape, dtype=np.int32)
|
2021-04-13 14:10:11 -04:00
|
|
|
def f(x, y):
|
|
|
|
operand = lax.broadcast(x, operand_shape)
|
|
|
|
updates = lax.broadcast(y, update_shape)
|
2021-09-08 11:40:30 -04:00
|
|
|
return lax.scatter(operand, indices, updates, dimension_numbers)
|
2020-11-29 16:34:54 -08:00
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
2021-04-13 14:10:11 -04:00
|
|
|
jax.eval_shape(f, np.int32(1), np.int32(1))
|
2020-08-27 11:04:32 +02:00
|
|
|
|
2019-06-09 09:49:16 -07:00
|
|
|
def testIssue831(self):
|
|
|
|
# Tests the DeviceTuple constant handler
|
|
|
|
def f(x):
|
|
|
|
g = lambda *args: args[1]
|
2021-09-13 16:00:22 -04:00
|
|
|
return jax.jit(lax.fori_loop, static_argnums=(2,))( 0, 10, g, x)
|
2019-06-09 09:49:16 -07:00
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.jit(f)(1.) # doesn't crash
|
2019-06-09 09:49:16 -07:00
|
|
|
|
2019-08-31 21:23:39 -07:00
|
|
|
def testReshapeWithUnusualShapes(self):
|
2020-07-14 13:03:24 -07:00
|
|
|
ans = lax.reshape(np.ones((3,), np.float32), (lax.add(1, 2), 1))
|
|
|
|
self.assertAllClose(ans, np.ones((3, 1), np.float32))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
|
|
|
"Shapes must be 1D sequences of concrete values of integer type.*",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.reshape(np.ones(3,), (np.array([3, 1]),)))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
|
|
|
"Shapes must be 1D sequences of concrete values of integer type.*",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.reshape(np.ones(3,), (1.5, 2.0)))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-14 15:51:27 -05:00
|
|
|
def testDynamicSliceTypeErrors(self):
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 15:51:27 -05:00
|
|
|
TypeError,
|
|
|
|
"index arguments to dynamic_slice must be integers of the same type",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.dynamic_slice(np.ones((3, 4), dtype=np.float32),
|
|
|
|
(np.int32(1), np.int16(2)), (2, 2)))
|
2019-11-14 15:51:27 -05:00
|
|
|
|
|
|
|
def testDynamicUpdateSliceTypeErrors(self):
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 15:51:27 -05:00
|
|
|
TypeError,
|
|
|
|
"index arguments to dynamic_update_slice must be integers of the same "
|
|
|
|
"type",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.dynamic_update_slice(np.ones((3, 4), dtype=np.float32),
|
|
|
|
np.zeros((2, 2), dtype=np.float32),
|
|
|
|
(np.int32(1), np.int16(2))))
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2020-06-01 13:24:40 -07:00
|
|
|
def test_primitive_jaxtype_error(self):
|
2024-09-24 05:22:18 -07:00
|
|
|
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
|
|
|
|
r"value is of type .* and was passed to the function at path args\[1\].")
|
2021-03-19 13:49:38 -07:00
|
|
|
with jax.enable_checks(False):
|
2024-09-24 05:22:18 -07:00
|
|
|
with self.assertRaisesRegex(TypeError, err_str):
|
2020-06-01 13:24:40 -07:00
|
|
|
lax.add(1, 'hi')
|
|
|
|
|
2020-06-30 21:18:46 -07:00
|
|
|
def test_reduction_with_repeated_axes_error(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "duplicate value in 'axes' .*"):
|
2020-07-14 13:03:24 -07:00
|
|
|
lax.reduce(np.arange(3), 0, lax.add, (0, 0))
|
2020-06-30 21:18:46 -07:00
|
|
|
|
2022-06-01 16:32:10 -04:00
|
|
|
@parameterized.parameters([lax.rem, lax.lt, lax.gt, lax.ge, lax.le])
|
|
|
|
def test_ops_do_not_accept_complex_dtypes(self, op):
|
|
|
|
with self.assertRaisesRegex(TypeError, ".*does not accept dtype complex.*"):
|
|
|
|
op(2+3j, 4+5j)
|
|
|
|
|
2024-12-11 11:59:10 -08:00
|
|
|
@parameterized.parameters([lax.add, lax.mul, lax.div, lax.rem, lax.lt, lax.gt,
|
|
|
|
lax.ge, lax.le, lax.eq, lax.ne])
|
|
|
|
def test_ops_error_on_mismatched_dtypes(self, op):
|
|
|
|
with self.assertRaisesRegex(TypeError, ".*requires arguments to have the same dtypes.*"):
|
|
|
|
op(0, 0.0)
|
|
|
|
|
2020-07-28 19:46:00 -07:00
|
|
|
def test_population_count_booleans_not_supported(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/3886
|
2020-07-28 19:46:00 -07:00
|
|
|
msg = "population_count does not accept dtype bool"
|
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
|
|
lax.population_count(True)
|
|
|
|
|
2020-09-17 19:37:40 +02:00
|
|
|
def test_conv_general_dilated_different_input_ranks_error(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/4316
|
2020-09-17 19:37:40 +02:00
|
|
|
msg = ("conv_general_dilated lhs and rhs must have the same number of "
|
|
|
|
"dimensions")
|
|
|
|
dimension_numbers = lax.ConvDimensionNumbers(lhs_spec=(0, 1, 2),
|
|
|
|
rhs_spec=(0, 1, 2),
|
|
|
|
out_spec=(0, 1, 2))
|
|
|
|
kwargs = { 'window_strides': (1,)
|
|
|
|
, 'padding': ((0, 0),)
|
|
|
|
, 'lhs_dilation': (1,)
|
|
|
|
, 'rhs_dilation': (1,)
|
|
|
|
, 'dimension_numbers': dimension_numbers
|
|
|
|
, 'feature_group_count': 1
|
|
|
|
, 'batch_group_count': 1
|
|
|
|
, 'precision': None
|
|
|
|
}
|
|
|
|
lhs, rhs = np.ones((1, 1, 1)), np.ones((1, 1, 1, 1))
|
|
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
|
|
lax.conv_general_dilated(lhs, rhs, **kwargs)
|
|
|
|
|
2020-12-05 11:53:39 +09:00
|
|
|
def test_window_strides_dimension_shape_rule(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/5087
|
2020-12-05 11:53:39 +09:00
|
|
|
msg = ("conv_general_dilated window and window_strides must have "
|
2020-12-06 15:39:30 +02:00
|
|
|
"the same number of dimensions")
|
2020-12-05 11:53:39 +09:00
|
|
|
lhs = jax.numpy.zeros((1, 1, 3, 3))
|
|
|
|
rhs = np.zeros((1, 1, 1, 1))
|
|
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
|
|
jax.lax.conv(lhs, rhs, [1], 'SAME')
|
|
|
|
|
2020-10-14 10:47:52 +02:00
|
|
|
def test_reduce_window_scalar_init_value_shape_rule(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/4574
|
2020-10-14 10:47:52 +02:00
|
|
|
args = { "operand": np.ones((4, 4), dtype=np.int32)
|
|
|
|
, "init_value": np.zeros((1,), dtype=np.int32)
|
|
|
|
, "computation": lax.max
|
|
|
|
, "window_dimensions": (2, 2)
|
|
|
|
, "window_strides": (2, 2)
|
|
|
|
, "padding": "VALID"
|
|
|
|
, "base_dilation": (1, 1)
|
|
|
|
, "window_dilation": (1, 1)
|
|
|
|
}
|
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
msg = (r"reduce_window expected init_values to be scalars but init_values "
|
|
|
|
r"have shapes \[\(1,\)\].")
|
2020-10-14 10:47:52 +02:00
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
|
|
lax.reduce_window(**args)
|
2020-06-30 21:18:46 -07:00
|
|
|
|
2020-11-10 15:57:19 -08:00
|
|
|
def test_reduce_correctly_works_with_pytrees(self):
|
|
|
|
operands = {'x': [np.ones(5), np.arange(5)]}
|
|
|
|
init_values = {'x': [0., 0]}
|
|
|
|
result = lax.reduce(operands, init_values,
|
2024-02-26 14:17:18 -08:00
|
|
|
lambda x, y: jax.tree.map(lax.add, x, y),
|
2020-11-10 15:57:19 -08:00
|
|
|
[0])
|
2022-06-15 12:09:08 -07:00
|
|
|
self.assertDictEqual(result, {'x': [5., 10]})
|
2020-11-10 15:57:19 -08:00
|
|
|
|
|
|
|
def test_reduce_with_mismatched_pytrees_errors(self):
|
|
|
|
operands = {'x': np.ones(5)}
|
|
|
|
bad_init_values = {'y': 0.}
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, 'Operands must have the same '
|
|
|
|
'tree structure as init_values'):
|
|
|
|
lax.reduce(operands, bad_init_values,
|
|
|
|
lambda x, y: dict(x=x['x'] + y['x']), [0])
|
|
|
|
|
|
|
|
def test_reduce_with_nonscalar_inits_errors(self):
|
|
|
|
operands = {'x': np.ones(5)}
|
|
|
|
bad_init_values = {'x': np.ones(5)}
|
|
|
|
|
2021-02-12 10:30:46 -08:00
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
'reduce found non-scalar initial value'):
|
2020-11-10 15:57:19 -08:00
|
|
|
lax.reduce(operands, bad_init_values,
|
|
|
|
lambda x, y: dict(x=x['x'] + y['x']), [0])
|
|
|
|
|
2020-10-26 15:32:31 -07:00
|
|
|
def test_select_jvp_complexity(self):
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: jax.jvp(lambda x: lax.select(True, x, x),
|
|
|
|
(x,), (1.,)))(1.)
|
|
|
|
self.assertLen(jaxpr.jaxpr.eqns, 2)
|
|
|
|
|
2021-02-16 12:31:01 +00:00
|
|
|
def testRngBitGenerator(self):
|
2021-10-05 13:46:57 -07:00
|
|
|
# This test covers the original behavior of lax.rng_bit_generator, which
|
|
|
|
# required x64=True, and only checks shapes and jit invariance.
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value:
|
2021-02-16 12:31:01 +00:00
|
|
|
raise SkipTest("RngBitGenerator requires 64bit key")
|
|
|
|
|
|
|
|
key = np.array((1, 2)).astype(np.uint64)
|
|
|
|
def fn(k):
|
|
|
|
return lax.rng_bit_generator(
|
|
|
|
k, shape=(5, 7), algorithm=lax.RandomAlgorithm.RNG_THREE_FRY)
|
|
|
|
|
|
|
|
out = fn(key)
|
2021-09-13 16:00:22 -04:00
|
|
|
out_jit = jax.jit(fn)(key)
|
2021-02-16 12:31:01 +00:00
|
|
|
self.assertEqual(out[0].shape, (2,))
|
|
|
|
self.assertEqual(out[1].shape, (5, 7))
|
|
|
|
self.assertArraysEqual(out[0], out_jit[0])
|
|
|
|
self.assertArraysEqual(out[1], out_jit[1])
|
|
|
|
|
2022-01-20 22:20:17 -08:00
|
|
|
def testRngBitGenerator2(self):
|
|
|
|
def f(key):
|
|
|
|
return lax.rng_bit_generator(key, shape=(5, 7))
|
|
|
|
|
|
|
|
key = np.array((1, 2, 3, 4)).astype(np.uint32)
|
|
|
|
out1 = f(key)
|
|
|
|
out2 = jax.jit(f)(key)
|
|
|
|
self.assertEqual(out1[0].shape, (4,))
|
|
|
|
self.assertEqual(out1[1].shape, (5, 7))
|
|
|
|
self.assertArraysEqual(out1[0], out2[0])
|
|
|
|
self.assertArraysEqual(out1[1], out2[1])
|
|
|
|
|
2021-10-05 13:46:57 -07:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def testRngBitGeneratorReturnedKey(self):
|
|
|
|
# This test ensures that the key bit-packing/unpacking operations used in
|
|
|
|
# the translation rule for rng_bit_generator, on older jaxlibs and at time
|
|
|
|
# of writing on GPU, are inverses of one another.
|
|
|
|
key = np.array([3, 1, 4, 2], dtype=np.dtype('uint32'))
|
|
|
|
new_key, _ = lax.rng_bit_generator(key, (0,))
|
|
|
|
self.assertAllClose(key, new_key)
|
|
|
|
|
2024-03-05 20:09:14 -08:00
|
|
|
def test_rng_bit_generator_vmap(self):
|
|
|
|
def f(key):
|
|
|
|
return lax.rng_bit_generator(key, shape=(5, 7))
|
|
|
|
|
|
|
|
keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32)
|
|
|
|
out_keys, bits = jax.vmap(f)(keys)
|
|
|
|
self.assertEqual(out_keys.shape, (3, 4))
|
|
|
|
self.assertEqual(bits.shape, (3, 5, 7))
|
|
|
|
|
|
|
|
def test_rng_bit_generator_vmap_vmap(self):
|
|
|
|
def f(key):
|
|
|
|
return lax.rng_bit_generator(key, shape=(5, 7))
|
|
|
|
|
|
|
|
keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32)
|
|
|
|
out_keys, bits = jax.vmap(jax.vmap(f))(keys)
|
|
|
|
self.assertEqual(out_keys.shape, (2, 3, 4))
|
|
|
|
self.assertEqual(bits.shape, (2, 3, 5, 7))
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types,
|
|
|
|
weak_type=[True, False],
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2021-03-17 13:07:53 -07:00
|
|
|
def test_const(self, dtype, weak_type):
|
2023-02-16 15:29:12 -08:00
|
|
|
if dtype in set(lax_test_util.python_scalar_types):
|
2021-03-17 13:07:53 -07:00
|
|
|
val = dtype(0)
|
|
|
|
else:
|
2022-03-09 18:18:16 -08:00
|
|
|
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
|
2021-03-17 13:07:53 -07:00
|
|
|
|
2022-03-07 12:25:01 -08:00
|
|
|
const = lax_internal._const(val, 0)
|
2021-11-22 09:29:43 -08:00
|
|
|
self.assertEqual(dtypes.dtype(val, canonicalize=True),
|
|
|
|
dtypes.dtype(const, canonicalize=True))
|
2020-10-26 15:32:31 -07:00
|
|
|
|
2021-05-11 04:31:08 -07:00
|
|
|
def testIgammaSpecial(self):
|
|
|
|
self.assertEqual(lax.igamma(1., np.inf), 1.)
|
|
|
|
self.assertEqual(lax.igammac(1., np.inf), 0.)
|
|
|
|
|
2021-05-11 21:26:58 -04:00
|
|
|
def testRegressionIssue5728(self):
|
|
|
|
# The computation in this test gave garbage data on CPU due to an LLVM bug.
|
|
|
|
@jax.jit
|
|
|
|
def f(inputs):
|
|
|
|
out_action_2 = lax.slice_in_dim(inputs, 0, 15, axis=-1)
|
|
|
|
mask = lax.slice_in_dim(inputs, 7, 22, axis=-1)
|
|
|
|
out_action_2 = lax.select(lax.eq(mask, np.float32(0)),
|
|
|
|
lax.broadcast(np.float32(42), (1, 15)),
|
|
|
|
out_action_2)
|
|
|
|
return lax.pad(out_action_2, np.float32(42), [(0, 0, 0), (0, 15, 0)])
|
|
|
|
self.assertArraysEqual(np.full((1, 30), np.float32(42)),
|
|
|
|
f(np.zeros((1, 24), dtype=np.float32)))
|
|
|
|
|
2023-07-25 13:09:09 -07:00
|
|
|
def testDynamicSliceUnsignedNoNormalization(self):
|
|
|
|
# Test that no negative index correction is done for unsigned indices.
|
|
|
|
f = lambda x, i: lax.dynamic_slice(x, [i], [1])
|
|
|
|
x = np.arange(200)
|
|
|
|
i = np.uint32(128)
|
|
|
|
jaxpr = jax.make_jaxpr(f)(x, i)
|
|
|
|
self.assertLen(jaxpr.eqns, 1)
|
|
|
|
self.assertEqual(jaxpr.eqns[0].primitive, lax.dynamic_slice_p)
|
|
|
|
|
2021-06-23 13:29:15 -07:00
|
|
|
def testDynamicSliceU8Index(self):
|
|
|
|
# Regression test for u8 index in dynamic-slice (#6122)
|
|
|
|
x = np.arange(200)
|
|
|
|
np.testing.assert_equal(
|
|
|
|
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
|
|
|
|
|
2023-07-24 14:01:07 -07:00
|
|
|
def test_dot_general_batching_python_builtin_arg(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/16805
|
2023-07-24 14:01:07 -07:00
|
|
|
@jax.remat
|
|
|
|
def f(x):
|
|
|
|
return jax.lax.dot_general(x, x, (([], []), ([], [])))
|
|
|
|
|
|
|
|
jax.hessian(f)(1.0) # don't crash
|
|
|
|
|
2023-12-19 20:51:33 -08:00
|
|
|
def test_constant_folding_complex_to_real_scan_regression(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# regression test for github.com/jax-ml/jax/issues/19059
|
2023-12-19 20:51:33 -08:00
|
|
|
def g(hiddens):
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
hiddens_aug = jnp.vstack((hiddens[0], hiddens))
|
|
|
|
new_hiddens = hiddens_aug.copy()
|
|
|
|
diff = new_hiddens[:-1] - hiddens
|
|
|
|
diff = new_hiddens[:-1] - hiddens
|
|
|
|
out = jnp.trace(jnp.conj(diff).T @ diff).real
|
|
|
|
return jnp.array(out, dtype=jnp.complex64)
|
2023-12-19 20:51:33 -08:00
|
|
|
|
|
|
|
def _step(carry, arg):
|
Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 13:59:50 -07:00
|
|
|
primals, f_vjp = jax.vjp(
|
|
|
|
g,
|
|
|
|
jax.random.normal(jax.random.key(0), (9, 8), dtype=jnp.complex64),
|
|
|
|
)
|
|
|
|
out = f_vjp(np.array(1.0 + 0j, 'complex64'))[0]
|
|
|
|
return carry, carry
|
2023-12-19 20:51:33 -08:00
|
|
|
|
|
|
|
a, b = jax.lax.scan(_step, 0, jnp.arange(4, dtype=jnp.complex64))
|
|
|
|
|
2024-01-24 15:36:40 -08:00
|
|
|
@parameterized.parameters([float, np.array, np.float32, jnp.float32])
|
|
|
|
def testAsarray(self, typ):
|
|
|
|
x = typ(1.0)
|
|
|
|
x_arr = lax_internal.asarray(x)
|
|
|
|
self.assertArraysEqual(x, x_arr)
|
|
|
|
self.assertIsInstance(x_arr, jax.Array)
|
|
|
|
|
|
|
|
# jaxpr should not bind any primitives, whether called directly or
|
|
|
|
# as a closure:
|
|
|
|
jaxpr = jax.make_jaxpr(lax_internal.asarray)(x)
|
|
|
|
self.assertLen(jaxpr.eqns, 0)
|
|
|
|
|
|
|
|
asarray_closure = lambda: lax_internal.asarray(x)
|
|
|
|
jaxpr = jax.make_jaxpr(asarray_closure)()
|
|
|
|
self.assertLen(jaxpr.eqns, 0)
|
|
|
|
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/19334
|
2024-01-24 15:36:40 -08:00
|
|
|
# lax.asarray as a closure should not trigger transfer guard.
|
|
|
|
with jax.transfer_guard('disallow'):
|
|
|
|
jax.jit(asarray_closure)()
|
|
|
|
|
2024-12-09 19:20:04 -08:00
|
|
|
def test_optimization_barrier(self):
|
2024-09-05 19:49:12 +00:00
|
|
|
x = lax.optimization_barrier((2, 3))
|
|
|
|
self.assertEqual((2, 3), x)
|
|
|
|
|
2024-01-24 15:36:40 -08:00
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
class LazyConstantTest(jtu.JaxTestCase):
|
|
|
|
def _Check(self, make_const, expected):
|
2018-12-18 22:45:34 -08:00
|
|
|
# check casting to ndarray works
|
2020-07-14 13:03:24 -07:00
|
|
|
asarray_result = np.asarray(make_const())
|
2018-12-18 22:45:34 -08:00
|
|
|
|
|
|
|
# check passing as an argument works (should hit constant handler)
|
2020-07-14 13:03:24 -07:00
|
|
|
zero = np.array(0, expected.dtype)
|
2018-12-18 22:45:34 -08:00
|
|
|
argument_result = lax.add(zero, make_const())
|
|
|
|
|
|
|
|
# check looping into a compiled computation works
|
2021-09-13 16:00:22 -04:00
|
|
|
jit_result = jax.jit(lambda x: lax.add(x, make_const()))(zero)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
|
|
|
# ensure they're all the same
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(asarray_result, expected)
|
|
|
|
self.assertAllClose(argument_result, expected)
|
|
|
|
self.assertAllClose(jit_result, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2019-05-17 12:38:45 -07:00
|
|
|
# ensure repr doesn't crash
|
|
|
|
repr(make_const())
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes + [None],
|
2022-10-03 13:36:01 +00:00
|
|
|
shape=[(), (3,), (2, 3), (2, 3, 4), (1001, 1001)],
|
|
|
|
fill_value=[0, 1, np.pi],
|
|
|
|
)
|
2018-12-18 22:45:34 -08:00
|
|
|
def testFilledConstant(self, shape, fill_value, dtype):
|
|
|
|
make_const = lambda: lax.full(shape, fill_value, dtype)
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.full(shape, fill_value,
|
2021-11-22 09:29:43 -08:00
|
|
|
dtype or dtypes.dtype(fill_value))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
self._Check(make_const, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, dimension=dimension)
|
|
|
|
for shape in [(), (3,), (2, 3), (2, 3, 4)]
|
|
|
|
# TODO(mattjj): re-enable (1001, 1001), (101, 101, 101),
|
|
|
|
for dimension in range(len(shape))],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
2018-12-18 22:45:34 -08:00
|
|
|
def testIotaConstant(self, dtype, shape, dimension):
|
|
|
|
make_const = lambda: lax.broadcasted_iota(dtype, shape, dimension)
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
arr = np.arange(shape[dimension], dtype=dtypes.canonicalize_dtype(dtype))
|
2018-12-18 22:45:34 -08:00
|
|
|
singleton_shape = [1] * len(shape)
|
|
|
|
singleton_shape[dimension] = shape[dimension]
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.broadcast_to(arr.reshape(singleton_shape), shape)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
self._Check(make_const, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axes=axes)
|
2018-12-18 22:45:34 -08:00
|
|
|
for shape, axes in [
|
|
|
|
[(2, 3), (0, 1)],
|
|
|
|
[(2, 3, 4), (0, 1)],
|
|
|
|
[(2, 3, 4), (0, 2)],
|
|
|
|
[(2, 3, 4), (1, 2)],
|
|
|
|
[(2, 3, 4), (0, 1, 2)],
|
|
|
|
[(2, 3, 4, 2), (0, 1, 2)],
|
|
|
|
[(2, 3, 4, 2), (0, 2, 3)],
|
2019-05-17 13:01:45 -07:00
|
|
|
[(1001, 1001), (0, 1)],
|
2022-10-03 13:36:01 +00:00
|
|
|
]],
|
2023-02-16 15:29:12 -08:00
|
|
|
dtype=lax_test_util.default_dtypes,
|
2022-10-03 13:36:01 +00:00
|
|
|
)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
def testDeltaConstant(self, dtype, shape, axes):
|
2022-03-08 13:45:06 -08:00
|
|
|
make_const = lambda: lax_internal._delta(dtype, shape, axes)
|
2018-12-18 22:45:34 -08:00
|
|
|
# don't check the asarray case, just assume it's right
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.asarray(make_const())
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
self._Check(make_const, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2020-06-25 15:50:11 +01:00
|
|
|
def testBroadcastInDim(self):
|
|
|
|
arr = lax.full((2, 1), 1.) + 1.
|
2020-07-14 13:03:24 -07:00
|
|
|
arr_np = np.full((2, 1), 1.) + 1.
|
|
|
|
expected = lax_reference.broadcast_in_dim(arr_np, (2, 1, 3), (0, 2))
|
2020-06-25 15:50:11 +01:00
|
|
|
make_const = lambda: lax.broadcast_in_dim(arr, (2, 1, 3), (0, 2))
|
|
|
|
self._Check(make_const, expected)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
input_type=[int, float, np.int32, np.float32, np.array],
|
|
|
|
dtype=[np.int32, np.float32],
|
|
|
|
jit=[True, False],
|
|
|
|
value=[0, 1],
|
|
|
|
)
|
2020-12-07 09:10:34 -08:00
|
|
|
def testConvertElementReturnType(self, input_type, dtype, value, jit):
|
|
|
|
op = lambda x: lax.convert_element_type(x, dtype)
|
|
|
|
if jit:
|
2021-09-13 16:00:22 -04:00
|
|
|
op = jax.jit(op)
|
2020-12-07 09:10:34 -08:00
|
|
|
result = op(input_type(value))
|
2022-09-26 16:17:26 -07:00
|
|
|
assert isinstance(result, jax.Array)
|
2020-12-07 09:10:34 -08:00
|
|
|
|
2023-02-16 15:29:12 -08:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes)
|
2023-08-07 19:08:41 +02:00
|
|
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
2020-12-07 09:10:34 -08:00
|
|
|
def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out):
|
2022-04-20 14:47:15 -07:00
|
|
|
x = jax.device_put(np.zeros(5, dtype_in))
|
2020-12-07 09:10:34 -08:00
|
|
|
self.assertEqual(x.dtype, dtype_in)
|
|
|
|
y = lax.convert_element_type(x, dtype_out)
|
|
|
|
self.assertEqual(y.dtype, dtype_out)
|
2023-03-15 17:08:21 -07:00
|
|
|
x_buf = x
|
|
|
|
y_buf = y
|
2020-12-07 09:10:34 -08:00
|
|
|
if np.dtype(dtype_in) == np.dtype(dtype_out):
|
2022-09-15 13:26:57 -07:00
|
|
|
self.assertEqual(x_buf.unsafe_buffer_pointer(),
|
|
|
|
y_buf.unsafe_buffer_pointer())
|
2020-12-07 09:10:34 -08:00
|
|
|
else:
|
2022-09-15 13:26:57 -07:00
|
|
|
self.assertNotEqual(x_buf.unsafe_buffer_pointer(),
|
|
|
|
y_buf.unsafe_buffer_pointer())
|
2020-12-07 09:10:34 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
index_dtype=jtu.dtypes.all_inexact + jtu.dtypes.boolean,
|
|
|
|
jax_fn=[lax.argmin, lax.argmax],
|
|
|
|
)
|
2020-11-17 09:45:48 +01:00
|
|
|
def testArgMinMaxIndexDtypeError(self, jax_fn, index_dtype):
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
|
|
"index_dtype must be an integer type"):
|
|
|
|
jax_fn(np.ones((2, 2)), axis=0, index_dtype=index_dtype)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.parameters([lax.argmin, lax.argmax])
|
2021-07-12 01:11:17 -07:00
|
|
|
def testArgMinMaxEmptyError(self, jax_fn):
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
"require non-empty reduced dimension"):
|
|
|
|
jax_fn(np.ones((0, 2)), axis=0, index_dtype=np.int32)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@parameterized.parameters([lax.argmin, lax.argmax])
|
2021-07-12 01:11:17 -07:00
|
|
|
def testArgMinMaxInvalidAxisError(self, jax_fn):
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
"Invalid axis -1 for operand"):
|
|
|
|
jax_fn(np.ones((2, 3)), axis=-1, index_dtype=np.int32)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
jax_fn=[lax.argmin, lax.argmax],
|
|
|
|
weak_type=[False, True],
|
|
|
|
)
|
2021-02-08 13:37:25 -08:00
|
|
|
def testArgMinMaxWeakType(self, jax_fn, weak_type):
|
|
|
|
op = lambda x: jax_fn(x, axis=0, index_dtype=np.int32)
|
2022-03-09 18:18:16 -08:00
|
|
|
x_in = lax_internal._convert_element_type(np.ones((2, 2)),
|
|
|
|
weak_type=weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
self.assertEqual(dtypes.is_weakly_typed(x_in), weak_type)
|
|
|
|
x_out = op(x_in)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(x_out), False)
|
2021-09-13 16:00:22 -04:00
|
|
|
x_out_jit = jax.jit(op)(x_in)
|
2021-02-08 13:37:25 -08:00
|
|
|
self.assertEqual(dtypes.is_weakly_typed(x_out_jit), False)
|
|
|
|
|
2021-05-17 09:48:46 -04:00
|
|
|
def testArgMaxOfNanChoosesNaN(self):
|
|
|
|
self.assertEqual(lax.argmax(np.array([0., np.nan]), axis=0,
|
|
|
|
index_dtype=np.int32), 1)
|
|
|
|
|
2021-07-22 14:00:52 -07:00
|
|
|
unary_op_types = {}
|
2023-02-16 15:29:12 -08:00
|
|
|
for r in lax_test_util.lax_ops():
|
2021-07-22 14:00:52 -07:00
|
|
|
if r.nargs == 1:
|
|
|
|
unary_op_types[r.op] = (unary_op_types.get(r.op, set()) |
|
2022-05-12 19:13:00 +01:00
|
|
|
{np.dtype(t) for t in r.dtypes})
|
2021-07-22 14:00:52 -07:00
|
|
|
|
2022-08-24 09:16:47 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_{op}", "op_name": op, "rec_dtypes": dtypes}
|
|
|
|
for op, dtypes in unary_op_types.items())
|
2021-02-08 13:37:25 -08:00
|
|
|
def testUnaryWeakTypes(self, op_name, rec_dtypes):
|
|
|
|
"""Test that all lax unary ops propagate weak_type information appropriately."""
|
2022-09-07 06:06:22 -07:00
|
|
|
if op_name == "bitwise_not":
|
2024-09-20 07:51:48 -07:00
|
|
|
raise unittest.SkipTest("https://github.com/jax-ml/jax/issues/12066")
|
2021-02-08 13:37:25 -08:00
|
|
|
# Find a valid dtype for the function.
|
2023-09-20 12:26:12 -07:00
|
|
|
for dtype in [float, int, complex, bool]:
|
2021-02-08 13:37:25 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if dtype in rec_dtypes:
|
|
|
|
py_val = dtype.type(1).item()
|
|
|
|
lax_val = lax.full((), py_val, dtype)
|
|
|
|
break
|
|
|
|
else:
|
2021-07-22 14:00:52 -07:00
|
|
|
raise ValueError(f"no available dtypes in {rec_dtypes}")
|
2021-02-08 13:37:25 -08:00
|
|
|
|
|
|
|
op = getattr(lax, op_name)
|
|
|
|
py_op = op(py_val)
|
|
|
|
lax_op = op(lax_val)
|
|
|
|
|
|
|
|
self.assertAllClose(py_op, lax_op, check_dtypes=True)
|
|
|
|
self.assertFalse(lax_op.aval.weak_type)
|
2022-08-24 09:16:47 -07:00
|
|
|
if type(py_val) == bool:
|
|
|
|
# Booleans should have weak types stripped.
|
|
|
|
self.assertFalse(py_op.aval.weak_type)
|
|
|
|
else:
|
|
|
|
self.assertTrue(py_op.aval.weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
|
|
|
|
def testCumsumLengthOne(self):
|
|
|
|
# regression test for issue 4672
|
|
|
|
x = lax.full((1,), 1)
|
|
|
|
out = lax.cumsum(x)
|
|
|
|
self.assertArraysEqual(out, x)
|
|
|
|
|
2021-06-18 08:11:51 -07:00
|
|
|
def testLog1pNearOne(self):
|
2022-06-15 12:09:08 -07:00
|
|
|
expected = np.log1p(np.float32(1e-5))
|
2021-06-18 08:11:51 -07:00
|
|
|
np.testing.assert_array_almost_equal_nulp(
|
2022-06-15 12:09:08 -07:00
|
|
|
expected.astype(np.float32), lax.log1p(np.float32(1e-5)))
|
2021-06-18 08:11:51 -07:00
|
|
|
np.testing.assert_array_almost_equal_nulp(
|
2022-06-15 12:09:08 -07:00
|
|
|
expected.astype(np.complex64), lax.log1p(np.complex64(1e-5)))
|
2021-06-18 08:11:51 -07:00
|
|
|
|
2021-03-09 13:48:15 -08:00
|
|
|
|
2022-08-30 13:25:49 -07:00
|
|
|
class FooTyRules:
|
2022-08-05 22:18:53 -07:00
|
|
|
# handlers
|
|
|
|
|
2022-08-31 22:53:32 -07:00
|
|
|
@staticmethod
|
2023-05-10 19:13:29 -07:00
|
|
|
def physical_element_aval(dtype) -> core.ShapedArray:
|
|
|
|
return core.ShapedArray((2,), jnp.dtype('uint32'))
|
2022-08-31 22:53:32 -07:00
|
|
|
|
2022-08-05 22:18:53 -07:00
|
|
|
@staticmethod
|
|
|
|
def result_handler(sticky_device, aval):
|
|
|
|
def handler(_, buf):
|
|
|
|
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
|
|
|
return FooArray(aval.shape, buf)
|
|
|
|
return handler
|
|
|
|
|
2022-08-29 22:02:32 -07:00
|
|
|
@staticmethod
|
2024-02-28 15:21:50 -08:00
|
|
|
def global_sharded_result_handler(aval, out_sharding, committed):
|
2023-03-08 21:39:56 -08:00
|
|
|
def handler(arr):
|
|
|
|
from jax._src.array import ArrayImpl
|
|
|
|
if isinstance(arr, ArrayImpl):
|
|
|
|
buf, = arr._arrays
|
|
|
|
else:
|
|
|
|
buf, = arr
|
2022-08-29 22:02:32 -07:00
|
|
|
return FooArray(aval.shape, buf)
|
|
|
|
return handler
|
|
|
|
|
2022-08-30 13:25:49 -07:00
|
|
|
|
2023-07-24 14:29:37 -07:00
|
|
|
class FooTy(dtypes.ExtendedDType):
|
|
|
|
type = dtypes.extended
|
2022-08-30 13:25:49 -07:00
|
|
|
name = 'foo'
|
|
|
|
_rules = FooTyRules
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return hash(FooTy)
|
|
|
|
def __eq__(self, other) -> bool:
|
|
|
|
return type(other) is FooTy
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return self.name
|
|
|
|
__str__ = __repr__
|
|
|
|
|
2022-08-05 22:18:53 -07:00
|
|
|
# primitives
|
|
|
|
|
|
|
|
make_p = core.Primitive('make')
|
|
|
|
bake_p = core.Primitive('bake')
|
|
|
|
take_p = core.Primitive('take')
|
2023-05-12 16:48:26 -07:00
|
|
|
jake_p = core.Primitive('jake')
|
2022-08-05 22:18:53 -07:00
|
|
|
|
|
|
|
def make(shape): return make_p.bind(shape=tuple(shape))
|
|
|
|
def bake(k): return bake_p.bind(k)
|
|
|
|
def take(k): return take_p.bind(k)
|
2023-05-12 16:48:26 -07:00
|
|
|
def jake(k): return jake_p.bind(k)
|
2022-08-05 22:18:53 -07:00
|
|
|
|
|
|
|
@make_p.def_abstract_eval
|
|
|
|
def make_abstract_eval(*, shape):
|
|
|
|
return core.ShapedArray(shape, FooTy())
|
|
|
|
|
|
|
|
@bake_p.def_abstract_eval
|
|
|
|
def bake_abstract_eval(x):
|
|
|
|
if type(x.dtype) != FooTy: raise TypeError
|
|
|
|
return core.ShapedArray(tuple(reversed(x.shape)), FooTy())
|
|
|
|
|
|
|
|
@take_p.def_abstract_eval
|
|
|
|
def take_abstract_eval(x):
|
|
|
|
return core.ShapedArray(x.shape, jnp.dtype('float32'))
|
|
|
|
|
2023-05-12 16:48:26 -07:00
|
|
|
@jake_p.def_abstract_eval
|
|
|
|
def jake_abstract_eval(x):
|
|
|
|
return x
|
|
|
|
|
2022-08-05 22:18:53 -07:00
|
|
|
# runtime ('outside jit') data types
|
|
|
|
|
|
|
|
class FooArray:
|
2023-06-23 15:11:37 -07:00
|
|
|
shape: tuple[int, ...]
|
2023-02-15 14:52:31 -08:00
|
|
|
data: jax.Array
|
2022-08-05 22:18:53 -07:00
|
|
|
|
|
|
|
def __init__(self, shape, data):
|
|
|
|
assert data.shape == (*shape, 2)
|
|
|
|
self.shape = shape
|
|
|
|
self.data = data
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
shape = ','.join(map(str, self.shape))
|
|
|
|
return f'foo[{shape}] with value\n{self.data}'
|
|
|
|
|
|
|
|
size = property(lambda self: self.data.size // 2)
|
|
|
|
ndim = property(lambda self: self.data.ndim - 1)
|
|
|
|
|
2024-11-07 15:50:32 -08:00
|
|
|
def shard_foo_array_handler(xs, shardings, layouts, copy_semantics):
|
2024-06-13 13:09:35 -07:00
|
|
|
results = []
|
|
|
|
for x, sharding in safe_zip(xs, shardings):
|
|
|
|
device, = sharding._addressable_device_assignment
|
2024-11-05 07:16:32 -08:00
|
|
|
aval = core.get_aval(x.data)
|
2024-06-13 13:09:35 -07:00
|
|
|
results.append(pxla.batched_device_put(
|
|
|
|
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]))
|
|
|
|
return results
|
2022-08-29 22:02:32 -07:00
|
|
|
|
2023-08-25 10:59:10 -07:00
|
|
|
def foo_array_constant_handler(x):
|
|
|
|
return array._array_mlir_constant_handler(x.data)
|
2022-08-05 22:18:53 -07:00
|
|
|
|
|
|
|
def make_lowering(*, shape):
|
|
|
|
return jnp.zeros((*shape, 2), 'uint32')
|
|
|
|
|
|
|
|
def bake_lowering(k):
|
|
|
|
return k.T
|
|
|
|
|
|
|
|
def take_lowering(k):
|
|
|
|
return jnp.broadcast_to(jnp.float32(k.size), k.shape)
|
|
|
|
|
2023-05-12 16:48:26 -07:00
|
|
|
def jake_lowering(k):
|
|
|
|
return jnp.ones((*k.shape, 2), 'uint32')
|
2022-08-05 22:18:53 -07:00
|
|
|
|
|
|
|
def bake_vmap(batched_args, batch_dims):
|
|
|
|
xs, = batched_args
|
|
|
|
bdim_in, = batch_dims
|
|
|
|
ys = bake(xs)
|
|
|
|
perm = list(reversed(range(xs.ndim)))
|
|
|
|
bdim_out = perm[bdim_in]
|
|
|
|
return ys, bdim_out
|
|
|
|
|
|
|
|
|
2025-01-10 06:58:01 -08:00
|
|
|
# All tests in this test class are thread-hostile because they add and remove
|
|
|
|
# primitives from global maps.
|
|
|
|
@jtu.thread_unsafe_test_class() # registration isn't thread-safe
|
2022-08-05 22:18:53 -07:00
|
|
|
class CustomElementTypesTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
core.pytype_aval_mappings[FooArray] = \
|
2025-02-03 17:59:44 -08:00
|
|
|
lambda x: core.ShapedArray(x.shape, FooTy(),
|
|
|
|
sharding=core.get_cur_mesh_sharding())
|
2022-08-05 22:18:53 -07:00
|
|
|
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
|
2022-08-29 22:02:32 -07:00
|
|
|
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
|
2022-08-05 22:18:53 -07:00
|
|
|
mlir._constant_handlers[FooArray] = foo_array_constant_handler
|
|
|
|
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
|
|
|
|
mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False))
|
|
|
|
mlir.register_lowering(take_p, mlir.lower_fun(take_lowering, False))
|
2023-05-12 16:48:26 -07:00
|
|
|
mlir.register_lowering(jake_p, mlir.lower_fun(jake_lowering, False))
|
2022-08-05 22:18:53 -07:00
|
|
|
batching.defvectorized(take_p)
|
|
|
|
batching.primitive_batchers[bake_p] = bake_vmap
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
del core.pytype_aval_mappings[FooArray]
|
|
|
|
del xla.canonicalize_dtype_handlers[FooArray]
|
|
|
|
del mlir._constant_handlers[FooArray]
|
|
|
|
del mlir._lowerings[make_p]
|
|
|
|
del mlir._lowerings[bake_p]
|
|
|
|
del mlir._lowerings[take_p]
|
|
|
|
del batching.primitive_batchers[take_p]
|
|
|
|
del batching.primitive_batchers[bake_p]
|
|
|
|
|
|
|
|
def test_shaped_array_construction(self):
|
|
|
|
aval = core.ShapedArray((), FooTy())
|
|
|
|
self.assertEqual(aval.str_short(), 'foo[]')
|
|
|
|
aval = core.ShapedArray((3, 4), FooTy())
|
|
|
|
self.assertEqual(aval.str_short(), 'foo[3,4]')
|
|
|
|
|
|
|
|
def test_make_jaxpr_identity(self):
|
|
|
|
x = types.SimpleNamespace(shape=(3,), dtype=FooTy())
|
|
|
|
jaxpr = jax.make_jaxpr(lambda x: x)(x).jaxpr
|
|
|
|
# { lambda ; a:foo[3]. let in (a,) }
|
|
|
|
self.assertLen(jaxpr.invars, 1)
|
|
|
|
a, = jaxpr.invars
|
|
|
|
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
|
|
|
|
self.assertLen(jaxpr.outvars, 1)
|
|
|
|
a, = jaxpr.outvars
|
|
|
|
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
|
|
|
|
|
|
|
|
# tests after here need the primitives
|
|
|
|
|
|
|
|
def test_make_jaxpr_with_primitives(self):
|
|
|
|
def f():
|
|
|
|
k1 = make((3, 4))
|
|
|
|
k2 = bake(k1)
|
|
|
|
x = take(k2)
|
|
|
|
return x
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(f)().jaxpr
|
|
|
|
# { lambda ; . let
|
|
|
|
# a:foo[3,4] = make[shape=(3, 4)]
|
|
|
|
# b:foo[4,3] = bake a
|
|
|
|
# c:f32[4,3] = take b
|
|
|
|
# in (c,) }
|
|
|
|
self.assertLen(jaxpr.invars, 0)
|
|
|
|
self.assertLen(jaxpr.eqns, 3)
|
|
|
|
e1, e2, e3 = jaxpr.eqns
|
|
|
|
|
|
|
|
self.assertIs(e1.primitive, make_p)
|
|
|
|
self.assertLen(e1.outvars, 1)
|
|
|
|
a, = e1.outvars
|
|
|
|
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
|
|
|
|
|
|
|
|
self.assertIs(e2.primitive, bake_p)
|
|
|
|
self.assertLen(e2.outvars, 1)
|
|
|
|
b, = e2.outvars
|
|
|
|
self.assertEqual(b.aval, core.ShapedArray((4, 3), FooTy()))
|
|
|
|
|
|
|
|
self.assertIs(e3.primitive, take_p)
|
|
|
|
self.assertLen(e3.outvars, 1)
|
|
|
|
c, = e3.outvars
|
|
|
|
self.assertEqual(c.aval, core.ShapedArray((4, 3), np.dtype('float32')))
|
|
|
|
|
|
|
|
# tests after here need FooArray and lowerings
|
|
|
|
|
|
|
|
def test_jit_closure(self):
|
|
|
|
k = FooArray((), jnp.arange(2, dtype='uint32'))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
jnp.add(1, 1) # make jit not hit trivial dispatch path
|
|
|
|
return k
|
|
|
|
|
|
|
|
y = f() # doesn't crash
|
|
|
|
self.assertIsInstance(y, FooArray)
|
|
|
|
self.assertEqual(y.shape, ())
|
|
|
|
|
|
|
|
def test_jit_identity(self):
|
|
|
|
k = FooArray((), jnp.arange(2, dtype='uint32'))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(k):
|
|
|
|
jnp.add(1, 1) # make jit not hit trivial dispatch path
|
|
|
|
return k
|
|
|
|
|
|
|
|
y = f(k) # doesn't crash
|
|
|
|
self.assertIsInstance(y, FooArray)
|
|
|
|
self.assertEqual(y.shape, ())
|
|
|
|
|
|
|
|
def test_jit_multiple_primitives(self):
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
k1 = make((3,))
|
|
|
|
k2 = bake(k1)
|
|
|
|
y = take(k2)
|
|
|
|
return y
|
|
|
|
|
|
|
|
y = f()
|
|
|
|
self.assertArraysAllClose(y, jnp.array([3., 3., 3.]), check_dtypes=False)
|
|
|
|
|
|
|
|
def test_scan_jaxpr(self):
|
|
|
|
ks = jax.jit(lambda: make((3, 4)))()
|
|
|
|
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
|
|
|
|
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
|
|
|
|
# { lambda ; a:foo[3,4]. let
|
|
|
|
# b:foo[3,4] = scan[
|
|
|
|
# jaxpr={ lambda ; c:foo[4]. let d:foo[4] = bake c in (d,) }
|
|
|
|
# ] a
|
|
|
|
# in (b,) }
|
|
|
|
self.assertLen(jaxpr.invars, 1)
|
|
|
|
a, = jaxpr.invars
|
|
|
|
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
|
|
|
|
self.assertLen(jaxpr.eqns, 1)
|
|
|
|
e, = jaxpr.eqns
|
|
|
|
self.assertLen(e.outvars, 1)
|
|
|
|
b, = e.outvars
|
|
|
|
self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy()))
|
|
|
|
|
2024-03-28 10:54:02 -07:00
|
|
|
def test_scan_jaxpr_split_transpose(self):
|
|
|
|
def stage(x, w):
|
|
|
|
x = x @ w
|
|
|
|
x = jnp.tanh(x)
|
|
|
|
return (x, ())
|
|
|
|
def loss(ws, x, split_transpose=False):
|
|
|
|
return jnp.sum(jax.lax.scan(stage, x, ws,
|
|
|
|
_split_transpose=split_transpose)[0])
|
|
|
|
|
|
|
|
def fn(*args, split_transpose=False):
|
|
|
|
v, fn_transpose = jax.vjp(
|
|
|
|
partial(loss, split_transpose=split_transpose), *args)
|
|
|
|
grads = fn_transpose(1.0)
|
|
|
|
return *grads, v
|
|
|
|
|
|
|
|
# x : [batch, d_model]
|
2024-03-28 16:22:46 -07:00
|
|
|
x = jax.random.uniform(jax.random.key(0), [256, 100])
|
2024-03-28 10:54:02 -07:00
|
|
|
# wss : [layers, d_model, d_model]
|
2024-03-28 16:22:46 -07:00
|
|
|
wss = jax.random.uniform(jax.random.key(1), [7, 100, 100])
|
2024-03-28 10:54:02 -07:00
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(partial(fn))(wss, x)
|
|
|
|
jaxpr_split_transpose = jax.make_jaxpr(partial(fn, split_transpose=True))(
|
|
|
|
wss, x
|
|
|
|
)
|
|
|
|
|
|
|
|
# Check that the shapes were preserved.
|
|
|
|
self.assertEqual(jaxpr.in_avals, jaxpr_split_transpose.in_avals)
|
|
|
|
self.assertEqual(jaxpr.out_avals, jaxpr_split_transpose.out_avals)
|
|
|
|
|
|
|
|
# The first two outvars (corresponding to gradients of params and inputs)
|
|
|
|
# must come from two different loops.
|
|
|
|
ct_ws = jaxpr_split_transpose.jaxpr.outvars[0]
|
|
|
|
ct_x = jaxpr_split_transpose.jaxpr.outvars[1]
|
|
|
|
|
|
|
|
# The last two equations are the two loops we care about
|
|
|
|
backprop_scan = jaxpr_split_transpose.jaxpr.eqns[-2]
|
|
|
|
self.assertEqual(backprop_scan.primitive, jax.lax.scan_p)
|
|
|
|
|
|
|
|
param_gradient_map = jaxpr_split_transpose.jaxpr.eqns[-1]
|
|
|
|
self.assertEqual(param_gradient_map.primitive, jax.lax.scan_p)
|
|
|
|
self.assertEqual(param_gradient_map.params['num_consts'], 0)
|
|
|
|
self.assertEqual(param_gradient_map.params['num_carry'], 0)
|
|
|
|
|
|
|
|
# Assert that parameter gradients come from the map.
|
|
|
|
self.assertEqual(ct_ws, param_gradient_map.outvars[0])
|
|
|
|
# And that activation gradients come from the scan.
|
|
|
|
self.assertEqual(ct_x, backprop_scan.outvars[0])
|
|
|
|
|
2022-08-05 22:18:53 -07:00
|
|
|
def test_scan_lowering(self):
|
|
|
|
ks = jax.jit(lambda: make((3, 4)))()
|
|
|
|
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
|
|
|
|
_, out = jax.jit(f)(ks) # doesn't crash
|
|
|
|
self.assertIsInstance(out, FooArray)
|
|
|
|
self.assertEqual(out.shape, (3, 4))
|
|
|
|
|
|
|
|
def test_vmap(self):
|
|
|
|
ks = jax.jit(lambda: make((3, 4, 5)))()
|
|
|
|
ys = jax.vmap(jax.jit(lambda k: take(bake(k))))(ks)
|
|
|
|
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
|
|
|
|
self.assertAllClose(ys, expected)
|
|
|
|
|
2022-08-09 19:13:34 -07:00
|
|
|
def test_slice(self):
|
|
|
|
ks = jax.jit(lambda: make((3, 4)))()
|
|
|
|
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (2, 4))
|
|
|
|
|
|
|
|
def test_dynamic_slice(self):
|
|
|
|
ks = jax.jit(lambda: make((3, 4)))()
|
|
|
|
ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (2, 4))
|
|
|
|
|
2022-08-05 22:18:53 -07:00
|
|
|
def test_transpose(self):
|
|
|
|
ks = jax.jit(lambda: make((3, 4)))()
|
|
|
|
ys = jax.jit(lambda x: x.T)(ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (4, 3))
|
|
|
|
|
2022-08-12 13:54:08 -07:00
|
|
|
def test_gather(self):
|
|
|
|
ks = jax.jit(lambda: make((3, 4)))()
|
|
|
|
ys = jax.jit(lambda x: x[1])(ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (4,))
|
|
|
|
|
|
|
|
ks = jax.jit(lambda: make((3, 4, 5)))()
|
|
|
|
|
|
|
|
ys = jax.jit(lambda x: x[1])(ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (4, 5))
|
|
|
|
|
|
|
|
ys = jax.jit(lambda x: x[1, 2:4])(ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (2, 5))
|
|
|
|
|
|
|
|
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (2,))
|
|
|
|
|
|
|
|
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (3, 2, 1))
|
|
|
|
|
2023-06-27 21:45:45 -07:00
|
|
|
def test_gather_batched_index_dtype(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/16557
|
2023-06-27 21:45:45 -07:00
|
|
|
dtype = jnp.int8
|
|
|
|
size = jnp.iinfo(dtype).max + 10
|
|
|
|
indices = jnp.zeros(size, dtype=dtype)
|
|
|
|
values = jnp.zeros((size, 1))
|
|
|
|
results = jax.vmap(lambda x, i: jnp.take(x, i, axis=0))(values, indices)
|
|
|
|
self.assertArraysEqual(results, jnp.zeros(size))
|
|
|
|
|
2023-05-12 15:29:34 -07:00
|
|
|
@parameterized.parameters([
|
|
|
|
(0,),
|
|
|
|
(slice(1),),
|
|
|
|
(np.array([0, 2]),),
|
|
|
|
(np.array([False, True, True]),)
|
|
|
|
])
|
|
|
|
def test_scatter(self, idx):
|
|
|
|
k = jax.jit(lambda: make(()))()
|
|
|
|
ks = jax.jit(lambda: make((3,)))()
|
|
|
|
ys = jax.jit(lambda x, y: x.at[idx].set(y))(ks, k)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (3,))
|
|
|
|
|
2023-05-12 16:48:26 -07:00
|
|
|
def test_equality(self):
|
|
|
|
eq = jax.jit(lambda k1, k2: k1 == k2)
|
|
|
|
ne = jax.jit(lambda k1, k2: k1 != k2)
|
|
|
|
|
|
|
|
k1 = jax.jit(lambda: make(()))()
|
|
|
|
k2 = jax.jit(lambda: jake(make(())))()
|
|
|
|
|
|
|
|
self.assertTrue(eq(k1, k1))
|
|
|
|
self.assertFalse(eq(k1, k2))
|
|
|
|
self.assertTrue(ne(k1, k2))
|
|
|
|
self.assertFalse(ne(k1, k1))
|
|
|
|
|
|
|
|
size = 5
|
|
|
|
idx = slice(2, 4)
|
|
|
|
ks = jax.jit(lambda k: jake(make((size,))).at[idx].set(k))(k1)
|
|
|
|
expected = jnp.zeros(size, dtype=bool).at[idx].set(True)
|
|
|
|
self.assertArraysEqual(eq(k1, ks), expected)
|
|
|
|
self.assertArraysEqual(ne(k1, ks), ~expected)
|
|
|
|
|
2023-05-08 18:44:24 -07:00
|
|
|
def test_select(self):
|
|
|
|
ks = jax.jit(lambda: make((3,)))()
|
|
|
|
cs = jnp.array([True, False, False])
|
|
|
|
ys = jax.jit(lax.select)(cs, ks, ks)
|
|
|
|
self.assertIsInstance(ys, FooArray)
|
|
|
|
self.assertEqual(ys.shape, (3,))
|
|
|
|
|
2022-09-23 12:06:35 -07:00
|
|
|
def test_xla_reverse_bug(self):
|
|
|
|
# Regression test for b/248295786
|
|
|
|
# This was an XLA bug related to an incorrect optimization of reverse
|
|
|
|
def f(x):
|
|
|
|
y = jnp.array([2, 5])
|
|
|
|
return lax.rev(x * y, (0,))
|
|
|
|
x = jnp.array([1, 2])
|
|
|
|
self.assertArraysEqual(f(x), jax.jit(f)(x))
|
|
|
|
|
2022-08-05 22:18:53 -07:00
|
|
|
# TODO(frostig,mattjj): more polymorphic primitives tests
|
|
|
|
|
2024-02-15 13:29:35 +02:00
|
|
|
|
|
|
|
class FunctionAccuracyTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
2024-03-21 23:35:29 +02:00
|
|
|
dict(testcase_name=f"_{dtype.__name__}", dtype=dtype)
|
|
|
|
for dtype in jtu.dtypes.supported([np.float32, np.float64, np.complex64, np.complex128]))
|
|
|
|
def testMPMathUtils(self, dtype):
|
|
|
|
try:
|
|
|
|
import mpmath
|
|
|
|
except ImportError as msg:
|
|
|
|
self.skipTest(f'could not import mpmath: {msg}')
|
|
|
|
|
|
|
|
prec = {np.float32: 24, np.float64: 53, np.complex64: 24, np.complex128: 53}[dtype]
|
|
|
|
is_complex = dtype().dtype.kind == 'c'
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
assert isinstance(x, mpmath.ctx_mp.mpnumeric)
|
|
|
|
assert x.context.prec == prec
|
|
|
|
assert isinstance(x, x.context.mpc if is_complex else x.context.mpf)
|
|
|
|
return x
|
|
|
|
|
|
|
|
ufunc = jtu.vectorize_with_mpmath(func, mpmath=mpmath)
|
|
|
|
|
|
|
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
|
|
|
if is_complex:
|
|
|
|
arr = jtu.complex_plane_sample(dtype=dtype, size_re=11)
|
|
|
|
else:
|
|
|
|
cdtype = getattr(np, ufunc.map_float_to_complex[dtype.__name__])
|
|
|
|
arr = jtu.complex_plane_sample(dtype=cdtype, size_re=11, size_im=0)[1:2].real
|
|
|
|
|
|
|
|
arr2 = ufunc.mptonp(ufunc.nptomp(arr))
|
|
|
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
|
|
|
self.assertAllClose(arr, arr2, atol=0, rtol=0)
|
|
|
|
|
|
|
|
arr3 = ufunc(arr)
|
|
|
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
|
|
|
self.assertAllClose(arr, arr3, atol=0, rtol=0)
|
|
|
|
|
|
|
|
if is_complex:
|
|
|
|
# tests scale in normalize
|
|
|
|
v = dtype(1.1071487177940644+1.1102230246251565e-16j)
|
|
|
|
r = dtype(1.1071487177940644+0j)
|
|
|
|
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1)
|
|
|
|
nr, nv = mnp.normalize(r, r, v)
|
|
|
|
self.assertAllClose(nr, nv)
|
|
|
|
|
|
|
|
_functions_on_complex_plane = [
|
|
|
|
'arccos', 'arccosh', 'arcsin', 'arcsinh',
|
|
|
|
'arctan', 'arctanh', 'conjugate', 'cos',
|
|
|
|
'cosh', 'exp', 'exp2', 'expm1', 'log',
|
|
|
|
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
|
|
|
|
'square', 'tan', 'tanh', 'sinc', 'positive',
|
|
|
|
'negative', 'absolute', 'sign'
|
|
|
|
]
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
|
|
|
|
for name, dtype in itertools.product(
|
|
|
|
_functions_on_complex_plane,
|
|
|
|
jtu.dtypes.supported([np.complex64, np.complex128]),
|
|
|
|
))
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def testSuccessOnComplexPlane(self, name, dtype):
|
|
|
|
self._testOnComplexPlaneWorker(name, dtype, 'success')
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
|
|
|
|
for name, dtype in itertools.product(
|
|
|
|
_functions_on_complex_plane,
|
2024-02-15 13:29:35 +02:00
|
|
|
jtu.dtypes.supported([np.complex64, np.complex128]),
|
|
|
|
))
|
|
|
|
@jtu.skip_on_devices("tpu")
|
2024-03-21 23:35:29 +02:00
|
|
|
def testFailureOnComplexPlane(self, name, dtype):
|
|
|
|
self._testOnComplexPlaneWorker(name, dtype, 'failure')
|
|
|
|
|
|
|
|
def _testOnComplexPlaneWorker(self, name, dtype, kind):
|
|
|
|
try:
|
|
|
|
import mpmath
|
|
|
|
except ImportError as msg:
|
|
|
|
self.skipTest(f'could not import mpmath: {msg}')
|
|
|
|
|
2024-02-15 13:29:35 +02:00
|
|
|
is_cpu = jtu.test_device_matches(["cpu"])
|
2024-03-04 08:25:17 -08:00
|
|
|
machine = platform.machine()
|
2024-03-21 23:35:29 +02:00
|
|
|
# TODO: remove is_arm_cpu as previously arm cpu related failures
|
|
|
|
# were due to numpy issues. Confirm?
|
2024-03-04 08:25:17 -08:00
|
|
|
is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm')
|
2024-02-15 13:29:35 +02:00
|
|
|
is_cuda = jtu.test_device_matches(["cuda"])
|
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
size_re = 11
|
|
|
|
size_im = 11
|
|
|
|
atol = None
|
2024-02-15 13:29:35 +02:00
|
|
|
|
2024-10-04 11:14:01 -04:00
|
|
|
if name in {"arccos", "arcsin", "arcsinh", "arccosh", "arctan", "arctanh"}:
|
2024-07-31 13:23:12 +03:00
|
|
|
# TODO(pearu): eliminate this if-block when a fix to mpmath#787
|
|
|
|
# becomes available
|
|
|
|
extra_prec_multiplier = 20
|
|
|
|
else:
|
|
|
|
extra_prec_multiplier = 1
|
|
|
|
|
|
|
|
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1, extra_prec_multiplier=extra_prec_multiplier)
|
|
|
|
mnp2 = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=extra_prec_multiplier)
|
2024-03-19 11:04:39 -07:00
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
ref_op = getattr(mnp, name)
|
|
|
|
ref2_op = getattr(mnp2, name)
|
2024-02-15 13:29:35 +02:00
|
|
|
jnp_op = getattr(jnp, name)
|
|
|
|
|
|
|
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
2024-03-21 23:35:29 +02:00
|
|
|
args = (jtu.complex_plane_sample(dtype=dtype, size_re=size_re, size_im=size_im),)
|
2024-02-15 13:29:35 +02:00
|
|
|
result = np.asarray(jnp_op(*args))
|
2024-03-21 23:35:29 +02:00
|
|
|
expected = ref_op(*args)
|
|
|
|
expected2 = ref2_op(*args)
|
|
|
|
|
|
|
|
normalized_expected, normalized_result = mnp2.normalize(expected2, expected, result)
|
|
|
|
|
|
|
|
# When comparing the results with expected, we'll divide the
|
|
|
|
# complex plane grid into smaller regions and perform the
|
|
|
|
# closeness tests on each region separately. The reason for this
|
|
|
|
# is that the inaccuracy or incorrectness issues with a particular
|
|
|
|
# function exists typically in specific regions while in other
|
|
|
|
# regions the function is accurate. So, such a division of the
|
|
|
|
# complex plane helps to identify the problematic regions as well
|
|
|
|
# as to fix the inaccuracy or incorrectness issues.
|
|
|
|
#
|
|
|
|
# Regions in complex plane:
|
|
|
|
#
|
|
|
|
# ( pinfj )
|
|
|
|
# ( q2 ) (posj) ( q1 )
|
|
|
|
# (ninf) ( neg ) (zero) ( pos ) (pinf)
|
|
|
|
# ( q3 ) (negj) ( q4 )
|
|
|
|
# ( ninfj )
|
|
|
|
#
|
|
|
|
# In addition, the 1/3 middle parts of regions q1, q2, q3, q4,
|
|
|
|
# neg, pos are tested separately as these don't contain extremely
|
|
|
|
# small or extremelly large values and functions on these regions
|
|
|
|
# ought not to possess any incorrectness issues.
|
|
|
|
|
|
|
|
s0, s1 = size_re, size_im
|
|
|
|
s03, s13 = s0 // 3, s1 // 3
|
2024-02-15 13:29:35 +02:00
|
|
|
s_dict = dict(
|
|
|
|
q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)),
|
|
|
|
q2=(slice(s0 + 2, -1), slice(1, s1 + 1)),
|
|
|
|
q3=(slice(1, s0 + 1), slice(1, s1 + 1)),
|
|
|
|
q4=(slice(1, s0 + 1), slice(s1 + 2, -1)),
|
|
|
|
neg=(s0 + 1, slice(1, s1 + 1)),
|
|
|
|
pos=(s0 + 1, slice(s1 + 2, -1)),
|
|
|
|
negj=(slice(1, s0 + 1), s1 + 1),
|
|
|
|
posj=(slice(s0 + 2, -1), s1 + 1),
|
|
|
|
ninf=(slice(None), 0),
|
|
|
|
pinf=(slice(None), -1),
|
|
|
|
ninfj=(0, slice(None)),
|
|
|
|
pinfj=(-1, slice(None)),
|
|
|
|
zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)),
|
|
|
|
)
|
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
if s03 and s13:
|
|
|
|
s_dict.update(
|
|
|
|
mq1 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
|
|
|
mq2 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
|
|
|
|
mq3 = (slice(2 + s03, 2 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
|
|
|
|
mq4 = (slice(2 + s03, 2 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
|
|
|
mneg=(s0 + 1, slice(2 + s13, 2 + 2 * s13)),
|
|
|
|
mpos=(s0 + 1, slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
|
|
|
mnegj=(slice(2 + s03, 2 + 2 * s03), s1 + 1),
|
|
|
|
mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1),
|
|
|
|
)
|
|
|
|
|
2024-04-04 11:19:57 +03:00
|
|
|
# The regions are split to real and imaginary parts (of function
|
|
|
|
# return values) to (i) workaround numpy 1.x assert_allclose bug
|
|
|
|
# in comparing complex infinities, and (ii) expose more details
|
|
|
|
# about failing cases:
|
|
|
|
s_dict_parts = dict()
|
|
|
|
for k, v in s_dict.items():
|
|
|
|
s_dict_parts[k + '.real'] = v
|
|
|
|
s_dict_parts[k + '.imag'] = v
|
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
# Start with an assumption that all regions are problematic for a
|
|
|
|
# particular function:
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies = list(s_dict_parts)
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
# Next, we'll remove non-problematic regions from the
|
|
|
|
# regions_with_inaccuracies list by explicitly keeping problematic
|
|
|
|
# regions:
|
|
|
|
def regions_with_inaccuracies_keep(*to_keep):
|
2024-04-04 11:19:57 +03:00
|
|
|
to_keep_parts = []
|
|
|
|
for r in to_keep:
|
|
|
|
if r.endswith('.real') or r.endswith('.imag'):
|
|
|
|
to_keep_parts.append(r)
|
|
|
|
else:
|
|
|
|
to_keep_parts.append(r + '.real')
|
|
|
|
to_keep_parts.append(r + '.imag')
|
2024-03-21 23:35:29 +02:00
|
|
|
for item in regions_with_inaccuracies[:]:
|
2024-04-04 11:19:57 +03:00
|
|
|
if item not in to_keep_parts:
|
2024-03-21 23:35:29 +02:00
|
|
|
regions_with_inaccuracies.remove(item)
|
|
|
|
|
|
|
|
if name == 'absolute':
|
|
|
|
if is_cuda and dtype == np.complex128:
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real')
|
2024-02-15 13:29:35 +02:00
|
|
|
else:
|
2024-03-21 23:35:29 +02:00
|
|
|
regions_with_inaccuracies.clear()
|
|
|
|
|
|
|
|
elif name == 'sign':
|
|
|
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
|
|
|
|
|
|
|
|
elif name == 'log':
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
elif name == 'log10':
|
2025-01-01 12:45:20 +02:00
|
|
|
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
elif name == 'exp':
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
elif name == 'exp2':
|
|
|
|
if dtype == np.complex64:
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos.imag', 'mnegj', 'mposj')
|
2024-03-21 23:35:29 +02:00
|
|
|
if dtype == np.complex128:
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mpos.imag')
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
elif name == 'sinc':
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4',
|
|
|
|
'mneg.real', 'mpos.real', 'mnegj', 'mposj',
|
|
|
|
'ninf.imag', 'pinf.imag', 'ninfj.real', 'pinfj.real')
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
elif name == 'sinh':
|
|
|
|
if is_cuda:
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg', 'pos',
|
|
|
|
'ninf.imag', 'pinf.imag', 'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos',
|
|
|
|
'ninfj.real', 'pinfj.real')
|
2024-03-21 23:35:29 +02:00
|
|
|
if is_cpu:
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj.imag', 'posj.imag', 'ninf.imag', 'pinf.imag',
|
|
|
|
'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos',
|
|
|
|
'ninfj.real', 'pinfj.real')
|
2024-03-21 23:35:29 +02:00
|
|
|
elif name == 'cosh':
|
2024-04-04 11:19:57 +03:00
|
|
|
regions_with_inaccuracies_keep('neg.imag', 'pos.imag', 'ninf.imag', 'pinf.imag', 'mneg.imag', 'mpos.imag',
|
|
|
|
'ninfj.imag', 'pinfj.imag')
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
elif name == 'tanh':
|
|
|
|
regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
|
|
|
|
|
2024-04-04 11:19:57 +03:00
|
|
|
elif name in {'cos', 'sin'}:
|
|
|
|
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
|
|
|
|
|
2024-09-22 22:00:13 +03:00
|
|
|
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan',
|
|
|
|
'arcsinh', 'arcsin', 'arccosh', 'arccos', 'arctan', 'arctanh', 'square'}:
|
2024-03-21 23:35:29 +02:00
|
|
|
regions_with_inaccuracies.clear()
|
2024-09-22 22:00:13 +03:00
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
else:
|
|
|
|
assert 0 # unreachable
|
|
|
|
|
|
|
|
# Finally, perform the closeness tests per region:
|
|
|
|
unexpected_success_regions = []
|
2024-04-04 11:19:57 +03:00
|
|
|
for region_name, region_slice in s_dict_parts.items():
|
2024-03-21 23:35:29 +02:00
|
|
|
region = args[0][region_slice]
|
2024-04-04 11:19:57 +03:00
|
|
|
if region_name.endswith('.real'):
|
|
|
|
result_slice, expected_slice = result[region_slice].real, expected[region_slice].real
|
|
|
|
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].real, normalized_expected[region_slice].real
|
|
|
|
elif region_name.endswith('.imag'):
|
|
|
|
result_slice, expected_slice = result[region_slice].imag, expected[region_slice].imag
|
|
|
|
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].imag, normalized_expected[region_slice].imag
|
|
|
|
else:
|
|
|
|
result_slice, expected_slice = result[region_slice], expected[region_slice]
|
|
|
|
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice], normalized_expected[region_slice]
|
|
|
|
|
|
|
|
inexact_indices = np.where(normalized_result_slice != normalized_expected_slice)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
if inexact_indices[0].size == 0:
|
|
|
|
inexact_samples = ''
|
|
|
|
else:
|
|
|
|
inexact_samples = []
|
|
|
|
for ind in zip(*inexact_indices):
|
|
|
|
x = region[ind]
|
|
|
|
y1, y2 = result[region_slice][ind], expected[region_slice][ind]
|
|
|
|
ny1, ny2 = normalized_result[region_slice][ind], normalized_expected[region_slice][ind]
|
|
|
|
if str(y1) == str(y2): # skip equal nan-s
|
|
|
|
continue
|
|
|
|
max_abs_diff = abs(ny1 - ny2).max() if np.isfinite(y1) and np.isfinite(y1) else np.inf
|
|
|
|
inexact_samples.append((max_abs_diff, f'jax.numpy.{name}({x}) -> {y1} [{ny1}], expected {y2} [{ny2}]'))
|
|
|
|
inexact_samples = "\n".join([msg for _, msg in sorted(inexact_samples)])
|
|
|
|
|
|
|
|
if kind == 'success' and region_name not in regions_with_inaccuracies:
|
|
|
|
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
|
|
|
self.assertAllClose(
|
2024-04-04 11:19:57 +03:00
|
|
|
normalized_result_slice, normalized_expected_slice, atol=atol,
|
2024-05-07 16:06:48 -07:00
|
|
|
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=},\n{inexact_samples}")
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
if kind == 'failure' and region_name in regions_with_inaccuracies:
|
|
|
|
try:
|
2024-05-07 16:06:48 -07:00
|
|
|
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}"):
|
2024-03-21 23:35:29 +02:00
|
|
|
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
2024-04-04 11:19:57 +03:00
|
|
|
self.assertAllClose(normalized_result_slice, normalized_expected_slice)
|
2024-03-21 23:35:29 +02:00
|
|
|
except AssertionError as msg:
|
|
|
|
if str(msg).startswith('AssertionError not raised'):
|
|
|
|
unexpected_success_regions.append(region_name)
|
|
|
|
else:
|
|
|
|
raise # something else is wrong..
|
|
|
|
|
2024-04-04 11:19:57 +03:00
|
|
|
def eliminate_parts(seq):
|
|
|
|
# replace n.real and n.imag items in seq with n.
|
|
|
|
result = []
|
|
|
|
for part_name in seq:
|
|
|
|
name = part_name.split('.')[0]
|
|
|
|
if name in result:
|
|
|
|
continue
|
|
|
|
if name + '.real' in seq and name + '.imag' in seq:
|
|
|
|
result.append(name)
|
|
|
|
else:
|
|
|
|
result.append(part_name)
|
|
|
|
return result
|
|
|
|
|
|
|
|
regions_with_inaccuracies = eliminate_parts(regions_with_inaccuracies)
|
|
|
|
unexpected_success_regions = eliminate_parts(unexpected_success_regions)
|
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
if kind == 'success' and regions_with_inaccuracies:
|
|
|
|
reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies)
|
|
|
|
raise unittest.SkipTest(reason)
|
|
|
|
|
|
|
|
if kind == 'failure':
|
|
|
|
if not regions_with_inaccuracies:
|
|
|
|
raise unittest.SkipTest("no problematic regions")
|
2024-03-26 12:35:54 +02:00
|
|
|
elif unexpected_success_regions:
|
|
|
|
# This skip ought to be effective only when fixing functions
|
|
|
|
# on problematic regions in XLA that should follow up a JAX PR
|
|
|
|
# that enables testing the functions on these regions for
|
|
|
|
# success.
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
f"detected success in regions {', '.join(unexpected_success_regions)}, please update regions_with_inaccuracies!"
|
|
|
|
)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
2024-12-18 19:37:58 -08:00
|
|
|
|
|
|
|
class CompositeTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_composite(self):
|
|
|
|
def my_square_impl(x):
|
|
|
|
return x ** 2
|
|
|
|
my_square = lax.composite(my_square_impl, name="my.square")
|
|
|
|
|
|
|
|
x = jnp.array(2.0, dtype=jnp.float32)
|
|
|
|
output = my_square(x)
|
|
|
|
self.assertEqual(output, jnp.array(4.0, dtype=jnp.float32))
|
|
|
|
|
|
|
|
mlir_module = jax.jit(my_square).lower(x).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
'stablehlo.composite "my.square" %arg0 {decomposition = @my.square} : '
|
|
|
|
'(tensor<f32>) -> tensor<f32>', mlir_module)
|
|
|
|
self.assertIn('@my.square(%arg0: tensor<f32>) -> tensor<f32> {', mlir_module)
|
|
|
|
self.assertIn('stablehlo.multiply %arg0, %arg0 : tensor<f32>', mlir_module)
|
|
|
|
|
|
|
|
def test_composite_decorator(self):
|
|
|
|
@partial(lax.composite, name="my.square")
|
|
|
|
def my_square(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
x = jnp.array(2.0, dtype=jnp.float32)
|
|
|
|
output = my_square(x)
|
|
|
|
self.assertEqual(output, jnp.array(4.0, dtype=jnp.float32))
|
|
|
|
|
|
|
|
mlir_module = jax.jit(my_square).lower(x).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
'stablehlo.composite "my.square" %arg0 {decomposition = @my.square} : '
|
|
|
|
'(tensor<f32>) -> tensor<f32>', mlir_module)
|
|
|
|
self.assertIn('@my.square(%arg0: tensor<f32>) -> tensor<f32> {', mlir_module)
|
|
|
|
self.assertIn('stablehlo.multiply %arg0, %arg0 : tensor<f32>', mlir_module)
|
|
|
|
|
|
|
|
def test_composite_with_jit_function(self):
|
|
|
|
def my_square_impl(x):
|
|
|
|
return x ** 2
|
|
|
|
my_square = jax.jit(lax.composite(my_square_impl, name="my.square"))
|
|
|
|
|
|
|
|
x = jnp.array(2.0, dtype=jnp.float32)
|
|
|
|
output = my_square(x)
|
|
|
|
self.assertEqual(output, jnp.array(4.0, dtype=jnp.float32))
|
|
|
|
|
|
|
|
mlir_module = my_square.lower(x).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
'stablehlo.composite "my.square" %arg0 {decomposition = @my.square} : '
|
|
|
|
'(tensor<f32>) -> tensor<f32>', mlir_module)
|
|
|
|
self.assertIn('@my.square(%arg0: tensor<f32>) -> tensor<f32> {', mlir_module)
|
|
|
|
self.assertIn('stablehlo.multiply %arg0, %arg0 : tensor<f32>', mlir_module)
|
|
|
|
|
|
|
|
def test_composite_with_attributes(self):
|
|
|
|
# The static_argnames is required here since k is a constant that should
|
|
|
|
# come out of a larger context, but we unit test one op (composite) here.
|
|
|
|
@partial(jax.jit, static_argnames=['k'])
|
|
|
|
@partial(lax.composite, name="my.top_k")
|
|
|
|
def my_top_k(x, *, k):
|
|
|
|
return lax.top_k(x, k)
|
|
|
|
|
|
|
|
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)
|
|
|
|
output, indices = my_top_k(x, k=3)
|
|
|
|
self.assertArraysEqual(output, jnp.array([5.0, 4.0, 3.0], dtype=jnp.float32))
|
|
|
|
self.assertArraysEqual(indices, jnp.array([4, 3, 2], dtype=jnp.int32))
|
|
|
|
|
|
|
|
mlir_module = my_top_k.lower(x, k=3).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
'stablehlo.composite "my.top_k" %arg0 '
|
|
|
|
'{composite_attributes = {k = 3 : i64}, decomposition = @my.top_k} : '
|
|
|
|
'(tensor<5xf32>) -> (tensor<3xf32>, tensor<3xi32>)', mlir_module)
|
|
|
|
self.assertIn('@my.top_k(%arg0: tensor<5xf32>) -> (tensor<3xf32>, tensor<3xi32>) {', mlir_module)
|
|
|
|
self.assertIn('chlo.top_k(%arg0, k = 3) : tensor<5xf32> -> (tensor<3xf32>, tensor<3xi32>)', mlir_module)
|
|
|
|
|
|
|
|
def test_composite_attribute_dtypes(self):
|
|
|
|
@jax.jit
|
|
|
|
def my_tangent_composite_with_attributes(x):
|
|
|
|
def decomposition(x, **_):
|
|
|
|
return lax.sin(x) / lax.cos(x)
|
2025-02-10 12:21:16 -08:00
|
|
|
return lax.composite(decomposition, "my.tangent")(
|
|
|
|
x,
|
|
|
|
dtype=np.dtype(np.float32),
|
|
|
|
int=1,
|
|
|
|
str="bar",
|
|
|
|
tensor=np.zeros((1, 2), dtype=np.float32),
|
|
|
|
tensor_r1=np.zeros((2,), dtype=np.float32),
|
|
|
|
)
|
2024-12-18 19:37:58 -08:00
|
|
|
|
|
|
|
pi = jnp.pi
|
|
|
|
x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi], dtype=jnp.float32)
|
|
|
|
output = my_tangent_composite_with_attributes(x)
|
|
|
|
self.assertArraysAllClose(
|
|
|
|
output, jnp.array([0.0, 1.0, -1.0, 0.0], dtype=jnp.float32)
|
|
|
|
)
|
|
|
|
|
|
|
|
mlir_module = my_tangent_composite_with_attributes.lower(x).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
'stablehlo.composite "my.tangent" %arg0 {composite_attributes = {'
|
2025-02-10 12:21:16 -08:00
|
|
|
'dtype = f32, int = 1 : i64, str = "bar", '
|
2025-01-31 11:05:23 -08:00
|
|
|
'tensor = dense<0.000000e+00> : tensor<1x2xf32>, '
|
|
|
|
'tensor_r1 = dense<0.000000e+00> : tensor<2xf32>}, '
|
2024-12-18 19:37:58 -08:00
|
|
|
'decomposition = @my.tangent} : (tensor<4xf32>) -> tensor<4xf32>',
|
|
|
|
mlir_module)
|
|
|
|
self.assertIn("func.func private @my.tangent", mlir_module)
|
|
|
|
|
2025-02-07 06:01:38 -08:00
|
|
|
def test_composite_unsupported_attribute_dtypes(self):
|
|
|
|
|
|
|
|
def my_tangent_composite_with_attributes(x):
|
|
|
|
def decomposition(x, **_):
|
|
|
|
return lax.sin(x) / lax.cos(x)
|
|
|
|
return lax.composite(decomposition, "my.tangent")(
|
|
|
|
x, tensor=jnp.zeros((1, 2), dtype=jnp.float32)
|
|
|
|
)
|
|
|
|
|
|
|
|
pi = jnp.pi
|
|
|
|
x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi], dtype=jnp.float32)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
UnexpectedTracerError,
|
|
|
|
"Note: If you are passing jax arrays as attributes, use numpy arrays "
|
|
|
|
"instead."
|
|
|
|
):
|
|
|
|
jax.jit(my_tangent_composite_with_attributes).lower(x).as_text()
|
|
|
|
|
2024-12-18 19:37:58 -08:00
|
|
|
def test_composite_with_non_default_version(self):
|
|
|
|
@partial(lax.composite, name="my.square", version=1)
|
|
|
|
def my_square_with_version(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
x = jnp.array(2.0, dtype=jnp.float32)
|
|
|
|
out = my_square_with_version(x)
|
|
|
|
self.assertEqual(out, 4.0)
|
|
|
|
|
|
|
|
mlir_module = jax.jit(my_square_with_version).lower(x).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
'stablehlo.composite "my.square" %arg0 {decomposition = @my.square, '
|
|
|
|
'version = 1 : i32} : (tensor<f32>) -> tensor<f32>', mlir_module)
|
|
|
|
|
|
|
|
def test_composite_with_no_args(self):
|
|
|
|
@partial(lax.composite, name="my.one")
|
|
|
|
def one():
|
|
|
|
return jnp.array(1.0, dtype=jnp.float32)
|
|
|
|
|
|
|
|
out = one()
|
|
|
|
self.assertEqual(out, jnp.array(1.0, dtype=jnp.float32))
|
|
|
|
|
|
|
|
mlir_module = jax.jit(one).lower().as_text()
|
|
|
|
self.assertIn('stablehlo.composite "my.one"', mlir_module)
|
|
|
|
self.assertIn('{decomposition = @my.one} : () -> tensor<f32>', mlir_module)
|
|
|
|
self.assertIn('@my.one() -> tensor<f32>', mlir_module)
|
|
|
|
self.assertIn('stablehlo.constant dense<1.000000e+00> : tensor<f32>', mlir_module)
|
|
|
|
|
|
|
|
def test_composite_with_variadic_input_output(self):
|
|
|
|
@partial(lax.composite, name="my.ident")
|
|
|
|
def ident(*args):
|
|
|
|
return args
|
|
|
|
|
|
|
|
x = jnp.array(1.0, dtype=jnp.float32)
|
|
|
|
y = jnp.array(2.0, dtype=jnp.float32)
|
|
|
|
z = jnp.array(3.0, dtype=jnp.float32)
|
|
|
|
a, b, c = ident(x, y, z)
|
|
|
|
self.assertEqual(a, x)
|
|
|
|
self.assertEqual(b, y)
|
|
|
|
self.assertEqual(c, z)
|
|
|
|
|
|
|
|
mlir_module = jax.jit(ident).lower(x, y, z).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
'stablehlo.composite "my.ident" %arg0, %arg1, %arg2 '
|
|
|
|
'{decomposition = @my.ident} : (tensor<f32>, tensor<f32>, tensor<f32>) '
|
|
|
|
'-> (tensor<f32>, tensor<f32>, tensor<f32>)', mlir_module)
|
|
|
|
self.assertIn(
|
|
|
|
'@my.ident(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) '
|
|
|
|
'-> (tensor<f32>, tensor<f32>, tensor<f32>)', mlir_module)
|
|
|
|
self.assertIn('return %arg0, %arg1, %arg2 : tensor<f32>, tensor<f32>, tensor<f32>', mlir_module)
|
|
|
|
|
|
|
|
def test_composite_jvp(self):
|
|
|
|
@partial(lax.composite, name="my.square")
|
|
|
|
def my_square(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"JVP rule for composite not implemented. You can use `jax.custom_jvp` "
|
|
|
|
"to add support. See "
|
|
|
|
"https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html"
|
|
|
|
):
|
|
|
|
jvp(my_square, (1.0,), (2.0,))
|
|
|
|
|
|
|
|
def test_composite_grad(self):
|
|
|
|
@partial(lax.composite, name="my.square")
|
|
|
|
def my_square(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"JVP rule for composite not implemented. You can use `jax.custom_jvp` "
|
|
|
|
"to add support. See "
|
|
|
|
"https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html"
|
|
|
|
):
|
|
|
|
grad(my_square)(1.0)
|
|
|
|
|
2025-01-21 13:28:08 -08:00
|
|
|
def test_composite_with_array_consts(self):
|
|
|
|
@partial(lax.composite, name="my.consts")
|
|
|
|
def my_consts(x, /, *, scale):
|
|
|
|
return jnp.round(x / scale)
|
|
|
|
|
|
|
|
scale = np.array([0.5, 0.4, 0.3], dtype=np.float32)
|
|
|
|
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
|
|
|
|
self.assertAllClose(my_consts(x, scale=scale), jnp.round(x / scale))
|
|
|
|
|
|
|
|
# The constant must not appear as an extra input argument to the composite.
|
|
|
|
mlir_module = jax.jit(partial(my_consts, scale=scale)).lower(x).as_text()
|
|
|
|
self.assertIn(
|
|
|
|
"@my.consts(%arg0: tensor<3xf32>) -> tensor<3xf32>", mlir_module
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_composite_with_tracer_consts(self):
|
|
|
|
def fun(x, scale):
|
|
|
|
@partial(lax.composite, name="my.consts")
|
|
|
|
def my_consts(y):
|
|
|
|
return jnp.round(y / scale)
|
|
|
|
return my_consts(x)
|
|
|
|
|
|
|
|
scale = jnp.array([0.5, 0.4, 0.3], dtype=jnp.float32)
|
|
|
|
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
|
|
|
|
self.assertAllClose(fun(x, scale), jnp.round(x / scale))
|
|
|
|
self.assertAllClose(
|
|
|
|
jax.jit(partial(fun, scale=scale))(x), jnp.round(x / scale))
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
UnexpectedTracerError,
|
|
|
|
"Found a JAX Tracer as a constant in the decomposition for the "
|
|
|
|
"composite op 'my.consts'."):
|
|
|
|
jax.jit(fun)(x, scale)
|
|
|
|
|
2024-12-22 07:49:43 -08:00
|
|
|
|
|
|
|
class RaggedTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
[
|
|
|
|
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
|
|
|
|
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
|
|
|
|
],
|
|
|
|
dtype=jtu.dtypes.numeric,
|
|
|
|
)
|
2025-01-31 16:43:31 -08:00
|
|
|
def test_ragged_dot(self, m, k, n, num_groups, dtype):
|
2024-12-22 07:49:43 -08:00
|
|
|
"""Tests ragged_dot.
|
|
|
|
|
|
|
|
The ragged_dot is tested against numpy reference implementation, and by
|
|
|
|
running JAX compilation.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
SkipTest: in the case dtype is not supported.
|
|
|
|
"""
|
|
|
|
lhs_shape = (m, k)
|
|
|
|
rhs_shape = (num_groups, k, n)
|
|
|
|
|
|
|
|
def group_sizes(m, num_groups):
|
|
|
|
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
|
|
|
|
ends = jnp.concatenate(
|
|
|
|
[ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
|
|
|
|
starts = jnp.concatenate(
|
|
|
|
[jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
|
|
|
|
return ends - starts
|
|
|
|
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
args_maker = lambda: [
|
|
|
|
rng(lhs_shape, dtype),
|
|
|
|
rng(rhs_shape, dtype),
|
|
|
|
group_sizes(m, num_groups),
|
|
|
|
]
|
|
|
|
self._CompileAndCheck(lax.ragged_dot, args_maker)
|
|
|
|
self._CheckAgainstNumpy(
|
|
|
|
lax_reference.ragged_dot, lax.ragged_dot, args_maker)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == '__main__':
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|