# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines test inputs and invocations for JAX primitives. The idea is that we want to list all the JAX numeric primitives along with a set of inputs that should cover the use cases for each primitive. We want these separate from any particular test suite so we can reuse it to build multiple kinds of tests. For example, we can use the harnesses to check that each primitive is compiled correctly, or that we can apply a certain transformation, e.g., `vmap`. See the `Harness` class below for how to define a harness, describing one use case of one primitive. Some use cases are known to be partially implemented in JAX, e.g., because of an implementation limitation. We do have harnesses for those cases too, but we filter them out. Instead of writing this information as conditions inside one particular test, we write them as `Limitation` objects that can be reused in multiple tests and can also be used to generate documentation, e.g., the report of [unsupported and partially-implemented JAX primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) The limitations are used to filter out from tests the harnesses that are known to fail. A Limitation is specific to a harness. """ import operator import os from typing import Any, Callable, Dict, Iterable, List, Optional, NamedTuple, Sequence, Tuple, Union from functools import partial from absl import testing import jax from jax import config from jax import dtypes from jax import ad_util from jax import test_util as jtu from jax import lax from jax import numpy as jnp from jax._src.lax import control_flow as lax_control_flow from jax.interpreters import xla from jax.lib import xla_client import numpy as np FLAGS = config.FLAGS Rng = Any # A random number generator DType = Any class RandArg(NamedTuple): """Descriptor for a randomly generated argument. See description of `Harness`. """ shape: Tuple[int, ...] dtype: DType class StaticArg(NamedTuple): """Descriptor for a static argument. See description of `Harness`. """ value: Any class CustomArg(NamedTuple): """Descriptor for a dynamic argument generated from a PRNG factory. See description of `Harness`. """ make: Callable[[Rng], Any] # Called with a Rng to make a tensor ArgDescriptor = Union[RandArg, StaticArg, CustomArg, Any] class Harness: """Specifies inputs and callable for a primitive. See the module docstring for an introduction to harnesses. A harness is conceptually a callable and a list of arguments, that together exercise a use case. The harness can optionally have additional parameters that can be used by the test. The arguments are specified through argument descriptors. An argument descriptor can be: * a numeric value or ndarray, or * an instance of ``RandArg(shape, dtype)`` to be used with a PRNG to generate random tensor of the given shape and type, or * an instance of ``CustomArg(fun)`` to be used with a PRNG, or * an instance of ``StaticArg(value)``. Often these are the non-array arguments, e.g., a shape. The given callable will be passed one argument corresponding to each argument descriptor, e.g., `harness.fun(* harness.args_maker(rng))`. However, in many applications we only care about the non-static arguments. For that purpose, you can use `harness.dyn_fun(* harness.dyn_args_maked(rng))`, where `harness.dyn_fun` is `harness.fun` specialized to the static arguments. For example, a harness for ``lax.take(arr, indices, axis=None)`` may want to expose as external (non-static) argument the array and the indices, and keep the axis as a static argument (technically specializing the `take` to a axis): Harness(lax.slice_p, f"take_axis={axis}", lax.take, [RandArg((2, 4), np.float32), np.array([-1, 0, 1]), StaticArg(axis)], axis=axis) Each harness can have a list of Limitations that describe the cases when the harness may not be fully implemented. """ # The group name most often is the primitive name. group_name: str # Descriptive name of the harness, used as a testcase_name. Unique in a group. name: str # The function taking all arguments (static and dynamic). fun: Callable # Describes how to construct arguments, see the class docstring. arg_descriptors: Sequence[ArgDescriptor] dtype: DType # A set of limitations describing the cases that are not supported or # partially implemented in JAX for this harness. jax_unimplemented: Sequence["Limitation"] rng_factory: Callable # Carry some arbitrary parameters that the test can access. params: Dict[str, Any] def __init__(self, group_name, name, fun, arg_descriptors, *, dtype, rng_factory=jtu.rand_default, jax_unimplemented: Sequence["Limitation"] = (), **params): """See class docstring.""" self.group_name = group_name self.name = name self.fun = fun # type: ignore[assignment] self.arg_descriptors = arg_descriptors self.rng_factory = rng_factory # type: ignore[assignment] self.jax_unimplemented = jax_unimplemented self.dtype = dtype self.params = params def __str__(self): return self.fullname @property def fullname(self): return self.name if self.group_name is None else f"{self.group_name}_{self.name}" def _arg_maker(self, arg_descriptor, rng: Rng): if isinstance(arg_descriptor, StaticArg): return arg_descriptor.value if isinstance(arg_descriptor, RandArg): return self.rng_factory(rng)(arg_descriptor.shape, arg_descriptor.dtype) if isinstance(arg_descriptor, CustomArg): return arg_descriptor.make(rng) return arg_descriptor def args_maker(self, rng: Rng) -> Sequence: """All-argument maker, including the static ones.""" return [self._arg_maker(ad, rng) for ad in self.arg_descriptors] def dyn_args_maker(self, rng: Rng) -> Sequence: """A dynamic-argument maker, for use with `dyn_fun`.""" return [ self._arg_maker(ad, rng) for ad in self.arg_descriptors if not isinstance(ad, StaticArg) ] def dyn_fun(self, *dyn_args): """Invokes `fun` given just the dynamic arguments.""" all_args = self._args_from_dynargs(dyn_args) return self.fun(*all_args) def _args_from_dynargs(self, dyn_args: Sequence) -> Sequence: """All arguments, including the static ones.""" next_dynamic_argnum = 0 all_args = [] for ad in self.arg_descriptors: if isinstance(ad, StaticArg): all_args.append(ad.value) else: all_args.append(dyn_args[next_dynamic_argnum]) next_dynamic_argnum += 1 return all_args def filter(self, device_under_test: str, *, include_jax_unimpl: bool = False, one_containing: Optional[str] = None) -> bool: if not include_jax_unimpl: if any([ device_under_test in l.devices for l in self.jax_unimplemented if l.filter(device=device_under_test, dtype=self.dtype) ]): return False if one_containing is not None and one_containing not in self.fullname: return False return True def dtypes_to_str(dtype_list: Sequence[DType], empty_means_all=False) -> str: """User-friendly description of a set of dtypes""" if not dtype_list and empty_means_all: return "all" names = set([np.dtype(dt).name for dt in dtype_list]) signed = {"int8", "int16", "int32", "int64"} if all([t in names for t in signed]): names = (names - signed) | {"signed"} integers = {"uint8", "uint16", "uint32", "uint64"} if all([t in names for t in integers]): names = (names - integers) | {"unsigned"} integer = {"signed", "unsigned"} if all([t in names for t in integer]): names = (names - integer) | {"integer"} floating = {"bfloat16", "float16", "float32", "float64"} if all([t in names for t in floating]): names = (names - floating) | {"floating"} complex = {"complex64", "complex128"} if all([t in names for t in complex]): names = (names - complex) | {"complex"} inexact = {"floating", "complex"} if all([t in names for t in inexact]): names = (names - inexact) | {"inexact"} all_types = {"integer", "inexact", "bool"} if all([t in names for t in all_types]): names = (names - all_types) | {"all"} return ", ".join(sorted(list(names))) ##### All harnesses in this file. all_harnesses: List[Harness] = [] def define( group_name, # Should be the primitive name, as much as possible name, fun, arg_descriptors, *, dtype, rng_factory=jtu.rand_default, jax_unimplemented: Sequence["Limitation"] = (), **params): """Defines a harness and stores it in `all_harnesses`. See Harness.""" group_name = str(group_name) h = Harness(group_name, name, fun, arg_descriptors, rng_factory=rng_factory, jax_unimplemented=jax_unimplemented, dtype=dtype, **params) all_harnesses.append(h) class Limitation: """Encodes conditions under which a harness is limited, e.g., not runnable in JAX. See the module docstring for an introduction to harnesses and limitations. """ def __init__( self, description: str, *, enabled: bool = True, devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"), dtypes: Union[DType, Sequence[DType]] = (), skip_run: bool = False, ): """Args: description: text to augment the harness group name with the description of the limitation. Used for reports. enabled: whether this limitation is enabled for the harness in which it appears. This is only used during testing to know whether to ignore harness errors. Use this sparingly, prefer `devices` and `dtypes` for enabled conditions that are included in reports. devices: the list of device types for which this applies. Used for filtering during harness execution, and for reports. dtypes: the list of dtypes for which this applies. Used for filtering during harness execution, and for reports. """ assert isinstance(description, str), f"{description}" self.description = description self.skip_run = skip_run if isinstance(devices, str): devices = (devices,) else: devices = tuple(devices) self.devices = devices if not isinstance(dtypes, Iterable): dtypes = (dtypes,) else: dtypes = tuple(dtypes) self.dtypes = dtypes self.enabled = enabled # Does it apply to the current harness? def __str__(self): return (f"\"{self.description}\" devices={self.devices} " f"dtypes={[np.dtype(dt).name for dt in self.dtypes]}" + (" (skip_run) " if self.skip_run else "")) __repr__ = __str__ def filter(self, device: Optional[str] = None, dtype: Optional[DType] = None) -> bool: """Check that a limitation is enabled for the given dtype and device.""" return (self.enabled and (not self.dtypes or dtype is None or dtype in self.dtypes) and (device is None or device in self.devices)) def parameterized(harnesses: Iterable[Harness], *, one_containing: Optional[str] = None, include_jax_unimpl: bool = False): """Decorator for tests. The tests receive a `harness` argument. The `JAX_TEST_HARNESS_ONE_CONTAINING` environment variable is useful for debugging. If given, then picks only one harness whose name contains the string. The whole set of parameterized tests is reduced to one test, whose name is not decorated to make it easier to pick with an IDE for running. """ one_containing = one_containing or os.environ.get( "JAX_TEST_HARNESS_ONE_CONTAINING") cases = tuple( # Change the testcase name to include the harness name. dict( testcase_name=harness.fullname if one_containing is None else "", harness=harness) for harness in harnesses if harness.filter( jtu.device_under_test(), one_containing=one_containing, include_jax_unimpl=include_jax_unimpl)) if one_containing is not None: if not cases: raise ValueError( f"Cannot find test case with name containing {one_containing}." "Names are:" "\n".join([harness.fullname for harness in harnesses])) cases = cases[0:1] if not cases: # We filtered out all the harnesses. return jtu.skip_on_devices(jtu.device_under_test()) return testing.parameterized.named_parameters(*cases) ############################################################################### ### Harness definitions ### ############################################################################### def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): define( str(prim), f"shape={jtu.format_shape_dtype_string(shape, dtype)}", prim.bind, [RandArg(shape, dtype)], prim=prim, dtype=dtype, shape=shape) for dtype in (set(jtu.dtypes.all) - set(jtu.dtypes.all_unsigned + jtu.dtypes.boolean)): _make_unary_elementwise_harness(prim=lax.abs_p, dtype=dtype) for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: _make_unary_elementwise_harness(prim=lax.acosh_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asinh_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atanh_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.bessel_i1e_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.ceil_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.erf_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.erf_inv_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.erfc_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.floor_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.is_finite_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.lgamma_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.digamma_p, dtype=dtype) for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): _make_unary_elementwise_harness(prim=lax.neg_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.sign_p, dtype=dtype) def _make_round_harness(name, *, shape=(100, 100), dtype=np.float32, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO, operand=None): operand = operand if operand is not None else RandArg(shape, dtype) define( "round", f"{name}_shape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_roundingmethod={rounding_method}", lax.round, [operand, StaticArg(rounding_method)], dtype=dtype, operand=operand, rounding_method=rounding_method) # Validate dtypes for dtype in jtu.dtypes.all_floating: _make_round_harness("dtypes", dtype=dtype) for rounding_method in [ lax.RoundingMethod.AWAY_FROM_ZERO, lax.RoundingMethod.TO_NEAREST_EVEN ]: operand = np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32) _make_round_harness( "rounding_methods", operand=operand, rounding_method=rounding_method) # Validate edge cases for name, operand in [ # Checks that https://github.com/google/jax/issues/4952 is resolved ("round_away_from_0", np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32)), ]: _make_round_harness(f"edge_case_{name}", operand=operand) def _make_convert_element_type_harness(name, *, shape=(100, 100), dtype=np.float32, new_dtype=np.float32): define( "convert_element_type", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}", lambda arg: (lax.convert_element_type_p.bind(arg, new_dtype=new_dtype, weak_type=False)), [RandArg(shape, dtype)], shape=shape, dtype=dtype, new_dtype=new_dtype) for old_dtype in jtu.dtypes.all: # TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating # point numbers and new_dtype is an unsigned integer. See issue # https://github.com/google/jax/issues/5082 for details. for new_dtype in (jtu.dtypes.all if not (dtypes.issubdtype(old_dtype, np.floating) or dtypes.issubdtype(old_dtype, np.complexfloating)) else set(jtu.dtypes.all) - set(jtu.dtypes.all_unsigned)): _make_convert_element_type_harness( "dtypes_to_dtypes", dtype=old_dtype, new_dtype=new_dtype) def _make_integer_pow_harness(name, *, shape=(20, 30), dtype=np.int32, y=3): define( "integer_pow", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_y={y}", lax.integer_pow, [RandArg(shape, dtype), StaticArg(y)], shape=shape, dtype=dtype, y=y) for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): # Validate dtypes and y values _make_integer_pow_harness("dtypes", dtype=dtype) # Validate overflow behavior by dtype. 127 _make_integer_pow_harness("overflow", y=127, dtype=dtype) for dtype in jtu.dtypes.all_inexact: # Validate negative y by dtype _make_integer_pow_harness("negative_exp", y=-127, dtype=dtype) def _make_pow_harness(name, *, shapes=((20, 30), (20, 30)), dtype=np.float32, lhs=None, rhs=None): lhs = RandArg(shapes[0], dtype) if lhs is None else lhs rhs = RandArg(shapes[1], dtype) if rhs is None else rhs define( "pow", f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs.shape, dtype)}", lax.pow, [lhs, rhs], lhs=lhs, rhs=rhs, dtype=dtype) for dtype in jtu.dtypes.all_inexact: # Validate dtypes _make_pow_harness("dtypes", dtype=dtype) # Validate broadcasting behavior for shapes in [ ((), (4, 5, 6)), # broadcasting lhs ((4, 5, 6), ()), # broadcasting rhs ((4, 1, 6), (4, 5, 6)), # broadcasting lhs on a specific axis ((4, 5, 6), (4, 1, 6)), # broadcasting rhs on a specific axis ]: _make_pow_harness("broadcast", shapes=shapes) def _make_reshape_harness(name, *, shape=(2, 3), new_sizes=(3, 2), dimensions=(0, 1), dtype=np.float32): define( "reshape", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newsizes={new_sizes}_dimensions={dimensions}", lax.reshape, [RandArg(shape, dtype), StaticArg(new_sizes), StaticArg(dimensions)], shape=shape, dtype=dtype, new_sizes=new_sizes, dimensions=dimensions) # Validate dtypes for dtype in jtu.dtypes.all: _make_reshape_harness("dtypes", dtype=dtype) # Validate new_sizes for shape, new_sizes, dimensions in [ ((3, 4, 5), (3, 20), (0, 1, 2)), # merging two dimensions ((3, 4, 5), (4, 15), (0, 1, 2)), # changing leading dimension ]: _make_reshape_harness( "new_sizes", shape=shape, new_sizes=new_sizes, dimensions=dimensions) # Validate dimensions collapsing order for shape, new_sizes, dimensions in [ ((3, 4, 5), (3, 20), (2, 1, 0)), # transpose shape (0, 1, 2) into (2, 1, 0) ((3, 4, 5), (3, 20), (2, 0, 1)), # transpose shape (0, 1, 2) into (2, 0, 1) ]: _make_reshape_harness( "dimensions", shape=shape, new_sizes=new_sizes, dimensions=dimensions) def _make_rev_harness(name, *, shape=(4, 5), dtype=np.float32, dimensions=(0,)): define( "rev", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimensions={dimensions}", lax.rev, [RandArg(shape, dtype), StaticArg(dimensions)], shape=shape, dtype=dtype, dimensions=dimensions) # Validate dtypes for dtype in jtu.dtypes.all: _make_rev_harness("dtypes", dtype=dtype) # Validate dimensions for shape, dimensions in [ ((3, 4, 5), ()), # empty dimensions ((3, 4, 5), (0, 2)), # some dimensions ((3, 4, 5), (0, 1, 2)), # all dimensions (ordered) ((3, 4, 5), (2, 0, 1)), # all dimensions (unordered) ]: _make_rev_harness("dimensions", shape=shape, dimensions=dimensions) def _make_device_put_harness(name, *, shape=(3, 4), dtype=np.float32, device=None): _device_fn = lambda: jax.devices(device)[0] if device is not None else None define( "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_device={device}", lambda x: xla.device_put_p.bind(x, device=_device_fn()), [RandArg(shape, dtype)], shape=shape, dtype=dtype, device=device) # Validate dtypes for dtype in jtu.dtypes.all: _make_device_put_harness("dtypes", dtype=dtype) # Validate devices _make_device_put_harness("devices", device="cpu") def _make_bitcast_convert_type_harness(name, *, shape=(2, 3), dtype=np.float32, new_dtype=np.float32): define( "bitcast_convert_type", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newdtype={np.dtype(new_dtype).name}", lambda x: (lax.bitcast_convert_type_p.bind(x, new_dtype=new_dtype)), [RandArg(shape, dtype)], shape=shape, dtype=dtype, new_dtype=new_dtype) def _can_bitcast(dtype, target_dtype): def _to_equivalence_class(dtype): if dtypes.issubdtype(dtype, np.integer): return dtypes.iinfo(dtype).bits elif dtypes.issubdtype(dtype, np.floating): return dtypes.finfo(dtype).bits else: assert dtype == np.bool_ or dtypes.issubdtype(dtype, np.complexfloating) # Complex and boolean types can only be cast to themselves return np.dtype(dtype).name return _to_equivalence_class(dtype) == _to_equivalence_class(target_dtype) # Validate dtypes combinations for dtype in jtu.dtypes.all: for new_dtype in filter(partial(_can_bitcast, dtype), jtu.dtypes.all): _make_bitcast_convert_type_harness( "dtypes_to_new_dtypes", dtype=dtype, new_dtype=new_dtype) def _make_add_any_harness(name, *, shapes=((2,), (2,)), dtype=np.float32): define( ad_util.add_any_p, f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}", ad_util.add_jaxvals_p.bind, list(map(lambda s: RandArg(s, dtype), shapes)), dtype=dtype, shapes=shapes) for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): _make_add_any_harness("dtypes", dtype=dtype) for dtype in jtu.dtypes.all: shape: Tuple[int, ...] = (20, 20) define( ad_util.stop_gradient_p, f"{jtu.format_shape_dtype_string(shape, dtype)}", ad_util.stop_gradient_p.bind, [RandArg(shape, dtype)], shape=shape, dtype=dtype) _LAX_COMPARATORS = (lax.eq_p, lax.ge_p, lax.gt_p, lax.le_p, lax.lt_p, lax.ne_p) def _make_comparator_harness(name, *, dtype=np.float32, op=lax.eq_p, lhs_shape=(), rhs_shape=()): define( op.name, f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}", lambda *args: op.bind(*args), [RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype)], op=op, lhs_shape=lhs_shape, rhs_shape=rhs_shape, dtype=dtype) for op in _LAX_COMPARATORS: for dtype in (jtu.dtypes.all if op in [lax.eq_p, lax.ne_p] else set(jtu.dtypes.all) - set(jtu.dtypes.complex)): # Validate dtypes _make_comparator_harness("dtypes", dtype=dtype, op=op) # Validate broadcasting behavior for lhs_shape, rhs_shape in [ ((), (2, 3)), # broadcast scalar ((1, 2), (3, 2)), # broadcast along specific axis ]: _make_comparator_harness( "broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape, op=op) for dtype in jtu.dtypes.all: shape = (3, 4, 5) define( "zeros_like", f"shape={jtu.format_shape_dtype_string(shape, dtype)}", ad_util.zeros_like_p.bind, [RandArg(shape, dtype)], shape=shape, dtype=dtype) for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned: arg = np.array([-1, -2, 0, 1], dtype=dtype) define( "population_count", f"{jtu.dtype_str(dtype)}", lax.population_count, [arg], dtype=dtype) def _get_max_identity(dtype): if dtypes.issubdtype(dtype, np.inexact): return np.array(-np.inf, dtype) elif dtypes.issubdtype(dtype, np.integer): return np.array(dtypes.iinfo(dtype).min, dtype) elif dtypes.issubdtype(dtype, np.bool_): return np.array(False, np.bool_) def _get_min_identity(dtype): if dtypes.issubdtype(dtype, np.inexact): return np.array(np.inf, dtype) elif dtypes.issubdtype(dtype, np.integer): return np.array(dtypes.iinfo(dtype).max, dtype) elif dtypes.issubdtype(dtype, np.bool_): return np.array(True, np.bool_) def _make_argminmax_harness(prim, name, *, shape=(15,), dtype=jnp.float32, axes=(0,), index_dtype=np.int32, arr=None): arr = arr if arr is not None else RandArg(shape, dtype) dtype, shape = arr.dtype, arr.shape index_dtype = dtypes.canonicalize_dtype(index_dtype) define( prim, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}", lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr], shape=shape, dtype=dtype, axes=axes, index_dtype=index_dtype, prim=prim) for prim in [lax.argmin_p, lax.argmax_p]: for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex): # Validate dtypes for each primitive _make_argminmax_harness(prim, "dtypes", dtype=dtype) # Validate axes for each primitive; non major axis _make_argminmax_harness(prim, "axes", shape=(18, 12), axes=(1,)) # Validate index dtype for each primitive for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned: _make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype) # TODO(bchetioui): the below documents a limitation of argmin and argmax when a # dimension of the input is too large. However, it is not categorizable as it # seems that the converter fails before reaching the actual primitive call. This # suggests that we may need to harden the converter to handle inputs this big. # + tuple( # Document limitation in case of too large axis # _make_argminmax_harness("overflow_axis", prim=prim, # arr=np.ones((2**31,), dtype=np.uint8)) # for prim in [lax.argmin_p, lax.argmax_p] # ) def _make_iota_harness(name, *, shape=(2, 3), dtype=np.float32, dimension=0): define( lax.iota_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimension={dimension}", lambda dtype, shape, dim: (lax.iota_p.bind(dtype=dtype, shape=shape, dimension=dim)), [StaticArg(dtype), StaticArg(shape), StaticArg(dimension)], shape=shape, dtype=dtype, dimension=dimension) for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): _make_iota_harness("dtypes", dtype=dtype) # Validate broadcasting for shape, dimension in [ ((4, 8, 1, 1), 1), # broadcasting along non-major dimension ((4, 8, 1, 1), 2), # broadcasting along dimension == 1 ]: _make_iota_harness("broadcasting", shape=shape, dimension=dimension) def _make_div_rem_harness(prim, name, *, shapes=((2,), (2,)), dtype=np.float32, arrs=(None, None)): lhs, rhs = arrs def _make_non_zero(rng): return jtu.rand_nonzero(rng)(shapes[1], dtype) lhs = RandArg(shapes[0], dtype) if lhs is None else lhs rhs_shape = rhs.shape if rhs is not None else shapes[1] rhs_dtype = rhs.dtype if rhs is not None else dtype rhs = CustomArg(_make_non_zero) if rhs is None else rhs define( prim, f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, lhs.dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}", prim.bind, [lhs, rhs], dtype=dtype, lhs=lhs, rhs=rhs, prim=prim) for prim in [lax.div_p, lax.rem_p]: for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean) - ( set() if prim is lax.div_p else set(jtu.dtypes.complex)): _make_div_rem_harness(prim, "dtypes", dtype=dtype) # Validate broadcasting for shapes in [ ((2, 1, 3), (2, 4, 3)), # broadcast dividend ((2, 4, 3), (2, 1, 3)), # broadcast divisor ]: _make_div_rem_harness(prim, "broadcast", shapes=shapes) # Validate singularity points for name, arrs in [ ("positive_by_0", (np.ones( (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))), # TODO: this fails on CPU, different result # ("positive_by_0_int32", (np.ones((2,), dtype=np.int32), # np.zeros((2,), dtype=np.int32))), ("negative_by_0", (-np.ones( (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))), ("0_by_0", (np.zeros( (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))), ("inf_by_inf", (np.array([np.inf], dtype=np.float32), np.array([np.inf], dtype=np.float32))), ]: _make_div_rem_harness(prim, f"singularity_{name}", arrs=arrs) def _make_binary_elementwise_harnesses(prim, dtypes, default_dtype=np.float32, jax_unimplemented=lambda **kwargs: []): def _make(name, *, shapes=((20, 20), (20, 20)), dtype): lhs_shape, rhs_shape = shapes define( prim, f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}", prim.bind, [RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype)], jax_unimplemented=jax_unimplemented( dtype=dtype, prim=prim, shapes=shapes), prim=prim, dtype=dtype, shapes=shapes) return (tuple( # Validate dtypes _make("dtypes", dtype=dtype) for dtype in dtypes) + tuple( # Validate broadcasting _make("broadcasting", dtype=default_dtype, shapes=shapes) for shapes in [ ((20, 20), (1, 20)), # broadcasting rhs ((1, 20), (20, 20)), # broadcasting lhs ])) _make_binary_elementwise_harnesses( prim=lax.add_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean)) _make_binary_elementwise_harnesses( prim=lax.mul_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean)) _make_binary_elementwise_harnesses( prim=lax.atan2_p, dtypes=jtu.dtypes.all_floating) _make_binary_elementwise_harnesses( prim=lax.igamma_p, dtypes=jtu.dtypes.all_floating) _make_binary_elementwise_harnesses( prim=lax.igammac_p, dtypes=jtu.dtypes.all_floating) _make_binary_elementwise_harnesses( prim=lax.nextafter_p, dtypes=jtu.dtypes.all_floating) _make_binary_elementwise_harnesses( prim=lax.and_p, default_dtype=np.int32, dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean) _make_binary_elementwise_harnesses( prim=lax.or_p, default_dtype=np.int32, dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean) _make_binary_elementwise_harnesses( prim=lax.xor_p, default_dtype=np.int32, dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean) _make_binary_elementwise_harnesses( prim=lax.shift_left_p, default_dtype=np.int32, dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned) _make_binary_elementwise_harnesses( prim=lax.shift_right_logical_p, default_dtype=np.int32, dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned) _make_binary_elementwise_harnesses( prim=lax.shift_right_arithmetic_p, default_dtype=np.int32, dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned) _make_binary_elementwise_harnesses( prim=lax.sub_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean)) _min_max_special_cases = tuple( (lhs, rhs) for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex for lhs, rhs in [(np.array([np.inf, np.inf], dtype=dtype), np.array([np.nan, np.nan], dtype=dtype)), (np.array([-np.inf, -np.inf], dtype=dtype), np.array([np.nan, np.nan], dtype=dtype))]) _make_binary_elementwise_harnesses(prim=lax.min_p, dtypes=jtu.dtypes.all) # Validate special cases for lhs, rhs in _min_max_special_cases: define( lax.min_p, f"inf_nan_{jtu.dtype_str(lhs.dtype)}_{lhs[0]}_{rhs[0]}", lax.min_p.bind, [lhs, rhs], prim=lax.min_p, dtype=lhs.dtype) _make_binary_elementwise_harnesses(prim=lax.max_p, dtypes=jtu.dtypes.all) # Validate special cases for lhs, rhs in _min_max_special_cases: define( lax.max_p, f"inf_nan_{jtu.dtype_str(lhs.dtype)}_{lhs[0]}_{rhs[0]}", lax.max_p.bind, [lhs, rhs], prim=lax.max_p, dtype=lhs.dtype) def _make_broadcast_in_dim_harness(name, *, dtype=np.float32, shape=(2,), outshape=(2,), broadcast_dimensions=(0,)): define( lax.broadcast_in_dim_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_outshape={outshape}_broadcastdimensions={broadcast_dimensions}", lambda operand: lax.broadcast_in_dim_p.bind( operand, shape=outshape, broadcast_dimensions=broadcast_dimensions), [RandArg(shape, dtype)], shape=shape, dtype=dtype, outshape=outshape, broadcast_dimensions=broadcast_dimensions) for dtype in jtu.dtypes.all: _make_broadcast_in_dim_harness("dtypes", dtype=dtype) # Validate parameter combinations for shape, outshape, broadcast_dimensions in [ [(2,), (3, 2), (1,)], # add major dimension [(2,), (2, 3), (0,)], # add inner dimension [(), (2, 3), ()], # use scalar shape [(1, 2), (4, 3, 2), (0, 2)], # map size 1 dim to different output dim value ]: _make_broadcast_in_dim_harness( "parameter_combinations", shape=shape, outshape=outshape, broadcast_dimensions=broadcast_dimensions) def _make_broadcast_harness(name, *, dtype=np.float32, shape=(2,), sizes=()): define( lax.broadcast_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_sizes={sizes}", lambda operand: lax.broadcast_p.bind(operand, sizes=sizes), [RandArg(shape, dtype)], shape=shape, dtype=dtype, sizes=sizes) for dtype in jtu.dtypes.all: _make_broadcast_harness("dtypes", dtype=dtype) # Validate sizes for sizes in [ (2,), # broadcast 1 dim (1, 2, 3), # broadcast n > 1 dims ]: _make_broadcast_harness("sizes", sizes=sizes) for dtype in jtu.dtypes.all_floating: for arg1, arg2, arg3 in [ (np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=dtype), np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=dtype), np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6], dtype=dtype)) ]: define( lax.regularized_incomplete_beta_p, f"_{jtu.dtype_str(dtype)}", lax.betainc, [arg1, arg2, arg3], dtype=dtype) ## GATHER # Validate dtypes for dtype in set(jtu.dtypes.all): indices = np.array(2, dtype=np.int32) shape = (10,) axis = 0 define( lax.gather_p, f"dtypes_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={axis}", lambda a, i, axis: jnp.take(a, i, axis=axis), [ RandArg(shape, dtype), indices, StaticArg(axis) ], dtype=dtype) # Construct gather harnesses using take _gather_input = np.arange(1000, dtype=np.float32).reshape((10, 10, 10)) for indices in [ # Ensure each set of indices has a distinct shape np.array(2, dtype=np.int32), np.array([2], dtype=np.int32), np.array([2, 4], dtype=np.int32), np.array([[2, 4], [5, 6]], dtype=np.int32), np.array([0, 1, 10], dtype=np.int32), # Index out of bounds np.array([0, 1, 2, -1], dtype=np.int32), # Index out of bounds ]: for axis in [0, 1, 2]: define( lax.gather_p, f"from_take_indices_shape={indices.shape}_axis={axis}", lambda a, i, axis: jnp.take(a, i, axis=axis), [_gather_input, indices, StaticArg(axis)], dtype=_gather_input.dtype) # Directly from lax.gather in lax_test.py. for shape, idxs, dnums, slice_sizes in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (( 10, 5, ), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ]: dtype = np.float32 define( lax.gather_p, f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}", lambda op, idxs, dnums, slice_sizes: lax.gather( op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes), [RandArg(shape, dtype), idxs, StaticArg(dnums), StaticArg(slice_sizes)], dtype=dtype) def _make_scatter_harness(name, *, shape=(5,), f_lax=lax.scatter_min, indices_are_sorted=False, unique_indices=False, scatter_indices=np.array([[0], [2]]), update_shape=(2,), dtype=np.float32, dimension_numbers=((), (0,), (0,))): dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers) define( f_lax.__name__, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}" .replace(" ", ""), partial( f_lax, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices), [ RandArg(shape, dtype), StaticArg(scatter_indices), RandArg(update_shape, dtype), StaticArg(dimension_numbers) ], jax_unimplemented=[ Limitation( "unimplemented", devices="tpu", dtypes=np.complex64, enabled=(f_lax in [lax.scatter_max, lax.scatter_min])) ], f_lax=f_lax, shape=shape, dtype=dtype, scatter_indices=scatter_indices, update_shape=update_shape, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) # Validate dtypes for dtype in jtu.dtypes.all: for f_lax in [lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min]: if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_: continue _make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax) # Validate f_lax/update_jaxpr # We explicitly decide against testing lax.scatter, as its reduction function # is lambda x, y: y, which is not commutative and thus makes results # non-deterministic when an index into the operand is updated several times. # Validate shapes, dimension numbers and scatter indices for shape, scatter_indices, update_shape, dimension_numbers in [ ((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))), ((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,))) ]: _make_scatter_harness( "shapes_and_dimension_numbers", shape=shape, update_shape=update_shape, scatter_indices=np.array(scatter_indices), dimension_numbers=dimension_numbers) # Validate sorted indices _make_scatter_harness("indices_are_sorted", indices_are_sorted=True) # Validate unique_indices # `unique_indices` does not affect correctness, only performance, and thus # does not need to be tested here. If/when it will make sense to add a test # with `unique_indices` = True, particular care will have to be taken with # regards to the choice of parameters, as the results are only predictable # when all the indices to be updated are pairwise non-overlapping. Identifying # such cases is non-trivial. _make_scatter_harness("unique_indices", unique_indices=False) for dtype in jtu.dtypes.all: arg_shape = (2, 3) for pads in [ [(0, 0, 0), (0, 0, 0)], # no padding [(1, 1, 0), (2, 2, 0)], # only positive edge padding [(1, 2, 1), (0, 1, 0)], # edge padding and interior padding [(0, 0, 0), (-1, -1, 0)], # negative padding [(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges [(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension ]: define( lax.pad_p, f"inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_pads={pads}", lax.pad, [RandArg(arg_shape, dtype), np.array(0, dtype), StaticArg(pads)], rng_factory=jtu.rand_small, arg_shape=arg_shape, dtype=dtype, pads=pads) def _make_select_harness(name, *, shape_pred=(2, 3), shape_args=(2, 3), dtype=np.float32): define( lax.select_p, f"{name}_shapepred={jtu.format_shape_dtype_string(shape_pred, np.bool_)}_shapeargs={jtu.format_shape_dtype_string(shape_args, dtype)}", lax.select, [ RandArg(shape_pred, np.bool_), RandArg(shape_args, dtype), RandArg(shape_args, dtype) ], shape_pred=shape_pred, shape_args=shape_args, dtype=dtype) for dtype in jtu.dtypes.all: _make_select_harness("dtypes", dtype=dtype) # Validate shapes _make_select_harness("shapes", shape_pred=(), shape_args=(18,)) def _make_transpose_harness(name, *, shape=(2, 3), permutation=(1, 0), dtype=np.float32): define( lax.transpose_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_permutation={permutation}" .replace(" ", ""), lambda x: lax.transpose_p.bind(x, permutation=permutation), [RandArg(shape, dtype)], shape=shape, dtype=dtype, permutation=permutation) for dtype in jtu.dtypes.all: _make_transpose_harness("dtypes", dtype=dtype) # Validate permutations for shape, permutation in [ ((2, 3, 4), (0, 1, 2)), # identity ((2, 3, 4), (1, 2, 0)), # transposition ]: _make_transpose_harness("permutations", shape=shape, permutation=permutation) ## CUMREDUCE def _make_cumreduce_harness(name, *, f_jax=lax_control_flow.cummin, shape=(8, 9), dtype=np.float32, axis=0, reverse=False): limitations = [] if f_jax.__name__ != "cumsum": limitations.append( Limitation( "unimplemented", devices="tpu", dtypes=np.complex64, )) define( f_jax.__name__, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={axis}_reverse={reverse}", f_jax, [RandArg(shape, dtype), StaticArg(axis), StaticArg(reverse)], jax_unimplemented=limitations, f_jax=f_jax, shape=shape, dtype=dtype, axis=axis, reverse=reverse) # Validate dtypes for each function for f_jax in [ lax_control_flow.cummin, lax_control_flow.cummax, lax_control_flow.cumsum, lax_control_flow.cumprod ]: for dtype in jtu.dtypes.all: if dtype == np.bool_: continue _make_cumreduce_harness("dtype_by_fun", dtype=dtype, f_jax=f_jax) # Validate axis for each function shape = (8, 9) for axis in range(len(shape)): _make_cumreduce_harness("axis_by_fun", axis=axis, f_jax=f_jax, shape=shape) # Validate reverse for each function _make_cumreduce_harness("reverse", reverse=True, f_jax=f_jax) ### TOP_K def _make_top_k_harness(name, *, operand=None, shape=(5, 3), dtype=np.float32, k=2): if operand is None: operand = RandArg(shape, dtype) define( lax.top_k_p, f"{name}_inshape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_k={k}", lax.top_k, [operand, StaticArg(k)], shape=operand.shape, dtype=operand.dtype, k=k) for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex): # Validate dtypes _make_top_k_harness("dtypes", dtype=dtype) # Validate implicit properties of the sort for name, operand, k in [("stability", np.array([5, 7, 5, 8, 8, 5], dtype=np.int32), 3), ("sort_inf_nan", np.array([+np.inf, np.nan, -np.nan, -np.inf, 3], dtype=np.float32), 5)]: _make_top_k_harness(name, operand=operand, k=k) ### SORT def _make_sort_harness(name, *, operands=None, shape=(5, 7), dtype=np.float32, dimension=0, is_stable=False, num_keys=1): if operands is None: operands = [RandArg(shape, dtype)] define( lax.sort_p, f"{name}_num_arrays={len(operands)}_shape={jtu.format_shape_dtype_string(operands[0].shape, operands[0].dtype)}_axis={dimension}_isstable={is_stable}_num_keys={num_keys}", lambda *args: lax.sort_p.bind( *args[:-3], dimension=args[-3], is_stable=args[-2], num_keys=args[-1] ), [ *operands, StaticArg(dimension), StaticArg(is_stable), StaticArg(num_keys) ], shape=operands[0].shape, dimension=dimension, dtype=operands[0].dtype, is_stable=is_stable, num_keys=num_keys, num_arrays=len(operands)) _lax_sort_multiple_array_shape = (100,) # In order to test lexicographic ordering and sorting stability, the first # array contains only integers 0 and 1 _lax_sort_multiple_array_first_arg = ( np.random.uniform(0, 2, _lax_sort_multiple_array_shape).astype(np.int32)) # Validate dtypes for dtype in jtu.dtypes.all: _make_sort_harness("dtypes", dtype=dtype) # Validate dimensions for dimension in [0, 1]: _make_sort_harness("dimensions", dimension=dimension) # Validate stable sort _make_sort_harness("is_stable", is_stable=True) # Potential edge cases for operands, dimension in [ ([np.array([+np.inf, np.nan, -np.nan, -np.inf, 2], dtype=np.float32)], 0) ]: _make_sort_harness("edge_cases", operands=operands, dimension=dimension) # Validate multiple arrays, num_keys, and is_stable for is_stable in [False, True]: for operands in ( [ _lax_sort_multiple_array_first_arg, RandArg(_lax_sort_multiple_array_shape, np.int32) ], [ _lax_sort_multiple_array_first_arg, RandArg(_lax_sort_multiple_array_shape, np.int32), RandArg(_lax_sort_multiple_array_shape, np.float32) ], ): for num_keys in range(1, len(operands) + 1): _make_sort_harness( "multiple_arrays", operands=operands, num_keys=num_keys, is_stable=is_stable, shape=_lax_sort_multiple_array_first_arg.shape, dtype=_lax_sort_multiple_array_first_arg.dtype) def _make_cholesky_arg(shape, dtype, rng): a = jtu.rand_default(rng)(shape, dtype) return np.matmul(a, jnp.conj(np.swapaxes(a, -1, -2))) for dtype in jtu.dtypes.all_inexact: for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]: define( lax.linalg.cholesky_p, f"shape={jtu.format_shape_dtype_string(shape, dtype)}", lambda *args: lax.linalg.cholesky_p.bind(*args), [CustomArg(partial(_make_cholesky_arg, shape, dtype))], jax_unimplemented=[ Limitation( "unimplemented", dtypes=[np.float16], devices=("cpu", "gpu")) ], shape=shape, dtype=dtype) for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]: for full_matrices in [False, True]: define( lax.linalg.qr_p, f"multi_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}", lax.linalg.qr, [RandArg(shape, dtype), StaticArg(full_matrices)], # See jax.lib.lapack.geqrf for the list of compatible types jax_unimplemented=[ Limitation( "unimplemented", devices=("cpu", "gpu"), dtypes=[np.float16, dtypes.bfloat16]), ], shape=shape, dtype=dtype, full_matrices=full_matrices) def _make_fft_harness(name, *, shape=(14, 15, 16, 17), dtype=np.float32, fft_type=xla_client.FftType.FFT, fft_lengths=(17,)): def _fft_rng_factory(dtype): _all_integers = ( jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean) # For integer types, use small values to keep the errors small if dtype in _all_integers: return jtu.rand_small else: return jtu.rand_default define( lax.fft_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_ffttype={fft_type}_fftlengths={fft_lengths}", lambda *args: lax.fft_p.bind( args[0], fft_type=args[1], fft_lengths=args[2]), [RandArg(shape, dtype), StaticArg(fft_type), StaticArg(fft_lengths)], rng_factory=_fft_rng_factory(dtype), shape=shape, dtype=dtype, fft_type=fft_type, fft_lengths=fft_lengths) # FFT, IFFT, RFFT, IRFFT for fft_type in list(map(xla_client.FftType, [0, 1, 2, 3])): # Validate dtypes per FFT type for dtype in (jtu.dtypes.floating if fft_type == xla_client.FftType.RFFT else jtu.dtypes.complex): shape = (14, 15, 16, 17) for fft_lengths in [ (shape[-1],) if fft_type != xla_client.FftType.IRFFT else ((shape[-1] - 1) * 2,) ]: _make_fft_harness( "dtypes", shape=shape, dtype=dtype, fft_type=fft_type, fft_lengths=fft_lengths) # Validate dimensions per FFT type for dtype in [ np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64 ]: for dims in [1, 2, 3]: for fft_lengths in [ shape[-dims:] if fft_type != xla_client.FftType.IRFFT else shape[-dims:-1] + ((shape[-1] - 1) * 2,) ]: _make_fft_harness( "dims", shape=shape, fft_type=fft_type, fft_lengths=fft_lengths, dtype=dtype) for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]: for full_matrices in [False, True]: for compute_uv in [False, True]: define( lax.linalg.svd_p, f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}", lambda *args: lax.linalg.svd_p.bind( args[0], full_matrices=args[1], compute_uv=args[2]), [ RandArg(shape, dtype), StaticArg(full_matrices), StaticArg(compute_uv) ], jax_unimplemented=[ Limitation( "unimplemented", devices=("cpu", "gpu"), dtypes=[np.float16, dtypes.bfloat16]), Limitation( "complex not implemented. Works in JAX for CPU and GPU with custom kernels", devices="tpu", dtypes=[np.complex64, np.complex128]) ], shape=shape, dtype=dtype, full_matrices=full_matrices, compute_uv=compute_uv) for dtype in jtu.dtypes.all_inexact: for shape in [(0, 0), (5, 5), (2, 6, 6)]: for compute_left_eigenvectors in [False, True]: for compute_right_eigenvectors in [False, True]: define( lax.linalg.eig_p, f"shape={jtu.format_shape_dtype_string(shape, dtype)}_computelefteigenvectors={compute_left_eigenvectors}_computerighteigenvectors={compute_right_eigenvectors}", lax.linalg.eig, [ RandArg(shape, dtype), StaticArg(compute_left_eigenvectors), StaticArg(compute_right_eigenvectors) ], jax_unimplemented=[ Limitation( "only supported on CPU in JAX", devices=("tpu", "gpu")), Limitation( "unimplemented", devices="cpu", dtypes=[np.float16, dtypes.bfloat16]) ], shape=shape, dtype=dtype, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors) def _make_triangular_eigh_operand(shape, dtype, lower: bool, rng: Rng): # For testing eigh we use triangular matrices operand = jtu.rand_default(rng)(shape, dtype) # Make operand self-adjoint operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2 # Make operand lower/upper triangular return operand # np.tril(operand) if lower else np.triu(operand) for dtype in jtu.dtypes.all_inexact: for shape in [(0, 0), (50, 50), (2, 20, 20)]: for lower in [False, True]: define( lax.linalg.eigh_p, f"shape={jtu.format_shape_dtype_string(shape, dtype)}_lower={lower}", # Make operand lower/upper triangular lambda operand, lower, symmetrize_input: (lax.linalg.eigh( jnp.tril(operand) if lower else jnp.triu(operand), lower, symmetrize_input)), # lax.linalg.eigh, [ CustomArg( partial(_make_triangular_eigh_operand, shape, dtype, lower)), StaticArg(lower), StaticArg(False) ], jax_unimplemented=[ Limitation( "complex eigh not supported ", devices="tpu", dtypes=[np.complex64, np.complex128]), Limitation( "unimplemented", devices=("cpu", "gpu"), dtypes=[np.float16, dtypes.bfloat16]), ], shape=shape, dtype=dtype, lower=lower) for dtype in jtu.dtypes.all_inexact: for shape in [ (5, 5), # square (3, 5, 5), # batched (3, 5), # non-square ]: define( lax.linalg.lu_p, f"shape={jtu.format_shape_dtype_string(shape, dtype)}", lax.linalg.lu, [RandArg(shape, dtype)], jax_unimplemented=[ Limitation( "unimplemented", dtypes=[np.float16, dtypes.bfloat16]) ], shape=shape, dtype=dtype) def _make_triangular_solve_harness(name, *, left_side=True, lower=False, ab_shapes=((4, 4), (4, 1)), dtype=np.float32, transpose_a=False, conjugate_a=False, unit_diagonal=False): a_shape, b_shape = ab_shapes f_lax = lambda a, b: ( lax.linalg.triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)) define( lax.linalg.triangular_solve_p, f"{name}_a={jtu.format_shape_dtype_string(a_shape, dtype)}_b={jtu.format_shape_dtype_string(b_shape, dtype)}_leftside={left_side}_lower={lower}_transposea={transpose_a}_conjugatea={conjugate_a}_unitdiagonal={unit_diagonal}", f_lax, [RandArg(a_shape, dtype), RandArg(b_shape, dtype)], jax_unimplemented=[ Limitation( "unimplemented", devices="gpu", dtypes=[np.float16]), ], dtype=dtype, a_shape=a_shape, b_shape=b_shape, left_side=left_side, lower=lower, tranpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) # Validate dtypes # This first harness runs the tests for all dtypes using default values for # all the other parameters, except unit_diagonal (to ensure that # tf.linalg.set_diag works reliably for all dtypes). Variations of other # parameters can thus safely skip testing their corresponding default value. # Note that this validates solving on the left. for dtype in jtu.dtypes.all_inexact: for unit_diagonal in [False, True]: _make_triangular_solve_harness( "dtypes", dtype=dtype, unit_diagonal=unit_diagonal) # Validate shapes when solving on the right for ab_shapes in [ ((4, 4), (1, 4)), # standard ((2, 8, 8), (2, 10, 8)), # batched ]: _make_triangular_solve_harness( "shapes_right", ab_shapes=ab_shapes, left_side=False) # Validate transformations of a complex matrix for lower in [False, True]: for transpose_a in [False, True]: for conjugate_a in [False, True]: _make_triangular_solve_harness( "complex_transformations", dtype=np.complex64, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a) # Validate transformations of a real matrix for lower in [False, True]: for transpose_a in [False, True]: # conjugate_a is irrelevant for real dtypes, and is thus omitted _make_triangular_solve_harness( "real_transformations", dtype=np.float32, lower=lower, transpose_a=transpose_a) def _make_linear_solve_harnesses(): def linear_solve(a, b, solve, transpose_solve=None, symmetric=False): matvec = partial(lax.dot, a, precision=lax.Precision.HIGHEST) return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric) def explicit_jacobian_solve(matvec, b): return lax.stop_gradient(jnp.linalg.solve(jax.api.jacobian(matvec)(b), b)) def _make_harness(name, *, shape=(4, 4), dtype=np.float32, symmetric=False, solvers=(explicit_jacobian_solve, explicit_jacobian_solve)): solve, transpose_solve = solvers transpose_solve_name = transpose_solve.__name__ if transpose_solve else None def _make_first_argument(rng): a = jtu.rand_default(rng)(shape, dtype) if symmetric: a = a + a.T return a define( lax.linear_solve_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_b={jtu.format_shape_dtype_string(shape[:-1], dtype)}_solve={solve.__name__}_transposesolve={transpose_solve_name}_symmetric={symmetric}", linear_solve, [ CustomArg(_make_first_argument), RandArg(shape[:-1], dtype), StaticArg(solve), StaticArg(transpose_solve), StaticArg(symmetric) ], shape=shape, dtype=dtype, solve=solve, transpose_solve=transpose_solve, symmetric=symmetric) for dtype in jtu.dtypes.all_floating: if not dtype in [np.float16, dtypes.bfloat16]: _make_harness("dtypes", dtype=dtype) # Validate symmetricity _make_harness("symmetric", symmetric=True) # Validate removing transpose_solve _make_harness("transpose_solve", solvers=(explicit_jacobian_solve, None)) _make_linear_solve_harnesses() def _make_slice_harness(name, shape=(3,), start_indices=(1,), limit_indices=(2,), strides=None, dtype=np.float32): define( lax.slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore lax.slice, [ RandArg(shape, dtype), # type: ignore StaticArg(start_indices), # type: ignore StaticArg(limit_indices), # type: ignore StaticArg(strides) ], # type: ignore dtype=dtype, shape=shape, # type: ignore start_indices=start_indices, # type: ignore limit_indices=limit_indices) # type: ignore # Test first all dtypes for dtype in jtu.dtypes.all: _make_slice_harness("dtypes", dtype=dtype) # Now test many shapes for shape, start_indices, limit_indices, strides in [ ((3,), (1,), (2,), None), ((7,), (4,), (7,), None), ((5,), (1,), (5,), (2,)), ((8,), (1,), (6,), (2,)), ((5, 3), (1, 1), (3, 2), None), ((5, 3), (1, 1), (3, 1), None), ((7, 5, 3), (4, 0, 1), (7, 1, 3), None), ((5, 3), (1, 1), (2, 1), (1, 1)), ((5, 3), (1, 1), (5, 3), (2, 1)), ]: _make_slice_harness( "shapes", shape=shape, start_indices=start_indices, limit_indices=limit_indices, strides=strides) def _make_complex_harness(name, *, shapes=((3, 4), (3, 4)), dtype=np.float32): define( lax.complex_p, f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}", lax.complex_p.bind, [RandArg(shapes[0], dtype), RandArg(shapes[1], dtype)], shapes=shapes, dtype=dtype) for dtype in jtu.dtypes.floating: _make_complex_harness("dtypes", dtype=dtype) # Validate broadcasting for shapes in [ ((3, 2), (3, 1)), # broadcast imaginary part ((3, 1), (3, 2)), # broadcast real part ]: _make_complex_harness("broadcast", shapes=shapes) def _make_conj_harness(name, *, shape=(3, 4), dtype=np.float32, **kwargs): define( lax.conj_p, f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}_kwargs={kwargs}" .replace(" ", ""), lambda x: lax.conj_p.bind(x, **kwargs), [RandArg(shape, dtype)], shape=shape, dtype=dtype, **kwargs) for dtype in jtu.dtypes.floating + jtu.dtypes.complex: _make_conj_harness("dtypes", dtype=dtype) # Validate kwargs _make_conj_harness("kwargs", _input_dtype=np.float32) def _make_real_imag_harness(prim, name, *, shape=(2, 3), dtype=np.float32): define( prim, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}", prim.bind, [RandArg(shape, dtype)], shape=shape, dtype=dtype, prim=prim) for dtype in jtu.dtypes.complex: for prim in [lax.real_p, lax.imag_p]: _make_real_imag_harness(prim, "dtypes", dtype=dtype) def _make_dynamic_slice_harness(name, shape=(3,), start_indices=(1,), limit_indices=(2,), dtype=np.float32): define( lax.dynamic_slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_start_indices={start_indices}_limit_indices={limit_indices}", # type: ignore lax.dynamic_slice, [ RandArg(shape, dtype), # type: ignore np.array(list(start_indices)), StaticArg(tuple(map(operator.sub, limit_indices, start_indices))) ], # type: ignore dtype=dtype, shape=shape, # type: ignore start_indices=start_indices, # type: ignore limit_indices=limit_indices) # type: ignore # Test first all dtypes for dtype in jtu.dtypes.all: _make_dynamic_slice_harness("dtypes", dtype=dtype) # Now test many shapes for shape, start_indices, limit_indices in [ ((3,), (1,), (2,)), ((7,), (4,), (7,)), ((5,), (1,), (5,)), ((8,), (1,), (6,)), ((5, 3), (1, 1), (3, 2)), ((7, 5, 3), (4, 0, 1), (7, 1, 3)), ((5, 3), (1, 1), (2, 1)), # out-of-bounds cases, allowed for dynamic_slice ((5,), (-1,), (0,)), ((5,), (-1,), (1,)), ((5,), (-4,), (-2,)), ((5,), (-5,), (-2,)), ((5,), (-6,), (-5,)), ((5,), (-10,), (-9,)), ((5,), (-100,), (-99,)), ((5,), (5,), (6,)), ((5,), (10,), (11,)), ((5,), (3,), (6,)) ]: _make_dynamic_slice_harness( "shapes", shape=shape, start_indices=start_indices, limit_indices=limit_indices) def _make_dynamic_update_slice_harness(name, shape=(3,), start_indices=(1,), dtype=np.float32, update_shape=(1,)): define( lax.dynamic_update_slice_p, ( f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}" f"_start_indices={start_indices}"), lax.dynamic_update_slice, [ RandArg(shape, dtype), # type: ignore RandArg(update_shape, dtype), # type: ignore np.array(start_indices) ], # type: ignore dtype=dtype, shape=shape, # type: ignore start_indices=start_indices, # type: ignore update_shape=update_shape) # type: ignore # Test first all dtypes for dtype in jtu.dtypes.all: _make_dynamic_update_slice_harness("dtypes", dtype=dtype) # Now test many shapes for shape, start_indices, update_shape in [ ((3,), (1,), (1,)), ((5, 3), (1, 1), (3, 1)), ((7, 5, 3), (4, 1, 0), (2, 0, 1)), ((3,), (-1,), (1,)), # out-of-bounds ((3,), (10,), (1,)), # out-of-bounds ((3,), (10,), (2,)), # out-of-bounds ]: _make_dynamic_update_slice_harness( "shapes", shape=shape, start_indices=start_indices, update_shape=update_shape) def _make_squeeze_harness(name, shape=(1, 2), dimensions=(0,), dtype=np.float32): define( lax.squeeze_p, f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_dimensions={dimensions}", # type: ignore lax.squeeze, [RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type] dtype=dtype, arg_shape=shape, dimensions=dimensions) # type: ignore[has-type] # Test first all dtypes for dtype in set(jtu.dtypes.all): _make_squeeze_harness("dtypes", dtype=dtype) # Now test many shapes for shape, dimensions in [ ((1,), (0,)), ((1,), (-1,)), ((2, 1, 4), (1,)), ((2, 1, 4), (-2,)), ((2, 1, 3, 1), (1,)), ((2, 1, 3, 1), (1, 3)), ((2, 1, 3, 1), (3,)), ((2, 1, 3, 1), (1, -1)), ]: _make_squeeze_harness("shapes", shape=shape, dimensions=dimensions) def _make_select_and_scatter_add_harness(name, *, shape=(2, 4, 6), dtype=np.float32, select_prim=lax.ge_p, window_dimensions=(2, 2, 2), window_strides=(1, 1, 1), padding=((0, 0), (0, 0), (0, 0)), nb_inactive_dims=0): ones = (1,) * len(shape) cotangent_shape = jax.api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, window_dimensions, window_strides, padding, ones, ones), np.ones(shape, dtype)).shape define( lax.select_and_scatter_add_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}", lax._select_and_scatter_add, [ RandArg(cotangent_shape, dtype), RandArg(shape, dtype), StaticArg(select_prim), StaticArg(window_dimensions), StaticArg(window_strides), StaticArg(padding) ], jax_unimplemented=[ Limitation( "works only for 2 or more inactive dimensions", devices="tpu", enabled=(nb_inactive_dims < 2)) ], shape=shape, dtype=dtype, select_prim=select_prim, window_dimensions=window_dimensions, window_strides=window_strides, padding=padding) for dtype in set(jtu.dtypes.all) - set([np.complex64, np.complex128]): _make_select_and_scatter_add_harness("dtypes", dtype=dtype) # Validate different reduction primitives _make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p) # Validate padding for padding in [ # TODO(bchetioui): commented out the test based on # https://github.com/google/jax/issues/4690 # ((1, 2), (2, 3), (3, 4)) # non-zero padding ((1, 1), (1, 1), (1, 1)) # non-zero padding ]: _make_select_and_scatter_add_harness("padding", padding=padding) # Validate window_dimensions; uneven dimensions _make_select_and_scatter_add_harness( "window_dimensions", window_dimensions=(1, 2, 3)) # Validate window_strides # smaller than/same as/bigger than corresponding window dimension _make_select_and_scatter_add_harness("window_strides", window_strides=(1, 2, 3)) # Validate dtypes on TPU for dtype in set(jtu.dtypes.all) - set( [np.bool_, np.complex64, np.complex128, np.int8, np.uint8]): for window_strides, window_dimensions, nb_inactive_dims in [((1, 2, 1), (1, 3, 1), 2)]: _make_select_and_scatter_add_harness( "tpu_dtypes", dtype=dtype, nb_inactive_dims=nb_inactive_dims, window_strides=window_strides, window_dimensions=window_dimensions) def _make_select_and_gather_add_harness(name, *, shape=(4, 6), dtype=np.float32, select_prim=lax.le_p, padding="VALID", window_dimensions=(2, 2), window_strides=(1, 1), base_dilation=(1, 1), window_dilation=(1, 1)): if isinstance(padding, str): padding = tuple( lax.padtype_to_pads(shape, window_dimensions, window_strides, padding)) define( lax.select_and_gather_add_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}", lax._select_and_gather_add, [ RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim), StaticArg(window_dimensions), StaticArg(window_strides), StaticArg(padding), StaticArg(base_dilation), StaticArg(window_dilation) ], shape=shape, dtype=dtype, window_dimensions=window_dimensions, window_strides=window_strides, padding=padding, base_dilation=base_dilation, window_dilation=window_dilation) for dtype in jtu.dtypes.all_floating: for select_prim in [lax.ge_p, lax.le_p]: _make_select_and_gather_add_harness("dtypes", dtype=dtype, select_prim=select_prim) # Validate selection primitives _make_select_and_gather_add_harness("select_prim", select_prim=lax.ge_p) # Validate window dimensions _make_select_and_gather_add_harness( "window_dimensions", window_dimensions=(2, 3)) # Validate window strides _make_select_and_gather_add_harness("window_strides", window_strides=(2, 3)) # Validate padding _make_select_and_gather_add_harness("padding", padding="SAME") # Validate dilations for base_dilation, window_dilation in [ ((2, 3), (1, 1)), # base dilation, no window dilation ((1, 1), (2, 3)), # no base dilation, window dilation ((2, 3), (3, 2)) # base dilation, window dilation ]: _make_select_and_gather_add_harness( "dilations", base_dilation=base_dilation, window_dilation=window_dilation) def _make_reduce_window_harness(name, *, shape=(4, 6), base_dilation=(1, 1), computation=lax.add, window_dimensions=(2, 2), window_dilation=(1, 1), init_value=0, window_strides=(1, 1), dtype=np.float32, padding=((0, 0), (0, 0))): prim_name = f"reduce_window_{computation.__name__}" limitations = [] if computation.__name__ in ("max", "mul", "min"): limitations.append( Limitation("unimplemented in XLA", devices="tpu", dtypes=np.complex64)) define( prim_name, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}" .replace(" ", ""), lax.reduce_window, [ RandArg(shape, dtype), # Must be static to trigger the picking of the reducers StaticArg(np.array(init_value, dtype=dtype)), StaticArg(computation), StaticArg(window_dimensions), StaticArg(window_strides), StaticArg(padding), StaticArg(base_dilation), StaticArg(window_dilation) ], jax_unimplemented=limitations, shape=shape, dtype=dtype, init_value=np.array(init_value, dtype=dtype), computation=computation, window_dimensions=window_dimensions, window_strides=window_strides, padding=padding, base_dilation=base_dilation, window_dilation=window_dilation) # Validate dtypes across all execution paths # This first harness runs the tests for all dtypes using default values for # the other parameters (outside of computation and its init_value), through # several execution paths. Variations of other parameters can thus safely # skip testing their corresponding default value. for dtype in jtu.dtypes.all: for computation, init_value in [ (lax.min, _get_min_identity(dtype)), # path through reduce_window_min (lax.max, _get_max_identity(dtype)), # path through TF reduce_window_max (lax.max, 1), # path through reduce_window (lax.add, 0), # path_through reduce_window_sum (lax.add, 1), # path through reduce_window (lax.mul, 0), # path through reduce_window (lax.mul, 1), # path through reduce_window (lax.mul, 2), # path through reduce_window ]: if dtype == np.bool_ and (computation in [lax.add, lax.mul]): continue _make_reduce_window_harness( "dtypes", dtype=dtype, computation=computation, init_value=init_value) # Validate window_dimensions _make_reduce_window_harness("window_dimensions", window_dimensions=(1, 1)) # Validate window_strides _make_reduce_window_harness("window_strides", window_strides=(1, 2)) # Validate padding _make_reduce_window_harness("padding", padding=((1, 2), (0, 3))) # Validate base_dilation _make_reduce_window_harness("base_dilation", base_dilation=(1, 2)) # Validate window_dilation _make_reduce_window_harness("window_dilation", window_dilation=(1, 2)) # Validate squeezing behavior and dimensions in tf.nn.max_pool for shape, window_dimensions in [ ((2,), (2,)), # 1 spatial dimension, left and right squeeze ((2, 1), (2, 1)), # 1 spatial dimension, left squeeze ((1, 2), (1, 2)), # 1 spatial dimension, right squeeze ((1, 2, 1), (1, 2, 1)), # 1 spatial dimension no squeeze ((2, 4), (2, 2)), # 2 spatial dimensions, left and right squeeze ((2, 4, 3), (2, 2, 2)), # 3 spatial dimensions, left and right squeeze ((1, 4, 3, 2, 1), (1, 2, 2, 2, 1)) # 3 spatial dimensions, no squeeze ]: _make_reduce_window_harness( "squeeze_dim", computation=lax.max, shape=shape, dtype=np.float32, init_value=-np.inf, base_dilation=tuple([1] * len(shape)), window_dilation=tuple([1] * len(shape)), padding=tuple([(0, 0)] * len(shape)), window_strides=tuple([1] * len(shape)), window_dimensions=window_dimensions) def _make_reducer_harness(prim, name, *, shape=(2, 3), axes=(0,), dtype=np.int32): define( prim, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}", lambda arg: prim.bind(arg, axes=axes), [RandArg(shape, dtype)], prim=prim, shape=shape, dtype=dtype, axes=axes) for prim in [ lax.reduce_sum_p, lax.reduce_prod_p, lax.reduce_max_p, lax.reduce_min_p, lax.reduce_or_p, lax.reduce_and_p ]: for dtype in { lax.reduce_sum_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean), lax.reduce_prod_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean), lax.reduce_max_p: jtu.dtypes.all, lax.reduce_min_p: jtu.dtypes.all, lax.reduce_or_p: jtu.dtypes.boolean, lax.reduce_and_p: jtu.dtypes.boolean }[prim]: _make_reducer_harness(prim, "dtypes", dtype=dtype) for dtype in (np.float32, np.float64): for shape in ((), (3,)): define( "random_gamma", f"shape={jtu.format_shape_dtype_string(shape, dtype)}", jax.jit(jax.random.gamma), [np.array([42, 43], dtype=np.uint32), RandArg(shape, dtype)], dtype=dtype) for key_i, key in enumerate([ np.array([0, 0], dtype=np.uint32), np.array([42, 43], dtype=np.uint32), np.array([0xFFFFFFFF, 0], dtype=np.uint32), np.array([0, 0xFFFFFFFF], dtype=np.uint32), np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32) ]): define( "random_split", f"i={key_i}", jax.jit(lambda key: jax.random.split(key, 2)), [key], dtype=key.dtype) def _make_clamp_harness(name, *, min_shape=(), operand_shape=(2, 3), max_shape=(), dtype=np.float32, min_max=None): min_arr, max_arr = ( min_max if min_max is not None else [RandArg(min_shape, dtype), RandArg(max_shape, dtype)]) define( lax.clamp_p, f"{name}_min={jtu.format_shape_dtype_string(min_arr.shape, min_arr.dtype)}_operand={jtu.format_shape_dtype_string(operand_shape, dtype)}_max={jtu.format_shape_dtype_string(max_arr.shape, max_arr.dtype)}", lax.clamp, [min_arr, RandArg(operand_shape, dtype), max_arr], min_shape=min_arr.shape, operand_shape=operand_shape, max_shape=max_arr.shape, dtype=dtype) for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_]): _make_clamp_harness("dtypes", dtype=dtype) # Validate broadcasting of min/max arrays for min_shape, operand_shape, max_shape in [ ((), (2, 3), (2, 3)), # no broadcasting for max ((2, 3), (2, 3), ()), # no broadcasting for min ((2, 3), (2, 3), (2, 3)), # no broadcasting ]: _make_clamp_harness( "broadcasting", min_shape=min_shape, max_shape=max_shape, operand_shape=operand_shape) # Validate clamping when minval > maxval, and when minval < maxval for is_ordered, min_arr, max_arr in [ (False, np.array(4., dtype=np.float32), np.array(1., dtype=np.float32)), (True, np.array(1., dtype=np.float32), np.array(4., dtype=np.float32)) ]: _make_clamp_harness( f"order={is_ordered}", min_max=(min_arr, max_arr), dtype=np.float32) def _make_dot_general_harness(name, *, lhs_shape=(3, 4), rhs_shape=(4, 2), dtype=np.float32, precision=None, dimension_numbers=(((1,), (0,)), ((), ())), preferred_element_type=None): suffix = "" if precision is not None: suffix += f"_precision={precision}" if preferred_element_type is not None: suffix += f"_preferred={jtu.dtype_str(preferred_element_type)}" define( lax.dot_general_p, f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}{suffix}" .replace(" ", ""), lax.dot_general, [ RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype), StaticArg(dimension_numbers), StaticArg(precision), StaticArg(preferred_element_type) ], dtype=dtype, lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, jax_unimplemented=[ Limitation( "preferred_element_type=c128 not implemented", devices="tpu", dtypes=np.complex64, enabled=(preferred_element_type in [np.complex128])), Limitation( "preferred_element_type=f64 crashes (b/187884887)", devices="tpu", dtypes=(np.float16, jnp.bfloat16, np.float32), enabled=(preferred_element_type in [np.float64]), skip_run=True), Limitation( "preferred_element_type=i64 not implemented", devices="tpu", dtypes=(np.int8, np.int16, np.int32), enabled=(preferred_element_type in [np.int64])), ], ) # There are two execution paths in the conversion of dot_general. The main path # uses tf.einsum, while special cases use tf.linalg.matmul. For that reason, # the below tests are designed to perform the same checks on both execution # paths. # Validate dtypes and precision # This first harness runs the tests for all dtypes and precisions using # default values for all the other parameters. Variations of other parameters # can thus safely skip testing their corresponding default value. for dtype in jtu.dtypes.all: for precision in [ None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST ]: for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 4), (4, 2), (((1,), (0,)), ((), ()))), ((1, 3, 4), (1, 4, 3), (((2, 1), (1, 2)), ((0,), (0,)))), # Some batch dimensions ((7, 3, 4), (7, 4), (((2,), (1,)), ((0,), (0,)))), ]: _make_dot_general_harness( "dtypes_and_precision", precision=precision, lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers, dtype=dtype) # Validate batch dimensions for lhs_shape, rhs_shape, dimension_numbers in [ # Unique pattern that can go through tf.linalg.matmul ((4, 4, 3, 3, 4), (4, 4, 3, 4, 2), (((4,), (3,)), ((0, 1, 2), (0, 1, 2)))), # Main path with out of order batch dimensions ((8, 4, 3, 3, 4), (4, 8, 3, 4, 2), (((4, 3), (3, 2)), ((0, 1), (1, 0)))) ]: _make_dot_general_harness( "batch_dimensions", lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) # Validate squeezing behavior for matmul path for lhs_shape, rhs_shape, dimension_numbers in [ ((4,), (4, 4), (((0,), (0,)), ((), ()))), # (1, 4) -> (4,) ((4, 4), (4,), (((1,), (0,)), ((), ()))), # (4, 1) -> (4,) ((4,), (4,), (((0,), (0,)), ((), ()))), # (1, 1) -> () ]: _make_dot_general_harness( "squeeze", lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) # Validate preferred element type # From lax_test.py preferred_type_combinations = [ (np.float16, np.float16), (np.float16, np.float32), (np.float16, np.float64), (dtypes.bfloat16, np.float32), (dtypes.bfloat16, np.float64), (np.float32, np.float32), (np.float32, np.float64), (np.int8, np.int16), (np.int8, np.int32), (np.int8, np.int64), (np.int16, np.int32), (np.int16, np.int64), (np.int32, np.int32), (np.int32, np.int64), (np.complex64, np.complex128)] for lhs_shape in [(3,), (4, 3)]: for rhs_shape in [(3, ), (3, 6)]: for dtype, preferred_element_type in preferred_type_combinations: _make_dot_general_harness( "preferred", dtype=dtype, lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=(((len(lhs_shape) - 1,), (0,)), ((), ())), preferred_element_type=preferred_element_type) def _make_concatenate_harness(name, *, shapes=[(2, 3), (2, 3)], dimension=0, dtype=np.float32): shapes_str = "_".join(jtu.format_shape_dtype_string(s, dtype) for s in shapes) define( lax.concatenate_p, f"{name}_shapes={shapes_str}_dimension={dimension}", lambda *args: lax.concatenate_p.bind(*args, dimension=dimension), [RandArg(shape, dtype) for shape in shapes], shapes=shapes, dtype=dtype, dimension=dimension) for dtype in jtu.dtypes.all: _make_concatenate_harness("dtypes", dtype=dtype) # Validate dimension; non-major axis _make_concatenate_harness("dimension", dimension=1) # Validate > 2 operands for shapes in [ [(2, 3, 4), (3, 3, 4), (4, 3, 4)], # 3 operands ]: _make_concatenate_harness("nb_operands", shapes=shapes) def _make_conv_harness(name, *, lhs_shape=(2, 3, 9, 10), rhs_shape=(3, 3, 4, 5), dtype=np.float32, window_strides=(1, 1), precision=None, padding=((0, 0), (0, 0)), lhs_dilation=(1, 1), rhs_dilation=(1, 1), feature_group_count=1, dimension_numbers=("NCHW", "OIHW", "NCHW"), batch_group_count=1, enable_xla=True): define( lax.conv_general_dilated_p, f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}_enablexla={enable_xla}" .replace(" ", ""), lax.conv_general_dilated, [ RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype), StaticArg(window_strides), StaticArg(padding), StaticArg(lhs_dilation), StaticArg(rhs_dilation), StaticArg(dimension_numbers), StaticArg(feature_group_count), StaticArg(batch_group_count), StaticArg(precision) ], lhs_shape=lhs_shape, rhs_shape=rhs_shape, dtype=dtype, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, enable_xla=enable_xla) # Validate dtypes and precision for dtype in jtu.dtypes.all_inexact: for precision in [ None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST ]: # This first harness runs the tests for all dtypes and precisions using # default values for all the other parameters. Variations of other parameters # can thus safely skip testing their corresponding default value. _make_conv_harness("dtype_precision", dtype=dtype, precision=precision) # Validate variations of feature_group_count and batch_group_count for batch_group_count, feature_group_count in [ (1, 2), # feature_group_count != 1 (2, 1), # batch_group_count != 1 ]: for lhs_shape, rhs_shape in [ ((2 * batch_group_count, 3 * feature_group_count, 9, 10), (3 * feature_group_count * batch_group_count, 3, 4, 5)) ]: _make_conv_harness( "group_counts", lhs_shape=lhs_shape, rhs_shape=rhs_shape, feature_group_count=feature_group_count, batch_group_count=batch_group_count) # Validate variations of window_strides for window_strides in [(2, 3)]: _make_conv_harness("window_strides", window_strides=window_strides) # Validate variations of padding for padding in [ ((1, 2), (0, 0)), # padding only one spatial axis ((1, 2), (2, 1)) # padding on both spatial axes ]: _make_conv_harness("padding", padding=padding) # Validate variations of dilations for lhs_dilation, rhs_dilation in [ ((2, 2), (1, 1)), # dilation only on LHS (transposed) ((1, 1), (2, 3)), # dilation only on RHS (atrous) ((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous) ]: _make_conv_harness( "dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation) # Dimension numbers and corresponding permutation for dimension_numbers, lhs_shape, rhs_shape in [ (("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)), # TF default (("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)), # custom ]: _make_conv_harness( "dimension_numbers", lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) for padding, lhs_dilation, rhs_dilation in [ ("VALID", (1,), (1,)), # no dilation with "VALID" padding ("SAME", (1,), (1,)), # no dilation with "SAME" padding ("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding ("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding # TODO(bchetioui): LHS dilation with string padding can never be done using # TF convolution functions for now. ]: for dimension_numbers, lhs_shape, rhs_shape in [ (("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default # TODO(bchetioui): the NCW data format is not supported on CPU for TF # for now. That path is thus disabled to allow the code to use XLA instead. ]: for enable_xla in [False, True]: _make_conv_harness( "tf_conversion_path_1d", lhs_shape=lhs_shape, padding=padding, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers, window_strides=(1,), lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, enable_xla=enable_xla) for padding, lhs_dilation, rhs_dilation in [ ("VALID", (1, 1), (1, 1)), # no dilation with "VALID" padding ("SAME", (1, 1), (1, 1)), # no dilation with "SAME" padding ("VALID", (1, 1), (1, 2)), # dilation only on RHS with "VALID" padding ("SAME", (1, 1), (1, 2)), # dilation only on RHS with "SAME" padding # TODO(bchetioui): LHS dilation with string padding can never be done using # TF convolution functions for now. ]: for dimension_numbers, lhs_shape, rhs_shape in [ (("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default # TODO(bchetioui): the NCHW data format is not supported on CPU for TF # for now. That path is thus disabled to allow the code to use XLA instead. ]: for enable_xla in [False, True]: _make_conv_harness( "tf_conversion_path_2d", lhs_shape=lhs_shape, padding=padding, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers, window_strides=(1, 1), lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, enable_xla=enable_xla) for padding, lhs_dilation, rhs_dilation in [ ("VALID", (1, 1, 1), (1, 1, 1)), # no dilation with "VALID" padding ("SAME", (1, 1, 1), (1, 1, 1)), # no dilation with "SAME" padding ("VALID", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "VALID" padding ("SAME", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "SAME" padding # TODO(bchetioui): LHS dilation with string padding can never be done using # TF convolution functions for now. ]: for dimension_numbers, lhs_shape, rhs_shape in [ # TF default (("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)), # TODO(bchetioui): the NCDHW data format is not supported on CPU for TF # for now. That path is thus disabled to allow the code to use XLA instead. ]: for enable_xla in [False, True]: _make_conv_harness( "tf_conversion_path_3d", lhs_shape=lhs_shape, padding=padding, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers, window_strides=(1, 1, 1), lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, enable_xla=enable_xla)