# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import partial import itertools import math import operator import platform import types import unittest from unittest import SkipTest from absl.testing import absltest from absl.testing import parameterized import numpy as np import jax from jax._src import core from jax import jvp, grad from jax import lax import jax.numpy as jnp from jax.test_util import check_grads import jax.util from jax.interpreters import batching from jax.interpreters import xla from jax._src import array from jax._src import config from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map config.parse_flags_with_absl() ### lax tests # We check cases where the preferred type is at least as wide as the input # type and where both are either both floating-point or both integral, # which are the only supported configurations. preferred_type_combinations = [ (np.float16, np.float16), (np.float16, np.float32), (np.float16, np.float64), (dtypes.bfloat16, dtypes.bfloat16), (dtypes.bfloat16, np.float32), (dtypes.bfloat16, np.float64), (np.float32, np.float32), (np.float32, np.float64), (np.float64, np.float64), (np.int8, np.int8), (np.int8, np.int16), (np.int8, np.int32), (np.int8, np.int64), (np.int16, np.int16), (np.int16, np.int32), (np.int16, np.int64), (np.int32, np.int32), (np.int32, np.int64), (np.int64, np.int64), (np.complex64, np.complex64), (np.complex64, np.complex128), (np.complex128, np.complex128), (np.int8, np.float16), (np.int8, dtypes.bfloat16), (np.int8, np.float32), (np.int8, np.float64), (np.int16, np.float16), (np.int16, dtypes.bfloat16), (np.int16, np.float32), (np.int16, np.float64), (np.int32, np.float32), (np.int32, np.float64), (np.int64, np.float64)] def _reduce_custom_add(x, y): return x + y def _reduce_custom_mul(x, y): return x * y def _reduce_custom_sub(x, y): return x - y def _reduce_custom_min(x, y): return jnp.minimum(x, y) def _reduce_custom_max(x, y): return jnp.maximum(x, y) class LaxTest(jtu.JaxTestCase): """Numerical tests for LAX operations.""" @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(op_name=rec.op, rng_factory=rec.rng_factory)], shapes=itertools.chain.from_iterable( itertools.combinations_with_replacement(shape_group, rec.nargs) for shape_group in lax_test_util.compatible_shapes), dtype=rec.dtypes) for rec in lax_test_util.lax_ops())) def testOp(self, op_name, rng_factory, shapes, dtype): rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) for shape in shapes] op = getattr(lax, op_name) self._CompileAndCheck(op, args_maker) @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)], shapes=itertools.chain.from_iterable( itertools.combinations_with_replacement(shape_group, rec.nargs) for shape_group in lax_test_util.compatible_shapes), dtype=rec.dtypes) for rec in lax_test_util.lax_ops())) @jtu.ignore_warning(message="invalid value", category=RuntimeWarning) def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol): if (not config.enable_x64.value and op_name == "nextafter" and dtype == np.float64): raise SkipTest("64-bit mode disabled") rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) for shape in shapes] op = getattr(lax, op_name) numpy_op = getattr(lax_reference, op_name) tol = tol or jtu.default_tolerance() if jtu.test_device_matches(["tpu"]): if dtype in (np.float32, np.complex64) and op_name in ( "acosh", "asinh", "betainc", "cos", "cosh", "digamma", "exp", "exp2", "igamma", "igammac", "log", "log1p", "logistic", "pow", "sin", "sinh", "tan"): tol = jtu.join_tolerance(tol, 2e-4) elif op_name == "asinh" and dtype == np.float16: tol = jtu.join_tolerance(tol, 1e-3) elif op_name == "lgamma" and dtype == np.float32: tol = jtu.join_tolerance(tol, 1e-3) elif op_name == "pow" and dtype == np.complex128: tol = jtu.join_tolerance(tol, 2e-15) self._CheckAgainstNumpy(numpy_op, op, args_maker, tol=tol) # TODO test shift_left, shift_right_arithmetic, shift_right_logical @jtu.sample_product( [dict(from_dtype=from_dtype, to_dtype=to_dtype) for from_dtype, to_dtype in itertools.product( [None, np.float32, np.int32, "float32", "int32"], repeat=2)], weak_type=[False, True], ) def testConvertElementType(self, from_dtype, to_dtype, weak_type): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng((2, 3), from_dtype)] op = lambda x: lax_internal._convert_element_type(x, to_dtype, weak_type) self._CompileAndCheck(op, args_maker) x = rng((1,), from_dtype) out = op(x) self.assertEqual(out.dtype, dtypes.canonicalize_dtype(to_dtype or x.dtype)) self.assertEqual(out.aval.weak_type, weak_type) def testConvertElementTypeOOB(self): out = lax.convert_element_type(2 ** 32, 'int32') self.assertEqual(out, 0) @jtu.sample_product( [dict(from_dtype=from_dtype, to_dtype=to_dtype) for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2)], ) def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng((2, 3), from_dtype)] op = lambda x: lax.convert_element_type(x, to_dtype) numpy_op = lambda x: lax_reference.convert_element_type(x, to_dtype) self._CheckAgainstNumpy(numpy_op, op, args_maker) @jtu.sample_product( from_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned, to_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned, shape = [(), (2,), (2, 3)] ) def testBitcastConvertType(self, from_dtype, to_dtype, shape): rng = jtu.rand_default(self.rng()) nbits_in = dtypes.bit_width(from_dtype) nbits_out = dtypes.bit_width(to_dtype) if nbits_in < nbits_out: shape = (*shape, nbits_out // nbits_in) args_maker = lambda: [rng(shape, from_dtype)] jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CompileAndCheck(jnp_op, args_maker) # Test the shape and dtype of the output. We avoid testing the values here # because the bitwise representation may vary from platform to platform. out = jnp_op(*args_maker()) if nbits_in == nbits_out: expected_shape = shape elif nbits_in < nbits_out: expected_shape = shape[:-1] else: expected_shape = (*shape, nbits_in // nbits_out) self.assertEqual(out.dtype, to_dtype) self.assertEqual(out.shape, expected_shape) @jtu.sample_product( [dict(from_dtype=from_dtype, to_dtype=to_dtype) for from_dtype, to_dtype in itertools.product( ['int4', 'uint4', np.int8, np.uint8, np.int32, np.float16, np.float32], repeat=2)], shape=[(4,), (2, 4), (2, 3, 4)] ) def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape): nbits_in = dtypes.bit_width(from_dtype) nbits_out = dtypes.bit_width(to_dtype) if nbits_in < nbits_out: shape = (*shape, nbits_out // nbits_in) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, from_dtype)] jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype) np_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype) self._CheckAgainstNumpy(np_op, jnp_op, args_maker) @jtu.sample_product( [dict(from_dtype=from_dtype, to_dtype=to_dtype) for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2)], weak_type=[False, True], ) def testBitcastConvertWeakType(self, from_dtype, to_dtype, weak_type): rng = jtu.rand_default(self.rng()) x_in = lax_internal._convert_element_type(rng((2, 3), from_dtype), weak_type=weak_type) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self.assertEqual(dtypes.is_weakly_typed(x_in), weak_type) x_out = op(x_in) self.assertEqual(dtypes.is_weakly_typed(x_out), False) x_out_jit = jax.jit(op)(x_in) self.assertEqual(dtypes.is_weakly_typed(x_out_jit), False) @jtu.sample_product( [dict(min_shape=min_shape, operand_shape=operand_shape, max_shape=max_shape) for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ]], dtype=lax_test_util.default_dtypes, ) def testClamp(self, min_shape, operand_shape, max_shape, dtype): rng = jtu.rand_default(self.rng()) shapes = [min_shape, operand_shape, max_shape] args_maker = lambda: [rng(shape, dtype) for shape in shapes] self._CompileAndCheck(lax.clamp, args_maker) @jtu.sample_product( [dict(min_shape=min_shape, operand_shape=operand_shape, max_shape=max_shape) for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ]], dtype=lax_test_util.default_dtypes, ) def testClampAgainstNumpy(self, min_shape, operand_shape, max_shape, dtype): rng = jtu.rand_default(self.rng()) shapes = [min_shape, operand_shape, max_shape] args_maker = lambda: [rng(shape, dtype) for shape in shapes] self._CheckAgainstNumpy(lax_reference.clamp, lax.clamp, args_maker) @jtu.sample_product( [dict(base_shape=shape, dim=dim) for shape in [(4,), (3, 4), (2, 3, 4)] for dim in range(len(shape))], num_arrs=[3], dtype=lax_test_util.default_dtypes, ) def testConcatenate(self, dim, base_shape, dtype, num_arrs): rng = jtu.rand_default(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] args_maker = lambda: [rng(shape, dtype) for shape in shapes] op = lambda *args: lax.concatenate(args, dim) self._CompileAndCheck(op, args_maker) @jtu.sample_product( [dict(base_shape=shape, dim=dim) for shape in [(4,), (3, 4), (2, 3, 4)] for dim in range(len(shape))], num_arrs=[3], dtype=lax_test_util.default_dtypes, ) def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs): rng = jtu.rand_default(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] args_maker = lambda: [rng(shape, dtype) for shape in shapes] op = lambda *args: lax.concatenate(args, dim) numpy_op = lambda *args: lax_reference.concatenate(args, dim) self._CheckAgainstNumpy(numpy_op, op, args_maker) @jtu.sample_product( [dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(len(shape))], num_pieces=range(3), dtype=lax_test_util.default_dtypes, ) def testSplit(self, axis, base_shape, dtype, num_pieces): sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) shape = list(base_shape) shape[axis] = np.sum(sizes) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] op = lambda x: lax.split(x, sizes, axis=axis) def numpy_op(x): return np.split(x, np.cumsum(sizes[:-1]), axis=axis) self._CompileAndCheck(op, args_maker) self._CheckAgainstNumpy(numpy_op, op, args_maker) def testSplitErrors(self): with self.assertRaisesRegex(ValueError, "Sizes passed to split must be nonnegative"): lax.split(np.arange(5), [-1]) with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"): lax.split(np.arange(5), [6]) with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"): lax.split(np.arange(5), sizes=(), axis=1) @jtu.sample_product( [ dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) for b, i, j in itertools.product([2, 3], repeat=3) ], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1)], padding=["VALID", "SAME", "SAME_LOWER"], ) def testConv(self, lhs_shape, rhs_shape, dtype, strides, padding): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] def fun(lhs, rhs): return lax.conv(lhs, rhs, strides, padding) self._CompileAndCheck(fun, args_maker) @jtu.sample_product( [dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) for b, i, j in itertools.product([2, 3], repeat=3)], [dict(dtype=dtype, preferred_element_type=preferred) for dtype, preferred in preferred_type_combinations] ) @jax.default_matmul_precision("float32") def testConvPreferredElement(self, lhs_shape, rhs_shape, dtype, preferred_element_type): if (not config.enable_x64.value and (dtype == np.float64 or preferred_element_type == np.float64 or dtype == np.int64 or preferred_element_type == np.int64 or dtype == np.complex128 or preferred_element_type == np.complex128)): raise SkipTest("64-bit mode disabled") if (jtu.test_device_matches(["tpu"]) and (dtype == np.complex128 or preferred_element_type == np.complex128)): raise SkipTest("np.complex128 is not yet supported on TPU") if jtu.test_device_matches(["gpu"]) and np.issubdtype(dtype, np.integer): # TODO(b/183565702): Support integer convolutions on CPU/GPU. raise SkipTest("Integer convolution not yet supported on GPU") # x64 implementation is only accurate to ~float32 precision for this case. if dtype == np.complex64 and preferred_element_type == np.complex128: tol = 1e-5 else: tol = {np.float64: 1e-14} if (jtu.test_device_matches(["tpu"]) and dtype == np.float16 and preferred_element_type == np.float32): tol = 2e-3 if (jtu.test_device_matches(["tpu"]) and dtype == jnp.bfloat16 and preferred_element_type == np.float32): tol = 1e-5 rng = jtu.rand_default(self.rng()) x = rng(lhs_shape, dtype) y = rng(rhs_shape, dtype) # We first compute the conv when both inputs are a lower-precision type and # preferred_element_type is a higher-precision type. We then compute results # where the inputs are first upcast to the higher-precision type and no # `preferred_element_type` is given. We expect the result to be extremely # similar given the semantics of `preferred_element_type`. result_with_preferred_type = lax.conv( x, y, (1, 1), "VALID", preferred_element_type=preferred_element_type) result_with_upcast_inputs = lax.conv( x.astype(preferred_element_type), y.astype(preferred_element_type), (1, 1), "VALID") self.assertArraysAllClose( result_with_preferred_type, result_with_upcast_inputs, rtol=tol, atol=tol) @jtu.sample_product( [dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) for b, i, j in itertools.product([2, 3], repeat=3)], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1)], padding=["VALID", "SAME"], ) def testConvAgainstNumpy(self, lhs_shape, rhs_shape, dtype, strides, padding): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] op = lambda lhs, rhs: lax.conv(lhs, rhs, strides, padding) numpy_op = lambda lhs, rhs: lax_reference.conv(lhs, rhs, strides, padding) self._CheckAgainstNumpy(numpy_op, op, args_maker) @jtu.sample_product( [dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) for b, i, j in itertools.product([1, 2, 3], repeat=3)], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1)], padding=[((0, 0), (0, 0)), ((1, 2), (2, 0))], lhs_dilation=[(1, 1), (1, 2), (2, 2)], rhs_dilation=[(1, 1), (1, 2), (2, 2)], ) def testConvWithGeneralPadding(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation, rhs_dilation): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] def fun(lhs, rhs): return lax.conv_with_general_padding( lhs, rhs, strides, padding, lhs_dilation, rhs_dilation) self._CompileAndCheck(fun, args_maker) @jtu.sample_product( [dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) for b, i, j in itertools.product([1, 2, 3], repeat=3)], dtype=[np.float32], strides=[(1, 1), (1, 2), (2, 1)], padding=[((0, 0), (0, 0)), ((1, 2), (2, 0))], lhs_dilation=[(1, 1), (1, 2), (2, 2)], rhs_dilation=[(1, 1), (1, 2), (2, 2)], ) def testConvWithGeneralPaddingAgainstNumpy( self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation, rhs_dilation): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] def fun(lhs, rhs): return lax.conv_with_general_padding( lhs, rhs, strides, padding, lhs_dilation, rhs_dilation, precision=lax.Precision.HIGHEST) def numpy_fun(lhs, rhs): return lax_reference.conv_with_general_padding( lhs, rhs, strides, padding, lhs_dilation, rhs_dilation) self._CheckAgainstNumpy(numpy_fun, fun, args_maker) @jtu.sample_product( [ dict( lhs_shape=(b * batch_group_count, i * feature_group_count), rhs_shape=(j * feature_group_count * batch_group_count, i), batch_group_count=batch_group_count, feature_group_count=feature_group_count, ) for batch_group_count, feature_group_count in [(1, 1), (2, 1), (1, 2)] for b, i, j in itertools.product([2, 3], repeat=3) ], [dict(dimension_numbers=("NC", "OI", "NC"), perms=([0, 1], [0, 1]))], dtype=lax_test_util.all_dtypes, ) def testConvGeneralDilated0D(self, lhs_shape, rhs_shape, dtype, feature_group_count, batch_group_count, dimension_numbers, perms): if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_): # TODO(b/183565702): Support integer convolutions on CPU/GPU. if jtu.test_device_matches(["gpu"]): raise SkipTest("Integer convolution not yet supported on GPU") rng = jtu.rand_small(self.rng()) lhs_perm, rhs_perm = perms # permute to compatible shapes def args_maker(): return [lax.transpose(rng(lhs_shape, dtype), lhs_perm), lax.transpose(rng(rhs_shape, dtype), rhs_perm)] def fun(lhs, rhs): return lax.conv_general_dilated( lhs, rhs, window_strides=(), padding=(), dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count) self._CompileAndCheck(fun, args_maker) @jtu.sample_product( [ dict( lhs_shape=(b * batch_group_count, i * feature_group_count, 9, w), rhs_shape=(j * feature_group_count * batch_group_count, i, 4, 5), batch_group_count=batch_group_count, feature_group_count=feature_group_count, ) for batch_group_count, feature_group_count in [(1, 1), (2, 1), (1, 2)] for w in [0, 10] for b, i, j in itertools.product([2, 3], repeat=3) ], [ dict( dimension_numbers=("NCHW", "OIHW", "NCHW"), perms=([0, 1, 2, 3], [0, 1, 2, 3]), ), dict( dimension_numbers=("NHWC", "HWIO", "NHWC"), perms=([0, 2, 3, 1], [2, 3, 1, 0]), ), dict( dimension_numbers=("NCHW", "HWIO", "NHWC"), perms=([0, 1, 2, 3], [2, 3, 1, 0]), ), ], dtype=lax_test_util.all_dtypes, strides=[(1, 1), (2, 1)], padding=[((1, 2), (2, 0)), ((10, 8), (7, 13))], lhs_dilation=[(1, 1), (1, 2), (1, 4)], rhs_dilation=[(1, 1), (1, 2), (1, 4)], ) def testConvGeneralDilated(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation, rhs_dilation, feature_group_count, batch_group_count, dimension_numbers, perms): if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_): # TODO(b/183565702): Support integer convolutions on CPU/GPU. if jtu.test_device_matches(["gpu"]): raise SkipTest("Integer convolution not yet supported on GPU") rng = jtu.rand_small(self.rng()) lhs_perm, rhs_perm = perms # permute to compatible shapes def args_maker(): return [lax.transpose(rng(lhs_shape, dtype), lhs_perm), lax.transpose(rng(rhs_shape, dtype), rhs_perm)] def fun(lhs, rhs): return lax.conv_general_dilated( lhs, rhs, strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count) self._CompileAndCheck(fun, args_maker) @jax.default_matmul_precision("float32") def testConvGeneralDilatedPatchesOverlapping1D(self): lhs = np.array([[1]], np.float32).reshape((1, 1)) patches = lax.conv_general_dilated_patches( lhs=lhs, filter_shape=(), window_strides=(), padding='SAME' ) self.assertAllClose(lhs, patches) dn = ('NHC', 'OIH', 'NHC') lhs = np.array([1, 2, 3, 4, 5], np.float32).reshape((1, -1, 1)) patches = lax.conv_general_dilated_patches( lhs=lhs, filter_shape=(2,), window_strides=(2,), padding='VALID', dimension_numbers=dn ) self.assertAllClose( np.array([[1, 2], [3, 4]], np.float32).reshape((1, 2, 2)), patches) patches = lax.conv_general_dilated_patches( lhs=lhs, filter_shape=(3,), window_strides=(1,), padding='SAME', dimension_numbers=dn ) self.assertAllClose( np.array([[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 0]], np.float32).reshape((1, 5, 3)), patches) patches = lax.conv_general_dilated_patches( lhs=lhs, filter_shape=(3,), window_strides=(1,), padding='SAME', rhs_dilation=(2,), dimension_numbers=dn ) self.assertAllClose( np.array([[0, 1, 3], [0, 2, 4], [1, 3, 5], [2, 4, 0], [3, 5, 0]], np.float32).reshape((1, 5, 3)), patches) def testConvGeneralDilatedPatchesOverlapping2D(self): lhs = np.array([[1, 2, 3], [4, 5, 6]], np.float32).reshape((1, 2, 3, 1)) patches = lax.conv_general_dilated_patches( lhs=lhs, filter_shape=(2, 2), window_strides=(1, 1), padding='SAME', dimension_numbers=('NHWC', 'OIHW', 'NHWC') ) self.assertAllClose(np.array([[1, 2, 4, 5], [2, 3, 5, 6], [3, 0, 6, 0], [4, 5, 0, 0], [5, 6, 0, 0], [6, 0, 0, 0]], np.float32).reshape((1, 2, 3, 4)), patches) @jtu.sample_product( [ dict( lhs_shape=lhs_shape, filter_shape=filter_shape, strides=strides, padding=padding, dimension_numbers=dim_nums, ) for lhs_shape, filter_shape, strides, padding, dim_nums in [ ((2, 5), (), (), [], ("NC", "OI", "CN")), ((2, 3, 4), (2,), (2,), [(0, 2)], ("CNH", "OHI", "HNC")), ( (3, 1, 4, 5), (1, 3), (1, 3), [(3, 1), (2, 2)], ("NCHW", "OIHW", "NCHW"), ), ((3, 2, 5, 6), (4, 3), (4, 3), [(5, 2), (2, 4)], None), ( (1, 2, 3, 4), (1, 1), (1, 1), [(0, 0), (0, 0)], ("NCWH", "OHWI", "CNHW"), ), ( (1, 2, 3, 4), (3, 2), (1, 1), [(0, 0), (0, 0)], ("CWHN", "HOWI", "NCHW"), ), ( (2, 3, 4, 5, 6), (2, 1, 3), (2, 1, 3), [(1, 2), (5, 3), (3, 5)], ("NHWDC", "HDIWO", "DCWNH"), ), ] ], dtype=lax_test_util.all_dtypes, precision=[ None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST, ], ) def testConvGeneralDilatedPatchesNonOverlapping(self, lhs_shape, filter_shape, dtype, strides, padding, dimension_numbers, precision): if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_): # TODO(b/183565702): Support integer convolutions on CPU/GPU. if jtu.test_device_matches(["gpu"]): raise SkipTest("Integer convolution not yet supported on GPU") rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) if dimension_numbers is None: lhs_spec, rhs_spec, out_spec = "NCHW", "OIHW", "NCHW" else: lhs_spec, rhs_spec, out_spec = dimension_numbers filter_spec = ''.join(c for c in rhs_spec if c not in ('I', 'O')) patches_spec = out_spec.replace('C', 'C' + filter_spec.lower()) full_padding = [] for c in lhs_spec: if c in ('N', 'C'): full_padding += [(0, 0)] else: full_padding += [padding[filter_spec.index(c)]] lhs_padded = np.pad(lhs, full_padding, 'constant') out = lax.transpose(lhs_padded, [lhs_spec.index(c) for c in out_spec]) patches = lax.conv_general_dilated_patches( lhs=lhs, filter_shape=filter_shape, window_strides=strides, padding=padding, dimension_numbers=dimension_numbers, precision=precision ) source = [] # Test that output spatial shape is factored into `#patches x patch_size`. for c in out_spec: out_c = out.shape[out_spec.index(c)] patch_c = patches.shape[out_spec.index(c)] if c == 'N': self.assertEqual(out_c, patch_c) elif c == 'C': self.assertEqual(out_c * math.prod(filter_shape), patch_c) else: self.assertEqual(out_c, patch_c * filter_shape[filter_spec.index(c)]) source += [patches_spec.index(c), patches_spec.index(c.lower())] # Test that stacking patches together gives the source image, padded. c = out_spec.index('C') patches = patches.reshape(patches.shape[:c] + (lhs_shape[lhs_spec.index('C')],) + filter_shape + patches.shape[c + 1:] ) patches = np.moveaxis(patches, source, range(len(source))) for i in range(len(filter_shape)): patches = patches.reshape(patches.shape[:i] + (-1,) + patches.shape[2 + i:]) patches = np.moveaxis( patches, range(len(filter_shape)), [out_spec.index(c) for c in out_spec if c not in ('N', 'C')]) tol = None if (jtu.test_device_matches(["tpu"]) and precision in (None, lax.Precision.DEFAULT)): tol = 1e-3 self.assertAllClose(out, patches, atol=tol, rtol=tol) @jtu.sample_product( [ dict(n=n, lhs_spec=lhs_spec, rhs_spec=rhs_spec, out_spec=out_spec) for n in [1, 2] for lhs_spec in [ "".join(s) for s in itertools.permutations("NCHWD"[: n + 2]) ] for rhs_spec in [ "".join(s) for s in itertools.permutations("OIHWDX"[: n + 2]) ] for out_spec in [ "".join(s) for s in itertools.permutations("NCHWDX"[: n + 2]) ] ], dtype=lax_test_util.inexact_dtypes, precision=[ None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST, (lax.Precision.DEFAULT, lax.Precision.HIGHEST), ], padding=["SAME", "VALID"], ) def testConvGeneralDilatedLocal(self, dtype, precision, n, padding, lhs_spec, rhs_spec, out_spec): """Make sure LCN with tiled CNN kernel matches CNN.""" lhs_spec_default = 'NCHWDX'[:n + 2] rhs_spec_default = 'OIHWDX'[:n + 2] rng = jtu.rand_small(self.rng()) lhs_default = rng((2, 4, 7, 6, 5, 8)[:n + 2], dtype) rhs_default = rng((5, 4, 2, 3, 1, 2)[:n + 2], dtype) window_strides = (1, 2, 3, 4)[:n] rhs_dilation = (2, 1, 3, 2)[:n] lhs_perm = [lhs_spec_default.index(c) for c in lhs_spec] lhs = np.transpose(lhs_default, lhs_perm) rhs_perm = [rhs_spec_default.index(c) for c in rhs_spec] rhs = np.transpose(rhs_default, rhs_perm) kwargs = dict( lhs=lhs, window_strides=window_strides, padding=padding, rhs_dilation=rhs_dilation, dimension_numbers=(lhs_spec, rhs_spec, out_spec), precision=precision ) out_conv = lax.conv_general_dilated(rhs=rhs, **kwargs) rhs_local = np.moveaxis(rhs, (rhs_spec.index('O'), rhs_spec.index('I')), (0, 1)) rhs_local = rhs_local.reshape((rhs_local.shape[0], -1) + (1,) * n) rhs_shape = (rhs_local.shape[:2] + tuple(out_conv.shape[out_spec.index(c)] for c in rhs_spec_default[2:])) rhs_local = np.broadcast_to(rhs_local, rhs_shape) rhs_local = np.transpose(rhs_local, rhs_perm) filter_shape = [rhs.shape[i] for i in range(n + 2) if rhs_spec[i] not in ('O', 'I')] out_local = lax.conv_general_dilated_local(rhs=rhs_local, filter_shape=filter_shape, **kwargs) self.assertAllClose(out_conv, out_local) # TODO(mattjj): test conv_general_dilated against numpy def testConv0DIsDot(self): rng = jtu.rand_default(self.rng()) def args_maker(): return [rng((10, 5), np.float32), rng((5, 7), np.float32)] jnp_fun = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) self._CompileAndCheck(jnp_fun, args_maker) self._CheckAgainstNumpy(np.dot, jnp_fun, args_maker, tol=.1) def testGradConv0D(self): # Reproduces a failure in neural_tangents not caught in our presubmit tests # See cl/367416742. lhs = np.ones((2, 5), dtype=np.float32) rhs = np.ones((5, 10), dtype=np.float32) def f_jax(lhs, rhs): return lax.conv_general_dilated( lhs, rhs, window_strides=(), padding=(), lhs_dilation=(), rhs_dilation=(), dimension_numbers=lax.ConvDimensionNumbers((0, 1), (1, 0), (0, 1)), batch_group_count=1, feature_group_count=1, precision=None, preferred_element_type=None) res, pullback = jax.vjp(f_jax, lhs, rhs) grad = pullback(np.ones_like(res)) self.assertAllClose((lhs * 10., rhs * 2.), grad) @staticmethod def _conv_transpose_via_grad(data, kernel, strides, padding, rhs_dilation=None, dimension_numbers=None): """Helper method: calculates conv transpose via grad for testing.""" assert len(data.shape) == len(kernel.shape) nspatial = len(data.shape) - 2 one = (1,) * nspatial rhs_dilation = rhs_dilation or one dn = lax.conv_dimension_numbers(data.shape, kernel.shape, dimension_numbers) in_shape = np.take(data.shape, dn.lhs_spec) in_sdims = in_shape[2:] k_shape = np.take(kernel.shape, dn.rhs_spec) k_sdims = k_shape[2:] e_k_sdims = [(k-1) * r + 1 for k, r in zip(k_sdims, rhs_dilation)] if padding == 'VALID': o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0) for i in range(nspatial)] elif padding == 'SAME': o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)] o_shape = [in_shape[0], k_shape[1]] + o_sdims out_spec_inv = [x[0] for x in sorted(enumerate(dn.out_spec), key=lambda x: x[1])] o_layout = np.take(np.array(o_shape), out_spec_inv) placeholder = np.ones(o_layout, data.dtype) conv = lambda x: lax.conv_general_dilated(x, kernel, strides, padding, one, rhs_dilation, dn) _, g = jax.vjp(conv, placeholder) return g(data)[0] @staticmethod def _transpose_conv_kernel(data, kernel, dimension_numbers): dn = lax.conv_dimension_numbers(data.shape, kernel.shape, dimension_numbers) spatial_axes = np.array(dn.rhs_spec)[2:] for axis in spatial_axes: kernel = np.flip(kernel, axis) kernel = np.swapaxes(kernel, dn.rhs_spec[0], dn.rhs_spec[1]) return kernel @jtu.sample_product( [ dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [ ( (b, 9, 10, i), (k, k, j, i), ) # NB: i,j flipped in RHS for transpose for b, i, j, k in itertools.product( [2, 3], [2, 3], [2, 3], [3, 4, 5] ) ] ], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)], padding=["VALID", "SAME"], dspec=[ ("NHWC", "HWIO", "NHWC"), ], rhs_dilation=[None, (2, 2)], ) @jtu.skip_on_flag("jax_skip_slow_tests", True) def testConvTranspose2DT(self, lhs_shape, rhs_shape, dtype, strides, padding, dspec, rhs_dilation): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] # NB: this test calculates conv_transpose performing identically to the # lhs-grad of conv. def fun(lhs, rhs): return lax.conv_transpose(lhs, rhs, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec, transpose_kernel=True) def fun_via_grad(lhs, rhs): return self._conv_transpose_via_grad(lhs, rhs, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec) # NB: below just checks for agreement, we're not calling numpy. self._CheckAgainstNumpy(fun_via_grad, fun, args_maker) @jtu.sample_product( [ dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [ ((b, 9, 10, i), (k, k, i, j)) for b, i, j, k in itertools.product( [2, 3], [2, 3], [2, 3], [3, 4, 5] ) ] ], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)], padding=["VALID", "SAME"], dspec=[ ("NHWC", "HWIO", "NHWC"), ], rhs_dilation=[None, (2, 2)], ) @jtu.skip_on_flag("jax_skip_slow_tests", True) def testConvTranspose2D(self, lhs_shape, rhs_shape, dtype, strides, padding, dspec, rhs_dilation): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] def fun(lhs, rhs): return lax.conv_transpose(lhs, rhs, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec, transpose_kernel=False) def fun_via_grad(lhs, rhs): rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec) return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec) # NB: below just checks for agreement, we're not calling numpy. self._CheckAgainstNumpy(fun_via_grad, fun, args_maker) @jtu.sample_product( [ dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [ ((b, 10, i), (k, i, j)) for b, i, j, k in itertools.product( [2, 3], [2, 3], [2, 3], [3, 4, 5] ) ] ], dtype=lax_test_util.float_dtypes, strides=[(1,), (2,), (3,)], padding=["VALID", "SAME"], dspec=[ ("NHC", "HIO", "NHC"), ], rhs_dilation=[None, (2,)], ) def testConvTranspose1D(self, lhs_shape, rhs_shape, dtype, strides, padding, dspec, rhs_dilation): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] def fun(lhs, rhs): return lax.conv_transpose(lhs, rhs, strides, padding, dimension_numbers=dspec, rhs_dilation=rhs_dilation, transpose_kernel=False) def fun_via_grad(lhs, rhs): rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec) return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec) # NB: below just checks for agreement, we're not calling numpy. self._CheckAgainstNumpy(fun_via_grad, fun, args_maker) @jtu.sample_product( [ dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [ ((b, i), (i, j)) for b, i, j in itertools.product([2, 3], [2, 3], [2, 3]) ] ], dtype=lax_test_util.float_dtypes, strides=[()], padding=["VALID", "SAME"], dspec=[ ("NC", "IO", "NC"), ], rhs_dilation=[None, ()], ) def testConvTranspose0D(self, lhs_shape, rhs_shape, dtype, strides, padding, dspec, rhs_dilation): rng = jtu.rand_small(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] def fun(lhs, rhs): return lax.conv_transpose(lhs, rhs, strides, padding, dimension_numbers=dspec, rhs_dilation=rhs_dilation, transpose_kernel=False) def fun_via_grad(lhs, rhs): rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec) return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec) # NB: below just checks for agreement, we're not calling numpy. self._CheckAgainstNumpy(fun_via_grad, fun, args_maker) def testConvTransposePaddingList(self): # Regression test for https://github.com/jax-ml/jax/discussions/8695 a = jnp.ones((28,28)) b = jnp.ones((3,3)) c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1)) self.assertAllClose(c, 9 * jnp.ones((1, 1, 26, 26))) def testConvInvalidPadding(self): x = jnp.ones((1, 10, 10, 5), dtype=jnp.bfloat16) with self.assertRaisesRegex(ValueError, r"padding argument.*, got \(3, 3\)"): jax.lax.conv_general_dilated_patches(x, (5, 5), window_strides=(1, 1), padding=(3, 3)) @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]], [dict(lhs_dtype=lhs_dtype, rhs_dtype=rhs_dtype) for lhs_dtype, rhs_dtype in itertools.chain( itertools.product(lax_test_util.int_dtypes + lax_test_util.float_dtypes + lax_test_util.complex_dtypes + lax_test_util.uint_dtypes, repeat=2), zip(lax_test_util.bool_dtypes, lax_test_util.bool_dtypes))], precision=[ None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST, (lax.Precision.DEFAULT, lax.Precision.HIGHEST), ], ) def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker) @parameterized.parameters([ (algorithm, dtype) for algorithm, test_dtypes in [ (lax.DotAlgorithm( lhs_precision_type=np.float32, rhs_precision_type=np.float32, accumulation_type=np.float32, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False, ), [np.float32]), (lax.DotAlgorithm( lhs_precision_type=np.float16, rhs_precision_type=np.float16, accumulation_type=np.float32, ), [np.float16]), ("F16_F16_F32", [np.float16]), (lax.DotAlgorithmPreset.DEFAULT, lax_test_util.float_dtypes), (lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes), (lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes), (lax.DotAlgorithmPreset.F16_F16_F16, [np.float16]), (lax.DotAlgorithmPreset.F16_F16_F32, [np.float16]), (lax.DotAlgorithmPreset.BF16_BF16_BF16, [dtypes.bfloat16]), (lax.DotAlgorithmPreset.BF16_BF16_F32, [dtypes.bfloat16]), (lax.DotAlgorithmPreset.BF16_BF16_F32_X3, [np.float32]), (lax.DotAlgorithmPreset.BF16_BF16_F32_X6, [np.float32]), (lax.DotAlgorithmPreset.TF32_TF32_F32, [np.float32]), (lax.DotAlgorithmPreset.TF32_TF32_F32_X3, [np.float32]), (lax.DotAlgorithmPreset.F32_F32_F32, [np.float32]), (lax.DotAlgorithmPreset.F64_F64_F64, [np.float64]), ] for dtype in test_dtypes if jtu.dtypes.supported([dtype]) ]) def testDotAlgorithm(self, algorithm, dtype): if jtu.test_device_matches(["cpu"]): if algorithm not in { lax.DotAlgorithmPreset.DEFAULT, lax.DotAlgorithmPreset.F16_F16_F16, lax.DotAlgorithmPreset.F32_F32_F32, lax.DotAlgorithmPreset.F64_F64_F64, lax.DotAlgorithmPreset.BF16_BF16_F32, lax.DotAlgorithmPreset.BF16_BF16_F32_X3, lax.DotAlgorithmPreset.BF16_BF16_F32_X6, }: raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") if jtu.test_device_matches(["gpu"]): # GPU algorithm support is a little spotty. It is checked in # xla/service/algorithm_util.cc and the logic is copied here. if algorithm in { lax.DotAlgorithmPreset.F16_F16_F32, lax.DotAlgorithmPreset.TF32_TF32_F32, lax.DotAlgorithmPreset.BF16_BF16_F32, lax.DotAlgorithmPreset.BF16_BF16_F32_X3, lax.DotAlgorithmPreset.BF16_BF16_F32_X6, }: if not jtu.is_cuda_compute_capability_at_least("8.0"): raise SkipTest( f"The dot algorithm '{algorithm}' requires CUDA compute " "capability >= 8.0.") elif algorithm not in { lax.DotAlgorithmPreset.DEFAULT, lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32, lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, lax.DotAlgorithmPreset.F32_F32_F32, lax.DotAlgorithmPreset.F64_F64_F64, }: raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on GPU.") if jtu.test_device_matches(["tpu"]): # TODO(apaszke): Remove after 12 weeks have passed. if not jtu.if_cloud_tpu_at_least(2024, 12, 19): self.skipTest("Requires libtpu built after 2024-12-19") if algorithm not in { lax.DotAlgorithmPreset.DEFAULT, lax.DotAlgorithmPreset.BF16_BF16_F32, lax.DotAlgorithmPreset.BF16_BF16_F32_X3, lax.DotAlgorithmPreset.BF16_BF16_F32_X6, }: raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on TPU." ) lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] self._CompileAndCheck(partial(lax.dot, precision=algorithm), args_maker) self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype) def testDotAlgorithmInvalidFloat8Type(self): if jtu.test_device_matches(["cpu"]): raise SkipTest("Not supported on CPU.") lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn) with self.assertRaisesRegex(ValueError, "The dot algorithm"): lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32") def testDotAlgorithmCasting(self): if jtu.test_device_matches(["tpu"]): raise SkipTest("F32_F32_F32 is not supported on TPU.") def fun(lhs, rhs): return lax.dot(lhs, rhs, precision="F32_F32_F32") lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) self.assertEqual(fun(lhs, rhs).dtype, np.float16) def testDotAlgorithmAllowedOutputStorage(self): # see https://github.com/jax-ml/jax/issues/24794 if not jtu.test_device_matches(["gpu"]): self.skipTest("Only supported on GPU.") def fun(lhs, rhs): return lax.dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float16) lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) self.assertNotIn("convert", jax.jit(fun).lower(lhs, rhs).as_text()) def testDotAlgorithmConfig(self): lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) expected = ("algorithm = core.ShapedArray: return core.ShapedArray((2,), jnp.dtype('uint32')) @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): buf.aval = core.ShapedArray(buf.shape, buf.dtype) return FooArray(aval.shape, buf) return handler @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): def handler(arr): from jax._src.array import ArrayImpl if isinstance(arr, ArrayImpl): buf, = arr._arrays else: buf, = arr return FooArray(aval.shape, buf) return handler class FooTy(dtypes.ExtendedDType): type = dtypes.extended name = 'foo' _rules = FooTyRules def __hash__(self) -> int: return hash(FooTy) def __eq__(self, other) -> bool: return type(other) is FooTy def __repr__(self) -> str: return self.name __str__ = __repr__ # primitives make_p = core.Primitive('make') bake_p = core.Primitive('bake') take_p = core.Primitive('take') jake_p = core.Primitive('jake') def make(shape): return make_p.bind(shape=tuple(shape)) def bake(k): return bake_p.bind(k) def take(k): return take_p.bind(k) def jake(k): return jake_p.bind(k) @make_p.def_abstract_eval def make_abstract_eval(*, shape): return core.ShapedArray(shape, FooTy()) @bake_p.def_abstract_eval def bake_abstract_eval(x): if type(x.dtype) != FooTy: raise TypeError return core.ShapedArray(tuple(reversed(x.shape)), FooTy()) @take_p.def_abstract_eval def take_abstract_eval(x): return core.ShapedArray(x.shape, jnp.dtype('float32')) @jake_p.def_abstract_eval def jake_abstract_eval(x): return x # runtime ('outside jit') data types class FooArray: shape: tuple[int, ...] data: jax.Array def __init__(self, shape, data): assert data.shape == (*shape, 2) self.shape = shape self.data = data def __repr__(self) -> str: shape = ','.join(map(str, self.shape)) return f'foo[{shape}] with value\n{self.data}' size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) def shard_foo_array_handler(xs, shardings, layouts, copy_semantics): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment aval = core.get_aval(x.data) results.append(pxla.batched_device_put( aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) return results def foo_array_constant_handler(x): return array._array_mlir_constant_handler(x.data) def make_lowering(*, shape): return jnp.zeros((*shape, 2), 'uint32') def bake_lowering(k): return k.T def take_lowering(k): return jnp.broadcast_to(jnp.float32(k.size), k.shape) def jake_lowering(k): return jnp.ones((*k.shape, 2), 'uint32') def bake_vmap(batched_args, batch_dims): xs, = batched_args bdim_in, = batch_dims ys = bake(xs) perm = list(reversed(range(xs.ndim))) bdim_out = perm[bdim_in] return ys, bdim_out # All tests in this test class are thread-hostile because they add and remove # primitives from global maps. @jtu.thread_unsafe_test_class() # registration isn't thread-safe class CustomElementTypesTest(jtu.JaxTestCase): def setUp(self): core.pytype_aval_mappings[FooArray] = \ lambda x: core.ShapedArray(x.shape, FooTy(), sharding=None) xla.canonicalize_dtype_handlers[FooArray] = lambda x: x pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler mlir._constant_handlers[FooArray] = foo_array_constant_handler mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False)) mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False)) mlir.register_lowering(take_p, mlir.lower_fun(take_lowering, False)) mlir.register_lowering(jake_p, mlir.lower_fun(jake_lowering, False)) batching.defvectorized(take_p) batching.primitive_batchers[bake_p] = bake_vmap def tearDown(self): del core.pytype_aval_mappings[FooArray] del xla.canonicalize_dtype_handlers[FooArray] del mlir._constant_handlers[FooArray] del mlir._lowerings[make_p] del mlir._lowerings[bake_p] del mlir._lowerings[take_p] del batching.primitive_batchers[take_p] del batching.primitive_batchers[bake_p] def test_shaped_array_construction(self): aval = core.ShapedArray((), FooTy()) self.assertEqual(aval.str_short(), 'foo[]') aval = core.ShapedArray((3, 4), FooTy()) self.assertEqual(aval.str_short(), 'foo[3,4]') def test_make_jaxpr_identity(self): x = types.SimpleNamespace(shape=(3,), dtype=FooTy()) jaxpr = jax.make_jaxpr(lambda x: x)(x).jaxpr # { lambda ; a:foo[3]. let in (a,) } self.assertLen(jaxpr.invars, 1) a, = jaxpr.invars self.assertEqual(a.aval, core.ShapedArray((3,), FooTy())) self.assertLen(jaxpr.outvars, 1) a, = jaxpr.outvars self.assertEqual(a.aval, core.ShapedArray((3,), FooTy())) # tests after here need the primitives def test_make_jaxpr_with_primitives(self): def f(): k1 = make((3, 4)) k2 = bake(k1) x = take(k2) return x jaxpr = jax.make_jaxpr(f)().jaxpr # { lambda ; . let # a:foo[3,4] = make[shape=(3, 4)] # b:foo[4,3] = bake a # c:f32[4,3] = take b # in (c,) } self.assertLen(jaxpr.invars, 0) self.assertLen(jaxpr.eqns, 3) e1, e2, e3 = jaxpr.eqns self.assertIs(e1.primitive, make_p) self.assertLen(e1.outvars, 1) a, = e1.outvars self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy())) self.assertIs(e2.primitive, bake_p) self.assertLen(e2.outvars, 1) b, = e2.outvars self.assertEqual(b.aval, core.ShapedArray((4, 3), FooTy())) self.assertIs(e3.primitive, take_p) self.assertLen(e3.outvars, 1) c, = e3.outvars self.assertEqual(c.aval, core.ShapedArray((4, 3), np.dtype('float32'))) # tests after here need FooArray and lowerings def test_jit_closure(self): k = FooArray((), jnp.arange(2, dtype='uint32')) @jax.jit def f(): jnp.add(1, 1) # make jit not hit trivial dispatch path return k y = f() # doesn't crash self.assertIsInstance(y, FooArray) self.assertEqual(y.shape, ()) def test_jit_identity(self): k = FooArray((), jnp.arange(2, dtype='uint32')) @jax.jit def f(k): jnp.add(1, 1) # make jit not hit trivial dispatch path return k y = f(k) # doesn't crash self.assertIsInstance(y, FooArray) self.assertEqual(y.shape, ()) def test_jit_multiple_primitives(self): @jax.jit def f(): k1 = make((3,)) k2 = bake(k1) y = take(k2) return y y = f() self.assertArraysAllClose(y, jnp.array([3., 3., 3.]), check_dtypes=False) def test_scan_jaxpr(self): ks = jax.jit(lambda: make((3, 4)))() f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks) jaxpr = jax.make_jaxpr(f)(ks).jaxpr # { lambda ; a:foo[3,4]. let # b:foo[3,4] = scan[ # jaxpr={ lambda ; c:foo[4]. let d:foo[4] = bake c in (d,) } # ] a # in (b,) } self.assertLen(jaxpr.invars, 1) a, = jaxpr.invars self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy())) self.assertLen(jaxpr.eqns, 1) e, = jaxpr.eqns self.assertLen(e.outvars, 1) b, = e.outvars self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy())) def test_scan_jaxpr_split_transpose(self): def stage(x, w): x = x @ w x = jnp.tanh(x) return (x, ()) def loss(ws, x, split_transpose=False): return jnp.sum(jax.lax.scan(stage, x, ws, _split_transpose=split_transpose)[0]) def fn(*args, split_transpose=False): v, fn_transpose = jax.vjp( partial(loss, split_transpose=split_transpose), *args) grads = fn_transpose(1.0) return *grads, v # x : [batch, d_model] x = jax.random.uniform(jax.random.key(0), [256, 100]) # wss : [layers, d_model, d_model] wss = jax.random.uniform(jax.random.key(1), [7, 100, 100]) jaxpr = jax.make_jaxpr(partial(fn))(wss, x) jaxpr_split_transpose = jax.make_jaxpr(partial(fn, split_transpose=True))( wss, x ) # Check that the shapes were preserved. self.assertEqual(jaxpr.in_avals, jaxpr_split_transpose.in_avals) self.assertEqual(jaxpr.out_avals, jaxpr_split_transpose.out_avals) # The first two outvars (corresponding to gradients of params and inputs) # must come from two different loops. ct_ws = jaxpr_split_transpose.jaxpr.outvars[0] ct_x = jaxpr_split_transpose.jaxpr.outvars[1] # The last two equations are the two loops we care about backprop_scan = jaxpr_split_transpose.jaxpr.eqns[-2] self.assertEqual(backprop_scan.primitive, jax.lax.scan_p) param_gradient_map = jaxpr_split_transpose.jaxpr.eqns[-1] self.assertEqual(param_gradient_map.primitive, jax.lax.scan_p) self.assertEqual(param_gradient_map.params['num_consts'], 0) self.assertEqual(param_gradient_map.params['num_carry'], 0) # Assert that parameter gradients come from the map. self.assertEqual(ct_ws, param_gradient_map.outvars[0]) # And that activation gradients come from the scan. self.assertEqual(ct_x, backprop_scan.outvars[0]) def test_scan_lowering(self): ks = jax.jit(lambda: make((3, 4)))() f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks) _, out = jax.jit(f)(ks) # doesn't crash self.assertIsInstance(out, FooArray) self.assertEqual(out.shape, (3, 4)) def test_vmap(self): ks = jax.jit(lambda: make((3, 4, 5)))() ys = jax.vmap(jax.jit(lambda k: take(bake(k))))(ks) expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32') self.assertAllClose(ys, expected) def test_slice(self): ks = jax.jit(lambda: make((3, 4)))() ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (2, 4)) def test_dynamic_slice(self): ks = jax.jit(lambda: make((3, 4)))() ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (2, 4)) def test_transpose(self): ks = jax.jit(lambda: make((3, 4)))() ys = jax.jit(lambda x: x.T)(ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (4, 3)) def test_gather(self): ks = jax.jit(lambda: make((3, 4)))() ys = jax.jit(lambda x: x[1])(ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (4,)) ks = jax.jit(lambda: make((3, 4, 5)))() ys = jax.jit(lambda x: x[1])(ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (4, 5)) ys = jax.jit(lambda x: x[1, 2:4])(ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (2, 5)) ys = jax.jit(lambda x: x[1, 2:4, 3])(ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (2,)) ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (3, 2, 1)) def test_gather_batched_index_dtype(self): # Regression test for https://github.com/jax-ml/jax/issues/16557 dtype = jnp.int8 size = jnp.iinfo(dtype).max + 10 indices = jnp.zeros(size, dtype=dtype) values = jnp.zeros((size, 1)) results = jax.vmap(lambda x, i: jnp.take(x, i, axis=0))(values, indices) self.assertArraysEqual(results, jnp.zeros(size)) @parameterized.parameters([ (0,), (slice(1),), (np.array([0, 2]),), (np.array([False, True, True]),) ]) def test_scatter(self, idx): k = jax.jit(lambda: make(()))() ks = jax.jit(lambda: make((3,)))() ys = jax.jit(lambda x, y: x.at[idx].set(y))(ks, k) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (3,)) def test_equality(self): eq = jax.jit(lambda k1, k2: k1 == k2) ne = jax.jit(lambda k1, k2: k1 != k2) k1 = jax.jit(lambda: make(()))() k2 = jax.jit(lambda: jake(make(())))() self.assertTrue(eq(k1, k1)) self.assertFalse(eq(k1, k2)) self.assertTrue(ne(k1, k2)) self.assertFalse(ne(k1, k1)) size = 5 idx = slice(2, 4) ks = jax.jit(lambda k: jake(make((size,))).at[idx].set(k))(k1) expected = jnp.zeros(size, dtype=bool).at[idx].set(True) self.assertArraysEqual(eq(k1, ks), expected) self.assertArraysEqual(ne(k1, ks), ~expected) def test_select(self): ks = jax.jit(lambda: make((3,)))() cs = jnp.array([True, False, False]) ys = jax.jit(lax.select)(cs, ks, ks) self.assertIsInstance(ys, FooArray) self.assertEqual(ys.shape, (3,)) def test_xla_reverse_bug(self): # Regression test for b/248295786 # This was an XLA bug related to an incorrect optimization of reverse def f(x): y = jnp.array([2, 5]) return lax.rev(x * y, (0,)) x = jnp.array([1, 2]) self.assertArraysEqual(f(x), jax.jit(f)(x)) # TODO(frostig,mattjj): more polymorphic primitives tests class FunctionAccuracyTest(jtu.JaxTestCase): @parameterized.named_parameters( dict(testcase_name=f"_{dtype.__name__}", dtype=dtype) for dtype in jtu.dtypes.supported([np.float32, np.float64, np.complex64, np.complex128])) def testMPMathUtils(self, dtype): try: import mpmath except ImportError as msg: self.skipTest(f'could not import mpmath: {msg}') prec = {np.float32: 24, np.float64: 53, np.complex64: 24, np.complex128: 53}[dtype] is_complex = dtype().dtype.kind == 'c' def func(x): assert isinstance(x, mpmath.ctx_mp.mpnumeric) assert x.context.prec == prec assert isinstance(x, x.context.mpc if is_complex else x.context.mpf) return x ufunc = jtu.vectorize_with_mpmath(func, mpmath=mpmath) with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): if is_complex: arr = jtu.complex_plane_sample(dtype=dtype, size_re=11) else: cdtype = getattr(np, ufunc.map_float_to_complex[dtype.__name__]) arr = jtu.complex_plane_sample(dtype=cdtype, size_re=11, size_im=0)[1:2].real arr2 = ufunc.mptonp(ufunc.nptomp(arr)) with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): self.assertAllClose(arr, arr2, atol=0, rtol=0) arr3 = ufunc(arr) with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): self.assertAllClose(arr, arr3, atol=0, rtol=0) if is_complex: # tests scale in normalize v = dtype(1.1071487177940644+1.1102230246251565e-16j) r = dtype(1.1071487177940644+0j) mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1) nr, nv = mnp.normalize(r, r, v) self.assertAllClose(nr, nv) _functions_on_complex_plane = [ 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'conjugate', 'cos', 'cosh', 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'sin', 'sinh', 'sqrt', 'square', 'tan', 'tanh', 'sinc', 'positive', 'negative', 'absolute', 'sign' ] @parameterized.named_parameters( dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype) for name, dtype in itertools.product( _functions_on_complex_plane, jtu.dtypes.supported([np.complex64, np.complex128]), )) @jtu.skip_on_devices("tpu") def testSuccessOnComplexPlane(self, name, dtype): self._testOnComplexPlaneWorker(name, dtype, 'success') @parameterized.named_parameters( dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype) for name, dtype in itertools.product( _functions_on_complex_plane, jtu.dtypes.supported([np.complex64, np.complex128]), )) @jtu.skip_on_devices("tpu") def testFailureOnComplexPlane(self, name, dtype): self._testOnComplexPlaneWorker(name, dtype, 'failure') def _testOnComplexPlaneWorker(self, name, dtype, kind): try: import mpmath except ImportError as msg: self.skipTest(f'could not import mpmath: {msg}') is_cpu = jtu.test_device_matches(["cpu"]) machine = platform.machine() # TODO: remove is_arm_cpu as previously arm cpu related failures # were due to numpy issues. Confirm? is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm') is_cuda = jtu.test_device_matches(["cuda"]) size_re = 11 size_im = 11 atol = None if name in {"arccos", "arcsin", "arcsinh", "arccosh", "arctan", "arctanh"}: # TODO(pearu): eliminate this if-block when a fix to mpmath#787 # becomes available extra_prec_multiplier = 20 else: extra_prec_multiplier = 1 mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1, extra_prec_multiplier=extra_prec_multiplier) mnp2 = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=extra_prec_multiplier) ref_op = getattr(mnp, name) ref2_op = getattr(mnp2, name) jnp_op = getattr(jnp, name) with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): args = (jtu.complex_plane_sample(dtype=dtype, size_re=size_re, size_im=size_im),) result = np.asarray(jnp_op(*args)) expected = ref_op(*args) expected2 = ref2_op(*args) normalized_expected, normalized_result = mnp2.normalize(expected2, expected, result) # When comparing the results with expected, we'll divide the # complex plane grid into smaller regions and perform the # closeness tests on each region separately. The reason for this # is that the inaccuracy or incorrectness issues with a particular # function exists typically in specific regions while in other # regions the function is accurate. So, such a division of the # complex plane helps to identify the problematic regions as well # as to fix the inaccuracy or incorrectness issues. # # Regions in complex plane: # # ( pinfj ) # ( q2 ) (posj) ( q1 ) # (ninf) ( neg ) (zero) ( pos ) (pinf) # ( q3 ) (negj) ( q4 ) # ( ninfj ) # # In addition, the 1/3 middle parts of regions q1, q2, q3, q4, # neg, pos are tested separately as these don't contain extremely # small or extremelly large values and functions on these regions # ought not to possess any incorrectness issues. s0, s1 = size_re, size_im s03, s13 = s0 // 3, s1 // 3 s_dict = dict( q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)), q2=(slice(s0 + 2, -1), slice(1, s1 + 1)), q3=(slice(1, s0 + 1), slice(1, s1 + 1)), q4=(slice(1, s0 + 1), slice(s1 + 2, -1)), neg=(s0 + 1, slice(1, s1 + 1)), pos=(s0 + 1, slice(s1 + 2, -1)), negj=(slice(1, s0 + 1), s1 + 1), posj=(slice(s0 + 2, -1), s1 + 1), ninf=(slice(None), 0), pinf=(slice(None), -1), ninfj=(0, slice(None)), pinfj=(-1, slice(None)), zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)), ) if s03 and s13: s_dict.update( mq1 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), mq2 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(2 + s13, 2 + 2 * s13)), mq3 = (slice(2 + s03, 2 + 2 * s03), slice(2 + s13, 2 + 2 * s13)), mq4 = (slice(2 + s03, 2 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), mneg=(s0 + 1, slice(2 + s13, 2 + 2 * s13)), mpos=(s0 + 1, slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), mnegj=(slice(2 + s03, 2 + 2 * s03), s1 + 1), mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1), ) # The regions are split to real and imaginary parts (of function # return values) to (i) workaround numpy 1.x assert_allclose bug # in comparing complex infinities, and (ii) expose more details # about failing cases: s_dict_parts = dict() for k, v in s_dict.items(): s_dict_parts[k + '.real'] = v s_dict_parts[k + '.imag'] = v # Start with an assumption that all regions are problematic for a # particular function: regions_with_inaccuracies = list(s_dict_parts) # Next, we'll remove non-problematic regions from the # regions_with_inaccuracies list by explicitly keeping problematic # regions: def regions_with_inaccuracies_keep(*to_keep): to_keep_parts = [] for r in to_keep: if r.endswith('.real') or r.endswith('.imag'): to_keep_parts.append(r) else: to_keep_parts.append(r + '.real') to_keep_parts.append(r + '.imag') for item in regions_with_inaccuracies[:]: if item not in to_keep_parts: regions_with_inaccuracies.remove(item) if name == 'absolute': if is_cuda and dtype == np.complex128: regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real') else: regions_with_inaccuracies.clear() elif name == 'sign': regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4') elif name == 'log': regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag') elif name == 'log10': regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag') elif name == 'exp': regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag') elif name == 'exp2': if dtype == np.complex64: regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos.imag', 'mnegj', 'mposj') if dtype == np.complex128: regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mpos.imag') elif name == 'sinc': regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg.real', 'mpos.real', 'mnegj', 'mposj', 'ninf.imag', 'pinf.imag', 'ninfj.real', 'pinfj.real') elif name == 'sinh': if is_cuda: regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg', 'pos', 'ninf.imag', 'pinf.imag', 'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos', 'ninfj.real', 'pinfj.real') if is_cpu: regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj.imag', 'posj.imag', 'ninf.imag', 'pinf.imag', 'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos', 'ninfj.real', 'pinfj.real') elif name == 'cosh': regions_with_inaccuracies_keep('neg.imag', 'pos.imag', 'ninf.imag', 'pinf.imag', 'mneg.imag', 'mpos.imag', 'ninfj.imag', 'pinfj.imag') elif name == 'tanh': regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj') elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', 'arcsinh', 'arcsin', 'arccosh', 'arccos', 'arctan', 'arctanh', 'square'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable # Finally, perform the closeness tests per region: unexpected_success_regions = [] for region_name, region_slice in s_dict_parts.items(): region = args[0][region_slice] if region_name.endswith('.real'): result_slice, expected_slice = result[region_slice].real, expected[region_slice].real normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].real, normalized_expected[region_slice].real elif region_name.endswith('.imag'): result_slice, expected_slice = result[region_slice].imag, expected[region_slice].imag normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].imag, normalized_expected[region_slice].imag else: result_slice, expected_slice = result[region_slice], expected[region_slice] normalized_result_slice, normalized_expected_slice = normalized_result[region_slice], normalized_expected[region_slice] inexact_indices = np.where(normalized_result_slice != normalized_expected_slice) if inexact_indices[0].size == 0: inexact_samples = '' else: inexact_samples = [] for ind in zip(*inexact_indices): x = region[ind] y1, y2 = result[region_slice][ind], expected[region_slice][ind] ny1, ny2 = normalized_result[region_slice][ind], normalized_expected[region_slice][ind] if str(y1) == str(y2): # skip equal nan-s continue max_abs_diff = abs(ny1 - ny2).max() if np.isfinite(y1) and np.isfinite(y1) else np.inf inexact_samples.append((max_abs_diff, f'jax.numpy.{name}({x}) -> {y1} [{ny1}], expected {y2} [{ny2}]')) inexact_samples = "\n".join([msg for _, msg in sorted(inexact_samples)]) if kind == 'success' and region_name not in regions_with_inaccuracies: with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): self.assertAllClose( normalized_result_slice, normalized_expected_slice, atol=atol, err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=},\n{inexact_samples}") if kind == 'failure' and region_name in regions_with_inaccuracies: try: with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}"): with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): self.assertAllClose(normalized_result_slice, normalized_expected_slice) except AssertionError as msg: if str(msg).startswith('AssertionError not raised'): unexpected_success_regions.append(region_name) else: raise # something else is wrong.. def eliminate_parts(seq): # replace n.real and n.imag items in seq with n. result = [] for part_name in seq: name = part_name.split('.')[0] if name in result: continue if name + '.real' in seq and name + '.imag' in seq: result.append(name) else: result.append(part_name) return result regions_with_inaccuracies = eliminate_parts(regions_with_inaccuracies) unexpected_success_regions = eliminate_parts(unexpected_success_regions) if kind == 'success' and regions_with_inaccuracies: reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies) raise unittest.SkipTest(reason) if kind == 'failure': if not regions_with_inaccuracies: raise unittest.SkipTest("no problematic regions") elif unexpected_success_regions: # This skip ought to be effective only when fixing functions # on problematic regions in XLA that should follow up a JAX PR # that enables testing the functions on these regions for # success. raise unittest.SkipTest( f"detected success in regions {', '.join(unexpected_success_regions)}, please update regions_with_inaccuracies!" ) class CompositeTest(jtu.JaxTestCase): def test_composite(self): def my_square_impl(x): return x ** 2 my_square = lax.composite(my_square_impl, name="my.square") x = jnp.array(2.0, dtype=jnp.float32) output = my_square(x) self.assertEqual(output, jnp.array(4.0, dtype=jnp.float32)) mlir_module = jax.jit(my_square).lower(x).as_text() self.assertIn( 'stablehlo.composite "my.square" %arg0 {decomposition = @my.square} : ' '(tensor) -> tensor', mlir_module) self.assertIn('@my.square(%arg0: tensor) -> tensor {', mlir_module) self.assertIn('stablehlo.multiply %arg0, %arg0 : tensor', mlir_module) def test_composite_decorator(self): @partial(lax.composite, name="my.square") def my_square(x): return x ** 2 x = jnp.array(2.0, dtype=jnp.float32) output = my_square(x) self.assertEqual(output, jnp.array(4.0, dtype=jnp.float32)) mlir_module = jax.jit(my_square).lower(x).as_text() self.assertIn( 'stablehlo.composite "my.square" %arg0 {decomposition = @my.square} : ' '(tensor) -> tensor', mlir_module) self.assertIn('@my.square(%arg0: tensor) -> tensor {', mlir_module) self.assertIn('stablehlo.multiply %arg0, %arg0 : tensor', mlir_module) def test_composite_with_jit_function(self): def my_square_impl(x): return x ** 2 my_square = jax.jit(lax.composite(my_square_impl, name="my.square")) x = jnp.array(2.0, dtype=jnp.float32) output = my_square(x) self.assertEqual(output, jnp.array(4.0, dtype=jnp.float32)) mlir_module = my_square.lower(x).as_text() self.assertIn( 'stablehlo.composite "my.square" %arg0 {decomposition = @my.square} : ' '(tensor) -> tensor', mlir_module) self.assertIn('@my.square(%arg0: tensor) -> tensor {', mlir_module) self.assertIn('stablehlo.multiply %arg0, %arg0 : tensor', mlir_module) def test_composite_with_attributes(self): # The static_argnames is required here since k is a constant that should # come out of a larger context, but we unit test one op (composite) here. @partial(jax.jit, static_argnames=['k']) @partial(lax.composite, name="my.top_k") def my_top_k(x, *, k): return lax.top_k(x, k) x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32) output, indices = my_top_k(x, k=3) self.assertArraysEqual(output, jnp.array([5.0, 4.0, 3.0], dtype=jnp.float32)) self.assertArraysEqual(indices, jnp.array([4, 3, 2], dtype=jnp.int32)) mlir_module = my_top_k.lower(x, k=3).as_text() self.assertIn( 'stablehlo.composite "my.top_k" %arg0 ' '{composite_attributes = {k = 3 : i64}, decomposition = @my.top_k} : ' '(tensor<5xf32>) -> (tensor<3xf32>, tensor<3xi32>)', mlir_module) self.assertIn('@my.top_k(%arg0: tensor<5xf32>) -> (tensor<3xf32>, tensor<3xi32>) {', mlir_module) self.assertIn('chlo.top_k(%arg0, k = 3) : tensor<5xf32> -> (tensor<3xf32>, tensor<3xi32>)', mlir_module) def test_composite_attribute_dtypes(self): @jax.jit def my_tangent_composite_with_attributes(x): def decomposition(x, **_): return lax.sin(x) / lax.cos(x) return lax.composite(decomposition, "my.tangent")( x, dtype=np.dtype(np.float32), int=1, omit=None, str="bar", tensor=np.zeros((1, 2), dtype=np.float32), tensor_r1=np.zeros((2,), dtype=np.float32), ) pi = jnp.pi x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi], dtype=jnp.float32) output = my_tangent_composite_with_attributes(x) self.assertArraysAllClose( output, jnp.array([0.0, 1.0, -1.0, 0.0], dtype=jnp.float32) ) mlir_module = my_tangent_composite_with_attributes.lower(x).as_text() self.assertIn( 'stablehlo.composite "my.tangent" %arg0 {composite_attributes = {' 'dtype = f32, int = 1 : i64, str = "bar", ' 'tensor = dense<0.000000e+00> : tensor<1x2xf32>, ' 'tensor_r1 = dense<0.000000e+00> : tensor<2xf32>}, ' 'decomposition = @my.tangent} : (tensor<4xf32>) -> tensor<4xf32>', mlir_module) self.assertIn("func.func private @my.tangent", mlir_module) def test_composite_unsupported_attribute_dtypes(self): def my_tangent_composite_with_attributes(x): def decomposition(x, **_): return lax.sin(x) / lax.cos(x) return lax.composite(decomposition, "my.tangent")( x, tensor=jnp.zeros((1, 2), dtype=jnp.float32) ) pi = jnp.pi x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi], dtype=jnp.float32) with self.assertRaisesRegex( UnexpectedTracerError, "Note: If you are passing jax arrays as attributes, use numpy arrays " "instead." ): jax.jit(my_tangent_composite_with_attributes).lower(x).as_text() def test_composite_with_non_default_version(self): @partial(lax.composite, name="my.square", version=1) def my_square_with_version(x): return x ** 2 x = jnp.array(2.0, dtype=jnp.float32) out = my_square_with_version(x) self.assertEqual(out, 4.0) mlir_module = jax.jit(my_square_with_version).lower(x).as_text() self.assertIn( 'stablehlo.composite "my.square" %arg0 {decomposition = @my.square, ' 'version = 1 : i32} : (tensor) -> tensor', mlir_module) def test_composite_with_no_args(self): @partial(lax.composite, name="my.one") def one(): return jnp.array(1.0, dtype=jnp.float32) out = one() self.assertEqual(out, jnp.array(1.0, dtype=jnp.float32)) mlir_module = jax.jit(one).lower().as_text() self.assertIn('stablehlo.composite "my.one"', mlir_module) self.assertIn('{decomposition = @my.one} : () -> tensor', mlir_module) self.assertIn('@my.one() -> tensor', mlir_module) self.assertIn('stablehlo.constant dense<1.000000e+00> : tensor', mlir_module) def test_composite_with_variadic_input_output(self): @partial(lax.composite, name="my.ident") def ident(*args): return args x = jnp.array(1.0, dtype=jnp.float32) y = jnp.array(2.0, dtype=jnp.float32) z = jnp.array(3.0, dtype=jnp.float32) a, b, c = ident(x, y, z) self.assertEqual(a, x) self.assertEqual(b, y) self.assertEqual(c, z) mlir_module = jax.jit(ident).lower(x, y, z).as_text() self.assertIn( 'stablehlo.composite "my.ident" %arg0, %arg1, %arg2 ' '{decomposition = @my.ident} : (tensor, tensor, tensor) ' '-> (tensor, tensor, tensor)', mlir_module) self.assertIn( '@my.ident(%arg0: tensor, %arg1: tensor, %arg2: tensor) ' '-> (tensor, tensor, tensor)', mlir_module) self.assertIn('return %arg0, %arg1, %arg2 : tensor, tensor, tensor', mlir_module) def test_composite_jvp(self): @partial(lax.composite, name="my.square") def my_square(x): return x ** 2 with self.assertRaisesRegex( ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" ): jvp(my_square, (1.0,), (2.0,)) def test_composite_grad(self): @partial(lax.composite, name="my.square") def my_square(x): return x ** 2 with self.assertRaisesRegex( ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" ): grad(my_square)(1.0) def test_composite_with_array_consts(self): @partial(lax.composite, name="my.consts") def my_consts(x, /, *, scale): return jnp.round(x / scale) scale = np.array([0.5, 0.4, 0.3], dtype=np.float32) x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) self.assertAllClose(my_consts(x, scale=scale), jnp.round(x / scale)) # The constant must not appear as an extra input argument to the composite. mlir_module = jax.jit(partial(my_consts, scale=scale)).lower(x).as_text() self.assertIn( "@my.consts(%arg0: tensor<3xf32>) -> tensor<3xf32>", mlir_module ) def test_composite_with_tracer_consts(self): def fun(x, scale): @partial(lax.composite, name="my.consts") def my_consts(y): return jnp.round(y / scale) return my_consts(x) scale = jnp.array([0.5, 0.4, 0.3], dtype=jnp.float32) x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) self.assertAllClose(fun(x, scale), jnp.round(x / scale)) self.assertAllClose( jax.jit(partial(fun, scale=scale))(x), jnp.round(x / scale)) with self.assertRaisesRegex( UnexpectedTracerError, "Found a JAX Tracer as a constant in the decomposition for the " "composite op 'my.consts'."): jax.jit(fun)(x, scale) class RaggedTest(jtu.JaxTestCase): @jtu.sample_product( [ {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, {'m': 10, 'k': 9, 'n': 8, 'num_groups': 2}, ], dtype=jtu.dtypes.numeric, ) def test_ragged_dot(self, m, k, n, num_groups, dtype): """Tests ragged_dot. The ragged_dot is tested against numpy reference implementation, and by running JAX compilation. Raises: SkipTest: in the case dtype is not supported. """ lhs_shape = (m, k) rhs_shape = (num_groups, k, n) def group_sizes(m, num_groups): ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1)) ends = jnp.concatenate( [ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)]) starts = jnp.concatenate( [jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final]) return ends - starts rng = jtu.rand_small(self.rng()) args_maker = lambda: [ rng(lhs_shape, dtype), rng(rhs_shape, dtype), group_sizes(m, num_groups), ] self._CompileAndCheck(lax.ragged_dot, args_maker) self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) @parameterized.parameters( { "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "group_sizes_shape": group_sizes_shape, "ragged_dot_dimension_numbers": ragged_dot_dimension_numbers, "err_msg": err_msg, } for lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg in [ ( [11, 5], [3, 5, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [1]), ([], [])), lhs_ragged_dimensions=[0, 1], rhs_group_dimensions=[0], ), "ragged_dot_general expects exactly one lhs ragged dimension", ), ( [11, 5], [3, 5, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [1]), ([], [])), lhs_ragged_dimensions=[2], rhs_group_dimensions=[0], ), ( "ragged_dot_general requires lhs ragged dimension numbers to " "be nonnegative and less than the number of axes of the lhs" ), ), ( [11, 5], [3, 5, 7], [2, 3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [1]), ([], [])), lhs_ragged_dimensions=[0], rhs_group_dimensions=[0], ), r"expected group_sizes to have shape \(3,\), got \(2, 3\)", ), ( [19, 17, 11, 5], [3, 19, 5, 7], [19, 11, 3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([3], [2]), ([0], [1])), lhs_ragged_dimensions=[2], rhs_group_dimensions=[0], ), ( r"expected group_sizes to have shape \(19, 17, 3\), " r"got \(19, 11, 3\)" ), ), ( [19, 11, 17, 5], [19, 17, 5, 7], [19, 11, 3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([2, 3], [1, 2]), ([0], [0])), lhs_ragged_dimensions=[3], rhs_group_dimensions=[], ), ( r"expected group_sizes to have shape \(19, 17, 3\), " r"got \(19, 11, 3\)" ), ), ( [17, 19, 11, 5], [17, 19, 5, 7], [19, 3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([3], [2]), ([0, 1], [0, 1])), lhs_ragged_dimensions=[1], rhs_group_dimensions=[], ), ( r"expected group_sizes to have shape \(17, 3\), " r"got \(19, 3\)" ), ), ( [19, 11, 5], [19, 5, 7], [19, 3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([2], [1]), ([0], [0])), lhs_ragged_dimensions=[1], rhs_group_dimensions=[0], ), ( "ragged_dot_general requires rhs group dimension numbers to " "be distinct from contracting and batch dimensions" ), ), ( [11, 3], [3, 3, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [1]), ([], [])), lhs_ragged_dimensions=[0], rhs_group_dimensions=[1], ), ( "ragged_dot_general requires rhs group dimension numbers to " "be distinct from contracting and batch dimensions" ), ), ( [11, 5], [3, 5, 7], [2], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [1]), ([], [])), lhs_ragged_dimensions=[0], rhs_group_dimensions=[0], ), "expected rhs group dimension size to be 2, got 3", ), ( [2, 11, 5], [3, 2, 5, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([2], [2]), ([0], [1])), lhs_ragged_dimensions=[0], rhs_group_dimensions=[0], ), ( "ragged_dot_general requires zero group dimensions in " "the rhs when lhs ragged dimension is contracting or batch" ), ), ( [11, 5], [3, 5, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [1]), ([], [])), lhs_ragged_dimensions=[1], rhs_group_dimensions=[0], ), ( "ragged_dot_general requires zero group dimensions in " "the rhs when lhs ragged dimension is contracting or batch" ), ), ( [11, 5], [5, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [0]), ([], [])), lhs_ragged_dimensions=[0], rhs_group_dimensions=[], ), ( "ragged_dot_general requires exactly one rhs group dimension " "when lhs ragged dimension is noncontracting" ), ), ] ) def test_ragged_dot_general_shape_inference_failure( self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg): lhs = jnp.ones(lhs_shape, dtype=jnp.float32) rhs = jnp.ones(rhs_shape, dtype=jnp.float32) group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) with self.assertRaisesRegex(TypeError, err_msg): lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dot_dimension_numbers) @parameterized.parameters( { "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "group_sizes_shape": group_sizes_shape, "ragged_dnums": ragged_dnums, "out_shape": out_shape, } for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [ ( [11, 5], [3, 5, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [1]), ([], [])), lhs_ragged_dimensions=[0], rhs_group_dimensions=[0], ), (11, 7), ), ( [11, 5], [5, 7], [3], lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(([1], [0]), ([], [])), lhs_ragged_dimensions=[1], rhs_group_dimensions=[], ), (3, 11, 7), ), ] ) def test_ragged_dot_general_shape_inference_success( self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape): lhs = jnp.ones(lhs_shape, dtype=jnp.float32) rhs = jnp.ones(rhs_shape, dtype=jnp.float32) group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) self.assertEqual( lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, out_shape, ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())