2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-09-24 07:02:08 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2024-06-06 12:40:21 -07:00
|
|
|
|
|
|
|
# pyformat: disable
|
2023-09-07 08:45:48 -07:00
|
|
|
from __future__ import annotations
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
import collections
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Generator, Iterable, Sequence
|
2024-05-15 22:06:11 -07:00
|
|
|
from contextlib import ExitStack, contextmanager
|
2024-02-27 15:24:12 -08:00
|
|
|
import datetime
|
2021-09-24 07:02:08 -07:00
|
|
|
import functools
|
|
|
|
from functools import partial
|
2024-05-15 22:06:11 -07:00
|
|
|
import inspect
|
2024-07-15 13:08:57 +01:00
|
|
|
import logging
|
2023-02-28 12:40:30 -08:00
|
|
|
import math
|
2021-09-24 07:02:08 -07:00
|
|
|
import os
|
2024-05-15 22:06:11 -07:00
|
|
|
import re
|
2024-06-04 09:40:56 -07:00
|
|
|
import sys
|
2023-02-07 15:14:53 -08:00
|
|
|
import tempfile
|
2021-09-24 07:02:08 -07:00
|
|
|
import textwrap
|
2024-07-09 09:14:19 -07:00
|
|
|
from typing import Any, TextIO
|
2021-09-24 07:02:08 -07:00
|
|
|
import unittest
|
|
|
|
import warnings
|
|
|
|
import zlib
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
2022-06-16 13:59:53 -07:00
|
|
|
import jax
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax import lax
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import api
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import config
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import core
|
|
|
|
from jax._src import dispatch
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import dtypes as _dtypes
|
2024-05-15 22:06:11 -07:00
|
|
|
from jax._src import linear_util as lu
|
2023-08-22 13:35:07 -07:00
|
|
|
from jax._src import monitoring
|
2024-05-15 22:06:11 -07:00
|
|
|
from jax._src import pjit as pjit_lib
|
2023-09-13 09:43:14 -07:00
|
|
|
from jax._src import stages
|
2024-05-15 22:06:11 -07:00
|
|
|
from jax._src import xla_bridge
|
2023-11-15 00:32:21 +00:00
|
|
|
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
2024-05-15 22:06:11 -07:00
|
|
|
from jax._src.interpreters import mlir
|
2023-03-08 19:12:37 -08:00
|
|
|
from jax._src.interpreters import pxla
|
2024-05-15 22:06:11 -07:00
|
|
|
from jax._src.lib import xla_client as xc
|
2023-03-13 12:18:36 -07:00
|
|
|
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
2022-04-21 13:44:12 -07:00
|
|
|
from jax._src.public_test_util import ( # noqa: F401
|
2022-04-04 14:39:43 -07:00
|
|
|
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
2024-05-15 22:06:11 -07:00
|
|
|
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance)
|
|
|
|
from jax._src.util import unzip2
|
|
|
|
from jax.experimental.compilation_cache import compilation_cache
|
|
|
|
from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten
|
|
|
|
import numpy as np
|
|
|
|
import numpy.random as npr
|
2022-12-16 20:59:41 -08:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2022-04-04 14:39:43 -07:00
|
|
|
# This submodule includes private test utilities that are not exported to
|
|
|
|
# jax.test_util. Functionality appearing here is for internal use only, and
|
|
|
|
# may be changed or removed at any time and without any deprecation cycle.
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_TEST_DUT = config.string_flag(
|
2023-09-06 13:12:51 -07:00
|
|
|
'jax_test_dut', '',
|
|
|
|
help=
|
|
|
|
'Describes the device under test in case special consideration is required.'
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
NUM_GENERATED_CASES = config.int_flag(
|
2022-11-09 18:57:28 -08:00
|
|
|
'jax_num_generated_cases',
|
2021-10-04 17:54:18 -07:00
|
|
|
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
|
2021-09-24 07:02:08 -07:00
|
|
|
help='Number of generated cases to test')
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_MAX_CASES_SAMPLING_RETRIES = config.int_flag(
|
2021-09-24 07:02:08 -07:00
|
|
|
'max_cases_sampling_retries',
|
2021-10-04 17:54:18 -07:00
|
|
|
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
|
2021-09-24 07:02:08 -07:00
|
|
|
'Number of times a failed test sample should be retried. '
|
|
|
|
'When an unseen case cannot be generated in this many trials, the '
|
|
|
|
'sampling process is terminated.'
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_SKIP_SLOW_TESTS = config.bool_flag(
|
2021-09-24 07:02:08 -07:00
|
|
|
'jax_skip_slow_tests',
|
2023-10-12 13:15:22 +01:00
|
|
|
config.bool_env('JAX_SKIP_SLOW_TESTS', False),
|
2021-09-24 07:02:08 -07:00
|
|
|
help='Skip tests marked as slow (> 5 sec).'
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_TEST_TARGETS = config.string_flag(
|
2022-07-06 12:51:07 -07:00
|
|
|
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
2021-09-28 18:42:44 +02:00
|
|
|
'Regular expression specifying which tests to run, called via re.search on '
|
2021-09-24 07:02:08 -07:00
|
|
|
'the test name. If empty or unspecified, run all tests.'
|
|
|
|
)
|
2024-04-15 10:35:50 +01:00
|
|
|
_EXCLUDE_TEST_TARGETS = config.string_flag(
|
2022-07-06 12:51:07 -07:00
|
|
|
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
2021-09-28 18:42:44 +02:00
|
|
|
'Regular expression specifying which tests NOT to run, called via re.search '
|
2021-09-24 07:02:08 -07:00
|
|
|
'on the test name. If empty or unspecified, run all tests.'
|
|
|
|
)
|
2024-04-15 10:35:50 +01:00
|
|
|
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag(
|
2023-02-07 15:14:53 -08:00
|
|
|
'jax_test_with_persistent_compilation_cache',
|
2023-10-12 13:15:22 +01:00
|
|
|
config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
2023-02-07 15:14:53 -08:00
|
|
|
help='If enabled, the persistent compilation cache will be enabled for all '
|
|
|
|
'test cases. This can be used to increase compilation cache coverage.')
|
|
|
|
|
2024-07-15 13:08:57 +01:00
|
|
|
HYPOTHESIS_PROFILE = config.string_flag(
|
|
|
|
'hypothesis_profile',
|
|
|
|
os.getenv('JAX_HYPOTHESIS_PROFILE', 'deterministic'),
|
|
|
|
help=('Select the hypothesis profile to use for testing. Available values: '
|
|
|
|
'deterministic, interactive'),
|
|
|
|
)
|
|
|
|
|
2023-07-27 11:30:34 +02:00
|
|
|
# We sanitize test names to ensure they work with "unitttest -k" and
|
|
|
|
# "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
|
|
|
|
# does not. We replace sequences of problematic characters with a single '_'.
|
|
|
|
kSanitizeNameRE = re.compile(r"[ \"'\[\](){}<>=,._]+")
|
|
|
|
def sanitize_test_name(s: str) -> str:
|
|
|
|
return kSanitizeNameRE.sub("_", s)
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def num_float_bits(dtype):
|
|
|
|
return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits
|
|
|
|
|
2021-12-09 16:57:29 -08:00
|
|
|
def to_default_dtype(arr):
|
|
|
|
"""Convert a value to an array with JAX's default dtype.
|
|
|
|
|
|
|
|
This is generally used for type conversions of values returned by numpy functions,
|
|
|
|
to make their dtypes take into account the state of the ``jax_enable_x64`` and
|
|
|
|
``jax_default_dtype_bits`` flags.
|
|
|
|
"""
|
|
|
|
arr = np.asarray(arr)
|
2021-12-09 09:47:21 -08:00
|
|
|
dtype = _dtypes._default_types.get(arr.dtype.kind)
|
|
|
|
return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr
|
|
|
|
|
|
|
|
def with_jax_dtype_defaults(func, use_defaults=True):
|
|
|
|
"""Return a version of a function with outputs that match JAX's default dtypes.
|
|
|
|
|
|
|
|
This is generally used to wrap numpy functions within tests, in order to make
|
|
|
|
their default output dtypes match those of corresponding JAX functions, taking
|
|
|
|
into account the state of the ``jax_enable_x64`` and ``jax_default_dtype_bits``
|
|
|
|
flags.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
use_defaults : whether to convert any given output to the default dtype. May be
|
|
|
|
a single boolean, in which case it specifies the conversion for all outputs,
|
2023-12-13 07:45:52 +01:00
|
|
|
or may be a pytree with the same structure as the function output.
|
2021-12-09 09:47:21 -08:00
|
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
|
|
def wrapped(*args, **kwargs):
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
if isinstance(use_defaults, bool):
|
2021-12-09 16:57:29 -08:00
|
|
|
return tree_map(to_default_dtype, result) if use_defaults else result
|
2021-12-09 09:47:21 -08:00
|
|
|
else:
|
2021-12-09 16:57:29 -08:00
|
|
|
f = lambda arr, use_default: to_default_dtype(arr) if use_default else arr
|
2021-12-09 09:47:21 -08:00
|
|
|
return tree_map(f, result, use_defaults)
|
|
|
|
return wrapped
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def is_sequence(x):
|
|
|
|
try:
|
|
|
|
iter(x)
|
|
|
|
except TypeError:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
|
|
|
def _normalize_tolerance(tol):
|
|
|
|
tol = tol or 0
|
|
|
|
if isinstance(tol, dict):
|
|
|
|
return {np.dtype(k): v for k, v in tol.items()}
|
|
|
|
else:
|
2024-08-27 14:54:11 -07:00
|
|
|
return dict.fromkeys(_default_tolerance, tol)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def join_tolerance(tol1, tol2):
|
|
|
|
tol1 = _normalize_tolerance(tol1)
|
|
|
|
tol2 = _normalize_tolerance(tol2)
|
|
|
|
out = tol1
|
|
|
|
for k, v in tol2.items():
|
|
|
|
out[k] = max(v, tol1.get(k, 0))
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def check_eq(xs, ys, err_msg=''):
|
|
|
|
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
|
2022-04-01 14:51:54 -07:00
|
|
|
tree_all(tree_map(assert_close, xs, ys))
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
2022-09-02 10:33:10 -07:00
|
|
|
@contextmanager
|
2024-07-09 09:14:19 -07:00
|
|
|
def _capture_output(fp: TextIO) -> Generator[Callable[[], str], None, None]:
|
|
|
|
"""Context manager to capture all output written to a given file object.
|
2024-06-04 09:40:56 -07:00
|
|
|
|
2024-07-02 13:45:40 -07:00
|
|
|
Unlike ``contextlib.redirect_stdout``, this context manager works for
|
2024-07-09 09:14:19 -07:00
|
|
|
any file object and also for both pure Python and native code.
|
2024-07-02 13:45:40 -07:00
|
|
|
|
|
|
|
Example::
|
|
|
|
|
2024-07-09 09:14:19 -07:00
|
|
|
with capture_output(sys.stdout) as get_output:
|
2024-07-02 13:45:40 -07:00
|
|
|
print(42)
|
2024-07-09 09:14:19 -07:00
|
|
|
print("Captured": get_output())
|
2024-06-04 09:40:56 -07:00
|
|
|
|
2024-07-02 13:45:40 -07:00
|
|
|
Yields:
|
|
|
|
A function returning the captured output. The function must be called
|
|
|
|
*after* the context is no longer active.
|
|
|
|
"""
|
2024-07-09 09:14:19 -07:00
|
|
|
# ``None`` means nothing has not been captured yet.
|
2024-07-02 13:45:40 -07:00
|
|
|
captured = None
|
|
|
|
|
2024-07-09 09:14:19 -07:00
|
|
|
def get_output() -> str:
|
2024-07-02 13:45:40 -07:00
|
|
|
if captured is None:
|
2024-07-09 09:14:19 -07:00
|
|
|
raise ValueError("get_output() called while the context is active.")
|
2024-07-02 13:45:40 -07:00
|
|
|
return captured
|
2024-06-04 09:40:56 -07:00
|
|
|
|
2024-07-02 13:45:40 -07:00
|
|
|
with tempfile.NamedTemporaryFile(mode="w+", encoding='utf-8') as f:
|
2024-07-09 09:14:19 -07:00
|
|
|
original_fd = os.dup(fp.fileno())
|
|
|
|
os.dup2(f.fileno(), fp.fileno())
|
2024-06-04 09:40:56 -07:00
|
|
|
try:
|
2024-07-09 09:14:19 -07:00
|
|
|
yield get_output
|
2024-06-04 09:40:56 -07:00
|
|
|
finally:
|
|
|
|
# Python also has its own buffers, make sure everything is flushed.
|
2024-07-09 09:14:19 -07:00
|
|
|
fp.flush()
|
|
|
|
os.fsync(fp.fileno())
|
2024-06-04 09:40:56 -07:00
|
|
|
f.seek(0)
|
|
|
|
captured = f.read()
|
2024-07-09 09:14:19 -07:00
|
|
|
os.dup2(original_fd, fp.fileno())
|
|
|
|
|
|
|
|
|
|
|
|
capture_stdout = partial(_capture_output, sys.stdout)
|
|
|
|
capture_stderr = partial(_capture_output, sys.stderr)
|
2022-09-02 10:33:10 -07:00
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
@contextmanager
|
|
|
|
def count_device_put():
|
2023-03-08 19:12:37 -08:00
|
|
|
batched_device_put = pxla.batched_device_put
|
2021-09-24 07:02:08 -07:00
|
|
|
count = [0]
|
|
|
|
|
2023-03-08 19:12:37 -08:00
|
|
|
def make_fn_and_count(fn):
|
|
|
|
def fn_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
# device_put handlers might call `dispatch.device_put` (e.g. on an
|
|
|
|
# underlying payload or several). We only want to count these
|
|
|
|
# recursive puts once, so we skip counting more than the outermost
|
|
|
|
# one in such a call stack.
|
|
|
|
pxla.batched_device_put = batched_device_put
|
|
|
|
try:
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
finally:
|
|
|
|
pxla.batched_device_put = batched_device_put_and_count
|
|
|
|
return fn_and_count
|
|
|
|
|
|
|
|
batched_device_put_and_count = make_fn_and_count(batched_device_put)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2023-03-08 19:12:37 -08:00
|
|
|
pxla.batched_device_put = batched_device_put_and_count
|
2021-09-24 07:02:08 -07:00
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
2023-03-08 19:12:37 -08:00
|
|
|
pxla.batched_device_put = batched_device_put
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def count_primitive_compiles():
|
2021-11-22 08:22:10 -08:00
|
|
|
dispatch.xla_primitive_callable.cache_clear()
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2021-10-13 10:56:21 -04:00
|
|
|
count = [-1]
|
2021-09-24 07:02:08 -07:00
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
2021-11-22 08:22:10 -08:00
|
|
|
count[0] = dispatch.xla_primitive_callable.cache_info().misses
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
2024-02-23 10:02:19 -08:00
|
|
|
@contextmanager
|
|
|
|
def count_device_put_fast_path_hit():
|
2024-06-18 11:31:09 -04:00
|
|
|
original_fn = xc.batched_copy_array_to_devices_with_sharding
|
|
|
|
count = [0]
|
2024-06-13 13:09:35 -07:00
|
|
|
|
2024-06-18 11:31:09 -04:00
|
|
|
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return original_fn(*args, **kwargs)
|
2024-06-13 13:09:35 -07:00
|
|
|
|
2024-06-18 11:31:09 -04:00
|
|
|
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
xc.batched_copy_array_to_devices_with_sharding = original_fn
|
2024-02-23 10:02:19 -08:00
|
|
|
|
|
|
|
|
2023-02-06 20:34:51 -08:00
|
|
|
@contextmanager
|
2023-04-03 14:14:20 -07:00
|
|
|
def count_pjit_cpp_cache_miss():
|
2023-02-06 20:34:51 -08:00
|
|
|
original_pjit_lower = pjit_lib._pjit_lower
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def pjit_lower_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return original_pjit_lower(*args, **kwargs)
|
|
|
|
|
|
|
|
pjit_lib._pjit_lower = pjit_lower_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
pjit_lib._pjit_lower = original_pjit_lower
|
|
|
|
|
2024-05-29 01:49:06 -07:00
|
|
|
@contextmanager
|
|
|
|
def count_cached_compilation_cache_miss():
|
|
|
|
original_cached_compilation = pxla._cached_compilation
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def cached_compilation_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return original_cached_compilation(*args, **kwargs)
|
|
|
|
|
|
|
|
pxla._cached_compilation = cached_compilation_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
pxla._cached_compilation = original_cached_compilation
|
2023-02-06 20:34:51 -08:00
|
|
|
|
2024-01-18 13:11:14 -08:00
|
|
|
@contextmanager
|
|
|
|
def count_jit_tracing_cache_miss():
|
|
|
|
original_create_pjit_jaxpr = pjit_lib._create_pjit_jaxpr
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
@lu.cache
|
|
|
|
def create_pjit_jaxpr_and_count(*args):
|
|
|
|
count[0] += 1
|
|
|
|
return original_create_pjit_jaxpr(*args)
|
|
|
|
|
|
|
|
pjit_lib._create_pjit_jaxpr = create_pjit_jaxpr_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
@contextmanager
|
|
|
|
def count_jit_infer_params_cache_miss():
|
|
|
|
original_infer_params_impl = pjit_lib._infer_params_impl
|
|
|
|
count = collections.defaultdict(int)
|
|
|
|
|
|
|
|
def infer_params_impl_and_count(fun, *args, **kw):
|
|
|
|
count[fun] += 1
|
|
|
|
return original_infer_params_impl(fun, *args, **kw)
|
|
|
|
|
|
|
|
pjit_lib._infer_params_impl = infer_params_impl_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
pjit_lib._infer_params_impl = original_infer_params_impl
|
|
|
|
|
2024-01-18 13:11:14 -08:00
|
|
|
|
2023-09-13 09:43:14 -07:00
|
|
|
@contextmanager
|
|
|
|
def count_aot_jit_cpp_cache_miss():
|
|
|
|
original_call = stages.Compiled.call
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def compiled_call_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return original_call(*args, **kwargs)
|
|
|
|
|
|
|
|
stages.Compiled.call = compiled_call_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
stages.Compiled.call = original_call
|
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
@contextmanager
|
2024-08-09 20:03:06 -07:00
|
|
|
def count_jit_and_pmap_lowerings():
|
2021-09-24 07:02:08 -07:00
|
|
|
# No need to clear any caches since we generally jit and pmap fresh callables
|
|
|
|
# in tests.
|
|
|
|
|
2023-04-20 21:22:16 -07:00
|
|
|
mlir_lower = mlir.lower_jaxpr_to_module
|
2021-09-24 07:02:08 -07:00
|
|
|
count = [0]
|
|
|
|
|
2023-04-20 21:22:16 -07:00
|
|
|
def mlir_lower_and_count(*args, **kwargs):
|
2021-11-30 06:08:26 -08:00
|
|
|
count[0] += 1
|
2023-04-20 21:22:16 -07:00
|
|
|
return mlir_lower(*args, **kwargs)
|
2021-11-30 06:08:26 -08:00
|
|
|
|
2023-04-20 21:22:16 -07:00
|
|
|
mlir.lower_jaxpr_to_module = mlir_lower_and_count
|
2021-09-24 07:02:08 -07:00
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
2023-04-20 21:22:16 -07:00
|
|
|
mlir.lower_jaxpr_to_module = mlir_lower
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2023-10-30 15:27:17 -07:00
|
|
|
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
@contextmanager
|
|
|
|
def count_jit_compilation_cache_miss():
|
|
|
|
# No need to clear any caches since we generally jit and pmap fresh callables
|
|
|
|
# in tests.
|
|
|
|
|
|
|
|
jit_compilation = pxla._cached_compilation
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def compile_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return jit_compilation(*args, **kwargs)
|
|
|
|
|
|
|
|
pxla._cached_compilation = compile_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
pxla._cached_compilation = jit_compilation
|
|
|
|
|
|
|
|
|
2023-10-30 15:27:17 -07:00
|
|
|
@contextmanager
|
2024-01-15 02:12:52 -08:00
|
|
|
def count_subjaxpr_to_hlo_conversion(fun_name: str):
|
2023-10-30 15:27:17 -07:00
|
|
|
# No need to clear any caches since we generally jit and pmap fresh callables
|
|
|
|
# in tests.
|
|
|
|
|
|
|
|
mlir_lower = mlir.lower_jaxpr_to_fun
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def mlir_lower_and_count(ctx, name, *args, **kwargs):
|
|
|
|
if name == fun_name:
|
|
|
|
count[0] += 1
|
|
|
|
return mlir_lower(ctx, name, *args, **kwargs)
|
|
|
|
|
|
|
|
mlir.lower_jaxpr_to_fun = mlir_lower_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
mlir.lower_jaxpr_to_fun = mlir_lower
|
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
@contextmanager
|
|
|
|
def assert_num_jit_and_pmap_compilations(times):
|
2024-08-09 20:03:06 -07:00
|
|
|
with count_jit_and_pmap_lowerings() as count:
|
2021-09-24 07:02:08 -07:00
|
|
|
yield
|
|
|
|
if count[0] != times:
|
|
|
|
raise AssertionError(f"Expected exactly {times} XLA compilations, "
|
|
|
|
f"but executed {count[0]}")
|
|
|
|
|
2023-09-06 13:12:51 -07:00
|
|
|
|
|
|
|
def device_under_test():
|
|
|
|
return _TEST_DUT.value or xla_bridge.get_backend().platform
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def supported_dtypes():
|
|
|
|
if device_under_test() == "tpu":
|
|
|
|
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
|
|
|
|
np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
|
2024-02-22 18:23:43 -08:00
|
|
|
elif device_under_test() == "METAL":
|
|
|
|
types = {np.int32, np.uint32, np.float32}
|
2021-09-24 07:02:08 -07:00
|
|
|
else:
|
|
|
|
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
|
|
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
|
|
_dtypes.bfloat16, np.float16, np.float32, np.float64,
|
|
|
|
np.complex64, np.complex128}
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value:
|
2021-09-24 07:02:08 -07:00
|
|
|
types -= {np.uint64, np.int64, np.float64, np.complex128}
|
|
|
|
return types
|
|
|
|
|
|
|
|
def is_device_rocm():
|
2024-06-12 18:45:01 -05:00
|
|
|
return 'rocm' in xla_bridge.get_backend().platform_version
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def is_device_cuda():
|
2023-10-12 14:44:35 -07:00
|
|
|
return 'cuda' in xla_bridge.get_backend().platform_version
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2022-10-25 19:29:32 +00:00
|
|
|
def is_cloud_tpu():
|
2023-11-15 00:32:21 +00:00
|
|
|
return running_in_cloud_tpu_vm
|
2023-05-11 10:50:47 -07:00
|
|
|
|
2024-02-27 15:24:12 -08:00
|
|
|
# Returns True if it is not cloud TPU. If it is cloud TPU, returns True if it is
|
|
|
|
# built at least `date``.
|
|
|
|
# TODO(b/327203806): after libtpu adds a XLA version and the oldest support
|
|
|
|
# libtpu contains the XLA version, remove using built time to skip tests.
|
|
|
|
def if_cloud_tpu_at_least(date: datetime.date):
|
|
|
|
if not is_cloud_tpu():
|
|
|
|
return True
|
|
|
|
# The format of Cloud TPU platform_version is like:
|
|
|
|
# PJRT C API
|
|
|
|
# TFRT TPU v2
|
|
|
|
# Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722
|
|
|
|
platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1]
|
|
|
|
results = re.findall(r'\(.*?\)', platform_version)
|
|
|
|
if len(results) != 1:
|
|
|
|
return True
|
|
|
|
build_date = date.fromtimestamp(int(results[0][1:-1]))
|
|
|
|
return build_date >= date
|
|
|
|
|
2023-11-09 13:28:52 -08:00
|
|
|
def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
|
|
|
|
pjrt_c_api_versions = xla_bridge.backend_pjrt_c_api_version()
|
|
|
|
if pjrt_c_api_versions is None:
|
|
|
|
return True
|
|
|
|
return pjrt_c_api_versions >= (major_version, minor_version)
|
|
|
|
|
2024-01-18 18:16:08 -08:00
|
|
|
def get_tpu_version() -> int:
|
|
|
|
if device_under_test() != "tpu":
|
|
|
|
raise ValueError("Device is not TPU")
|
|
|
|
kind = jax.devices()[0].device_kind
|
|
|
|
if kind.endswith(' lite'):
|
|
|
|
kind = kind[:-len(' lite')]
|
|
|
|
assert kind[:-1] == "TPU v", kind
|
|
|
|
return int(kind[-1])
|
|
|
|
|
|
|
|
def is_device_tpu_at_least(version: int) -> bool:
|
|
|
|
if device_under_test() != "tpu":
|
|
|
|
return False
|
|
|
|
return get_tpu_version() >= version
|
|
|
|
|
2023-12-06 17:46:46 -08:00
|
|
|
def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
|
|
|
|
if device_under_test() != "tpu":
|
|
|
|
return False
|
|
|
|
if version is None:
|
|
|
|
return True
|
|
|
|
device_kind = jax.devices()[0].device_kind
|
|
|
|
expected_version = f"v{version}{variant}"
|
|
|
|
# Special case v5e until the name is updated in device_kind
|
|
|
|
if expected_version == "v5e":
|
|
|
|
return "v5 lite" in device_kind
|
|
|
|
return expected_version in device_kind
|
2022-07-27 10:18:38 -07:00
|
|
|
|
2024-05-08 21:38:05 +01:00
|
|
|
def is_cuda_compute_capability_at_least(capability: str) -> bool:
|
|
|
|
if not is_device_cuda():
|
2024-05-08 20:29:18 +01:00
|
|
|
return False
|
|
|
|
d, *_ = jax.local_devices(backend="gpu")
|
2024-07-23 23:13:27 +00:00
|
|
|
target = tuple(int(x) for x in capability.split("."))
|
|
|
|
current = tuple(int(x) for x in d.compute_capability.split("."))
|
|
|
|
return current >= target
|
2024-05-08 20:29:18 +01:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def _get_device_tags():
|
2022-07-20 18:57:12 +09:00
|
|
|
"""returns a set of tags defined for the device under test"""
|
2021-09-24 07:02:08 -07:00
|
|
|
if is_device_rocm():
|
2022-05-12 19:13:00 +01:00
|
|
|
device_tags = {device_under_test(), "rocm"}
|
2021-09-24 07:02:08 -07:00
|
|
|
elif is_device_cuda():
|
2022-05-12 19:13:00 +01:00
|
|
|
device_tags = {device_under_test(), "cuda"}
|
2024-02-22 18:23:43 -08:00
|
|
|
elif device_under_test() == "METAL":
|
|
|
|
device_tags = {device_under_test(), "gpu"}
|
2021-09-24 07:02:08 -07:00
|
|
|
else:
|
2022-05-12 19:13:00 +01:00
|
|
|
device_tags = {device_under_test()}
|
2021-09-24 07:02:08 -07:00
|
|
|
return device_tags
|
|
|
|
|
2023-09-27 12:10:06 -07:00
|
|
|
def test_device_matches(device_types: Iterable[str]) -> bool:
|
|
|
|
assert not isinstance(
|
|
|
|
device_types, str
|
|
|
|
), 'device_types should be a list of strings'
|
|
|
|
tags = _get_device_tags()
|
|
|
|
for device_type in device_types:
|
|
|
|
assert isinstance(device_type, str), device_type
|
|
|
|
if device_type in tags:
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
test_device_matches.__test__ = False # This isn't a test case, pytest.
|
2023-09-13 16:35:02 +01:00
|
|
|
|
|
|
|
def _device_filter(predicate):
|
2021-09-24 07:02:08 -07:00
|
|
|
def skip(test_method):
|
|
|
|
@functools.wraps(test_method)
|
|
|
|
def test_method_wrapper(self, *args, **kwargs):
|
|
|
|
device_tags = _get_device_tags()
|
2023-09-27 12:10:06 -07:00
|
|
|
if not predicate():
|
2021-09-24 07:02:08 -07:00
|
|
|
test_name = getattr(test_method, '__name__', '[unknown test]')
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
f"{test_name} not supported on device with tags {device_tags}.")
|
|
|
|
return test_method(self, *args, **kwargs)
|
|
|
|
return test_method_wrapper
|
|
|
|
return skip
|
|
|
|
|
2023-09-13 16:35:02 +01:00
|
|
|
def skip_on_devices(*disabled_devices):
|
|
|
|
"""A decorator for test methods to skip the test on certain devices."""
|
2023-09-27 12:10:06 -07:00
|
|
|
return _device_filter(lambda: not test_device_matches(disabled_devices))
|
2023-09-13 16:35:02 +01:00
|
|
|
|
|
|
|
def run_on_devices(*enabled_devices):
|
|
|
|
"""A decorator for test methods to run the test only on certain devices."""
|
2023-09-27 12:10:06 -07:00
|
|
|
return _device_filter(lambda: test_device_matches(enabled_devices))
|
2023-09-13 16:35:02 +01:00
|
|
|
|
|
|
|
def device_supports_buffer_donation():
|
|
|
|
"""A decorator for test methods to run the test only on devices that support
|
|
|
|
buffer donation."""
|
2023-09-27 12:10:06 -07:00
|
|
|
return _device_filter(
|
|
|
|
lambda: test_device_matches(mlir._platforms_with_donation)
|
|
|
|
)
|
2023-09-13 16:35:02 +01:00
|
|
|
|
|
|
|
|
2024-06-06 14:18:27 -07:00
|
|
|
@contextmanager
|
2021-09-24 07:02:08 -07:00
|
|
|
def set_host_platform_device_count(nr_devices: int):
|
2024-06-06 14:18:27 -07:00
|
|
|
"""Context manager to set host platform device count if not specified by user.
|
|
|
|
|
|
|
|
This should only be used by tests at the top level in setUpModule(); it will
|
|
|
|
not work correctly if applied to individual test cases.
|
|
|
|
"""
|
2021-09-24 07:02:08 -07:00
|
|
|
prev_xla_flags = os.getenv("XLA_FLAGS")
|
|
|
|
flags_str = prev_xla_flags or ""
|
|
|
|
# Don't override user-specified device count, or other XLA flags.
|
|
|
|
if "xla_force_host_platform_device_count" not in flags_str:
|
|
|
|
os.environ["XLA_FLAGS"] = (flags_str +
|
|
|
|
f" --xla_force_host_platform_device_count={nr_devices}")
|
|
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
2024-06-06 14:18:27 -07:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2021-09-24 07:02:08 -07:00
|
|
|
if prev_xla_flags is None:
|
|
|
|
del os.environ["XLA_FLAGS"]
|
|
|
|
else:
|
|
|
|
os.environ["XLA_FLAGS"] = prev_xla_flags
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
|
2022-11-24 09:56:27 -08:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def skip_on_flag(flag_name, skip_value):
|
|
|
|
"""A decorator for test methods to skip the test when flags are set."""
|
|
|
|
def skip(test_method): # pylint: disable=missing-docstring
|
|
|
|
@functools.wraps(test_method)
|
|
|
|
def test_method_wrapper(self, *args, **kwargs):
|
|
|
|
flag_value = config._read(flag_name)
|
|
|
|
if flag_value == skip_value:
|
|
|
|
test_name = getattr(test_method, '__name__', '[unknown test]')
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}")
|
|
|
|
return test_method(self, *args, **kwargs)
|
|
|
|
return test_method_wrapper
|
|
|
|
return skip
|
|
|
|
|
|
|
|
|
2022-12-27 21:32:00 +00:00
|
|
|
def pytest_mark_if_available(marker: str):
|
|
|
|
"""A decorator for test classes or methods to pytest.mark if installed."""
|
|
|
|
def wrap(func_or_class):
|
|
|
|
try:
|
|
|
|
import pytest
|
|
|
|
except ImportError:
|
|
|
|
return func_or_class
|
|
|
|
return getattr(pytest.mark, marker)(func_or_class)
|
|
|
|
return wrap
|
|
|
|
|
|
|
|
|
2024-06-04 16:23:33 -04:00
|
|
|
def is_running_under_pytest():
|
|
|
|
return "pytest" in sys.modules
|
|
|
|
|
|
|
|
|
|
|
|
def skip_under_pytest(reason: str):
|
|
|
|
"""A decorator for test methods to skip the test when run under pytest."""
|
|
|
|
reason = "Running under pytest: " + reason
|
|
|
|
def skip(test_method):
|
|
|
|
return unittest.skipIf(is_running_under_pytest(), reason)(test_method)
|
|
|
|
return skip
|
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def format_test_name_suffix(opname, shapes, dtypes):
|
|
|
|
arg_descriptions = (format_shape_dtype_string(shape, dtype)
|
|
|
|
for shape, dtype in zip(shapes, dtypes))
|
|
|
|
return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions))
|
|
|
|
|
|
|
|
|
|
|
|
# We use special symbols, represented as singleton objects, to distinguish
|
|
|
|
# between NumPy scalars, Python scalars, and 0-D arrays.
|
2022-05-12 19:13:00 +01:00
|
|
|
class ScalarShape:
|
2021-09-24 07:02:08 -07:00
|
|
|
def __len__(self): return 0
|
2022-06-23 11:46:20 -07:00
|
|
|
def __getitem__(self, i): raise IndexError(f"index {i} out of range.")
|
2021-09-24 07:02:08 -07:00
|
|
|
class _NumpyScalar(ScalarShape): pass
|
|
|
|
class _PythonScalar(ScalarShape): pass
|
|
|
|
NUMPY_SCALAR_SHAPE = _NumpyScalar()
|
|
|
|
PYTHON_SCALAR_SHAPE = _PythonScalar()
|
|
|
|
|
|
|
|
|
2022-06-06 21:28:13 -07:00
|
|
|
# Some shape combinations don't make sense.
|
|
|
|
def is_valid_shape(shape, dtype):
|
|
|
|
if shape == PYTHON_SCALAR_SHAPE:
|
|
|
|
return dtype == np.dtype(type(np.array(0, dtype=dtype).item()))
|
2022-09-15 15:28:55 -07:00
|
|
|
return True
|
2022-06-06 21:28:13 -07:00
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def _dims_of_shape(shape):
|
|
|
|
"""Converts `shape` to a tuple of dimensions."""
|
|
|
|
if type(shape) in (list, tuple):
|
|
|
|
return shape
|
|
|
|
elif isinstance(shape, ScalarShape):
|
|
|
|
return ()
|
|
|
|
elif np.ndim(shape) == 0:
|
|
|
|
return (shape,)
|
|
|
|
else:
|
|
|
|
raise TypeError(type(shape))
|
|
|
|
|
|
|
|
|
|
|
|
def _cast_to_shape(value, shape, dtype):
|
|
|
|
"""Casts `value` to the correct Python type for `shape` and `dtype`."""
|
|
|
|
if shape is NUMPY_SCALAR_SHAPE:
|
|
|
|
# explicitly cast to NumPy scalar in case `value` is a Python scalar.
|
|
|
|
return np.dtype(dtype).type(value)
|
|
|
|
elif shape is PYTHON_SCALAR_SHAPE:
|
|
|
|
# explicitly cast to Python scalar via https://stackoverflow.com/a/11389998
|
|
|
|
return np.asarray(value).item()
|
|
|
|
elif type(shape) in (list, tuple):
|
|
|
|
assert np.shape(value) == tuple(shape)
|
|
|
|
return value
|
|
|
|
elif np.ndim(shape) == 0:
|
|
|
|
assert np.shape(value) == (shape,)
|
|
|
|
return value
|
|
|
|
else:
|
|
|
|
raise TypeError(type(shape))
|
|
|
|
|
|
|
|
|
|
|
|
def dtype_str(dtype):
|
|
|
|
return np.dtype(dtype).name
|
|
|
|
|
|
|
|
|
|
|
|
def format_shape_dtype_string(shape, dtype):
|
|
|
|
if isinstance(shape, np.ndarray):
|
|
|
|
return f'{dtype_str(dtype)}[{shape}]'
|
|
|
|
elif isinstance(shape, list):
|
|
|
|
shape = tuple(shape)
|
|
|
|
return _format_shape_dtype_string(shape, dtype)
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=64)
|
|
|
|
def _format_shape_dtype_string(shape, dtype):
|
|
|
|
if shape is NUMPY_SCALAR_SHAPE:
|
|
|
|
return dtype_str(dtype)
|
|
|
|
elif shape is PYTHON_SCALAR_SHAPE:
|
|
|
|
return 'py' + dtype_str(dtype)
|
|
|
|
elif type(shape) is tuple:
|
|
|
|
shapestr = ','.join(str(dim) for dim in shape)
|
2022-05-12 19:13:00 +01:00
|
|
|
return f'{dtype_str(dtype)}[{shapestr}]'
|
2021-09-24 07:02:08 -07:00
|
|
|
elif type(shape) is int:
|
2022-05-12 19:13:00 +01:00
|
|
|
return f'{dtype_str(dtype)}[{shape},]'
|
2021-09-24 07:02:08 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(type(shape))
|
|
|
|
|
|
|
|
|
|
|
|
def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
|
|
|
|
"""Produce random values given shape, dtype, scale, and post-processor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
rand: a function for producing random values of a given shape, e.g. a
|
|
|
|
bound version of either np.RandomState.randn or np.RandomState.rand.
|
|
|
|
shape: a shape value as a tuple of positive integers.
|
|
|
|
dtype: a numpy dtype.
|
|
|
|
scale: optional, a multiplicative scale for the random values (default 1).
|
|
|
|
post: optional, a callable for post-processing the random values (default
|
|
|
|
identity).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An ndarray of the given shape and dtype using random values based on a call
|
|
|
|
to rand but scaled, converted to the appropriate dtype, and post-processed.
|
|
|
|
"""
|
2022-02-07 08:59:44 -08:00
|
|
|
if _dtypes.issubdtype(dtype, np.unsignedinteger):
|
2024-02-07 12:39:31 -08:00
|
|
|
r = lambda: np.asarray(scale * abs(rand(*_dims_of_shape(shape)))).astype(dtype)
|
2022-02-07 08:59:44 -08:00
|
|
|
else:
|
2024-02-07 12:39:31 -08:00
|
|
|
r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape))).astype(dtype)
|
2021-09-24 07:02:08 -07:00
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
vals = r() + 1.0j * r()
|
|
|
|
else:
|
|
|
|
vals = r()
|
|
|
|
return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_fullrange(rng, standardize_nans=False):
|
|
|
|
"""Random numbers that span the full range of available bits."""
|
|
|
|
def gen(shape, dtype, post=lambda x: x):
|
|
|
|
dtype = np.dtype(dtype)
|
2023-04-13 11:48:11 -07:00
|
|
|
size = dtype.itemsize * math.prod(_dims_of_shape(shape))
|
2021-09-24 07:02:08 -07:00
|
|
|
vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8)
|
2022-06-09 09:32:40 -07:00
|
|
|
vals = post(vals).view(dtype)
|
|
|
|
if shape is PYTHON_SCALAR_SHAPE:
|
|
|
|
# Sampling from the full range of the largest available uint type
|
|
|
|
# leads to overflows in this case; sample from signed ints instead.
|
|
|
|
if dtype == np.uint64:
|
|
|
|
vals = vals.astype(np.int64)
|
2023-10-12 13:15:22 +01:00
|
|
|
elif dtype == np.uint32 and not config.enable_x64.value:
|
2022-06-09 09:32:40 -07:00
|
|
|
vals = vals.astype(np.int32)
|
|
|
|
vals = vals.reshape(shape)
|
2021-09-24 07:02:08 -07:00
|
|
|
# Non-standard NaNs cause errors in numpy equality assertions.
|
|
|
|
if standardize_nans and np.issubdtype(dtype, np.floating):
|
|
|
|
vals[np.isnan(vals)] = np.nan
|
|
|
|
return _cast_to_shape(vals, shape, dtype)
|
|
|
|
return gen
|
|
|
|
|
|
|
|
|
|
|
|
def rand_default(rng, scale=3):
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=scale)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_nonzero(rng):
|
|
|
|
post = lambda x: np.where(x == 0, np.array(1, dtype=x.dtype), x)
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=3, post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_positive(rng):
|
|
|
|
post = lambda x: x + 1
|
|
|
|
return partial(_rand_dtype, rng.rand, scale=2, post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_small(rng):
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=1e-3)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_not_small(rng, offset=10.):
|
|
|
|
post = lambda x: x + np.where(x > 0, offset, -offset)
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=3., post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_small_positive(rng):
|
|
|
|
return partial(_rand_dtype, rng.rand, scale=2e-5)
|
|
|
|
|
|
|
|
def rand_uniform(rng, low=0.0, high=1.0):
|
|
|
|
assert low < high
|
|
|
|
post = lambda x: x * (high - low) + low
|
|
|
|
return partial(_rand_dtype, rng.rand, post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_some_equal(rng):
|
|
|
|
|
|
|
|
def post(x):
|
|
|
|
x_ravel = x.ravel()
|
|
|
|
if len(x_ravel) == 0:
|
|
|
|
return x
|
|
|
|
flips = rng.rand(*np.shape(x)) < 0.5
|
|
|
|
return np.where(flips, x_ravel[0], x)
|
|
|
|
|
|
|
|
return partial(_rand_dtype, rng.randn, scale=100., post=post)
|
|
|
|
|
|
|
|
|
|
|
|
def rand_some_inf(rng):
|
|
|
|
"""Return a random sampler that produces infinities in floating types."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
# TODO: Complex numbers are not correctly tested
|
|
|
|
# If blocks should be switched in order, and relevant tests should be fixed
|
2021-09-24 07:02:08 -07:00
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
if not _dtypes.issubdtype(dtype, np.floating):
|
|
|
|
# only float types have inf
|
|
|
|
return base_rand(shape, dtype)
|
|
|
|
|
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
|
|
|
|
out = (rand(shape, base_dtype) +
|
|
|
|
np.array(1j, dtype) * rand(shape, base_dtype))
|
|
|
|
return _cast_to_shape(out, shape, dtype)
|
|
|
|
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
posinf_flips = rng.rand(*dims) < 0.1
|
|
|
|
neginf_flips = rng.rand(*dims) < 0.1
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
|
|
|
|
vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
def rand_some_nan(rng):
|
|
|
|
"""Return a random sampler that produces nans in floating types."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
|
|
|
|
out = (rand(shape, base_dtype) +
|
|
|
|
np.array(1j, dtype) * rand(shape, base_dtype))
|
|
|
|
return _cast_to_shape(out, shape, dtype)
|
|
|
|
|
|
|
|
if not _dtypes.issubdtype(dtype, np.floating):
|
|
|
|
# only float types have inf
|
|
|
|
return base_rand(shape, dtype)
|
|
|
|
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
r = rng.rand(*dims)
|
|
|
|
nan_flips = r < 0.1
|
|
|
|
neg_nan_flips = r < 0.05
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
|
|
|
|
vals = np.where(neg_nan_flips, np.array(-np.nan, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
def rand_some_inf_and_nan(rng):
|
|
|
|
"""Return a random sampler that produces infinities in floating types."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
# TODO: Complex numbers are not correctly tested
|
|
|
|
# If blocks should be switched in order, and relevant tests should be fixed
|
2021-09-24 07:02:08 -07:00
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
if not _dtypes.issubdtype(dtype, np.floating):
|
|
|
|
# only float types have inf
|
|
|
|
return base_rand(shape, dtype)
|
|
|
|
|
|
|
|
if _dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
base_dtype = np.real(np.array(0, dtype=dtype)).dtype
|
|
|
|
out = (rand(shape, base_dtype) +
|
|
|
|
np.array(1j, dtype) * rand(shape, base_dtype))
|
|
|
|
return _cast_to_shape(out, shape, dtype)
|
|
|
|
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
posinf_flips = rng.rand(*dims) < 0.1
|
|
|
|
neginf_flips = rng.rand(*dims) < 0.1
|
|
|
|
nan_flips = rng.rand(*dims) < 0.1
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
|
|
|
|
vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
|
|
|
|
vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
# TODO(mattjj): doesn't handle complex types
|
|
|
|
def rand_some_zero(rng):
|
|
|
|
"""Return a random sampler that produces some zeros."""
|
|
|
|
base_rand = rand_default(rng)
|
|
|
|
|
|
|
|
def rand(shape, dtype):
|
|
|
|
"""The random sampler function."""
|
|
|
|
dims = _dims_of_shape(shape)
|
|
|
|
zeros = rng.rand(*dims) < 0.5
|
|
|
|
|
|
|
|
vals = base_rand(shape, dtype)
|
|
|
|
vals = np.where(zeros, np.array(0, dtype=dtype), vals)
|
|
|
|
|
|
|
|
return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
|
|
|
|
|
|
|
|
return rand
|
|
|
|
|
|
|
|
|
|
|
|
def rand_int(rng, low=0, high=None):
|
|
|
|
def fn(shape, dtype):
|
|
|
|
nonlocal high
|
2022-10-05 01:52:41 +00:00
|
|
|
gen_dtype = dtype if np.issubdtype(dtype, np.integer) else np.int64
|
2021-09-24 07:02:08 -07:00
|
|
|
if low == 0 and high is None:
|
|
|
|
if np.issubdtype(dtype, np.integer):
|
|
|
|
high = np.iinfo(dtype).max
|
|
|
|
else:
|
|
|
|
raise ValueError("rand_int requires an explicit `high` value for "
|
|
|
|
"non-integer types.")
|
2022-10-05 01:52:41 +00:00
|
|
|
return rng.randint(low, high=high, size=shape,
|
|
|
|
dtype=gen_dtype).astype(dtype)
|
2021-09-24 07:02:08 -07:00
|
|
|
return fn
|
|
|
|
|
|
|
|
def rand_unique_int(rng, high=None):
|
|
|
|
def fn(shape, dtype):
|
2023-02-28 12:40:30 -08:00
|
|
|
return rng.choice(np.arange(high or math.prod(shape), dtype=dtype),
|
2021-09-24 07:02:08 -07:00
|
|
|
size=shape, replace=False)
|
|
|
|
return fn
|
|
|
|
|
|
|
|
def rand_bool(rng):
|
|
|
|
def generator(shape, dtype):
|
2022-10-11 15:59:44 +00:00
|
|
|
return _cast_to_shape(
|
|
|
|
np.asarray(rng.rand(*_dims_of_shape(shape)) < 0.5, dtype=dtype),
|
|
|
|
shape, dtype)
|
2021-09-24 07:02:08 -07:00
|
|
|
return generator
|
|
|
|
|
|
|
|
def check_raises(thunk, err_type, msg):
|
|
|
|
try:
|
|
|
|
thunk()
|
|
|
|
assert False
|
|
|
|
except err_type as e:
|
2022-05-12 19:13:00 +01:00
|
|
|
assert str(e).startswith(msg), f"\n{e}\n\n{msg}\n"
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def check_raises_regexp(thunk, err_type, pattern):
|
|
|
|
try:
|
|
|
|
thunk()
|
|
|
|
assert False
|
|
|
|
except err_type as e:
|
2022-05-12 19:13:00 +01:00
|
|
|
assert re.match(pattern, str(e)), f"{e}\n\n{pattern}\n"
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
def iter_eqns(jaxpr):
|
|
|
|
# TODO(necula): why doesn't this search in params?
|
2022-05-12 19:13:00 +01:00
|
|
|
yield from jaxpr.eqns
|
2021-09-24 07:02:08 -07:00
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
|
|
yield from iter_eqns(subjaxpr)
|
|
|
|
|
|
|
|
def assert_dot_precision(expected_precision, fun, *args):
|
|
|
|
jaxpr = api.make_jaxpr(fun)(*args)
|
|
|
|
precisions = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr)
|
|
|
|
if eqn.primitive == lax.dot_general_p]
|
|
|
|
for precision in precisions:
|
2022-05-12 19:13:00 +01:00
|
|
|
msg = f"Unexpected precision: {expected_precision} != {precision}"
|
2021-09-24 07:02:08 -07:00
|
|
|
if isinstance(precision, tuple):
|
|
|
|
assert precision[0] == expected_precision, msg
|
|
|
|
assert precision[1] == expected_precision, msg
|
|
|
|
else:
|
|
|
|
assert precision == expected_precision, msg
|
|
|
|
|
2023-09-08 13:07:37 -07:00
|
|
|
def assert_dot_preferred_element_type(expected, fun, *args, **kwargs):
|
|
|
|
jaxpr = api.make_jaxpr(partial(fun, **kwargs))(*args)
|
|
|
|
pref_eltypes = [eqn.params['preferred_element_type'] for eqn in iter_eqns(jaxpr.jaxpr)
|
|
|
|
if eqn.primitive == lax.dot_general_p]
|
|
|
|
for pref_eltype in pref_eltypes:
|
|
|
|
msg = f"Unexpected preferred_element_type: {expected} != {pref_eltype}"
|
|
|
|
assert expected == pref_eltype, msg
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def cases_from_gens(*gens):
|
|
|
|
sizes = [1, 3, 10]
|
2023-10-13 21:27:14 +01:00
|
|
|
cases_per_size = int(NUM_GENERATED_CASES.value / len(sizes)) + 1
|
2021-09-24 07:02:08 -07:00
|
|
|
for size in sizes:
|
|
|
|
for i in range(cases_per_size):
|
2022-05-12 19:13:00 +01:00
|
|
|
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def named_cases_from_sampler(gen):
|
|
|
|
seen = set()
|
|
|
|
retries = 0
|
|
|
|
rng = npr.RandomState(42)
|
|
|
|
def choose_one(x):
|
|
|
|
if not isinstance(x, (list, tuple)):
|
|
|
|
x = list(x)
|
|
|
|
return [x[rng.randint(len(x))]]
|
2023-10-13 21:27:14 +01:00
|
|
|
while (len(seen) < NUM_GENERATED_CASES.value and
|
2023-07-27 12:15:16 -07:00
|
|
|
retries < _MAX_CASES_SAMPLING_RETRIES.value):
|
2021-09-24 07:02:08 -07:00
|
|
|
retries += 1
|
|
|
|
cases = list(gen(choose_one))
|
|
|
|
if not cases:
|
|
|
|
continue
|
|
|
|
if len(cases) > 1:
|
|
|
|
raise RuntimeError("Generator is expected to only return a single case when sampling")
|
|
|
|
case = cases[0]
|
|
|
|
if case["testcase_name"] in seen:
|
|
|
|
continue
|
|
|
|
retries = 0
|
|
|
|
seen.add(case["testcase_name"])
|
|
|
|
yield case
|
|
|
|
|
|
|
|
|
2022-10-12 13:51:11 +00:00
|
|
|
# Random sampling for every parameterized test is expensive. Do it once and
|
|
|
|
# cache the result.
|
2023-07-21 14:20:39 -04:00
|
|
|
@functools.cache
|
2022-10-12 13:51:11 +00:00
|
|
|
def _choice(n, m):
|
|
|
|
rng = np.random.RandomState(42)
|
|
|
|
return rng.choice(n, size=m, replace=False)
|
|
|
|
|
2022-10-03 13:36:01 +00:00
|
|
|
def sample_product_testcases(*args, **kw):
|
|
|
|
"""Non-decorator form of sample_product."""
|
|
|
|
args = [list(arg) for arg in args]
|
|
|
|
kw = [(k, list(v)) for k, v in kw.items()]
|
2023-02-28 12:40:30 -08:00
|
|
|
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
|
2022-10-03 13:36:01 +00:00
|
|
|
testcases = []
|
2023-10-13 21:27:14 +01:00
|
|
|
for i in _choice(n, min(n, NUM_GENERATED_CASES.value)):
|
2022-10-03 13:36:01 +00:00
|
|
|
testcase = {}
|
|
|
|
for a in args:
|
|
|
|
testcase.update(a[i % len(a)])
|
|
|
|
i //= len(a)
|
|
|
|
for k, v in kw:
|
|
|
|
testcase[k] = v[i % len(v)]
|
|
|
|
i //= len(v)
|
|
|
|
testcases.append(testcase)
|
|
|
|
return testcases
|
|
|
|
|
|
|
|
def sample_product(*args, **kw):
|
|
|
|
"""Decorator that samples from a cartesian product of test cases.
|
|
|
|
|
|
|
|
Similar to absltest.parameterized.product(), except that it samples from the
|
|
|
|
cartesian product rather than returning the whole thing.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
*args: each positional argument is a list of dictionaries. The entries
|
|
|
|
in a dictionary correspond to name=value argument pairs; one dictionary
|
|
|
|
will be chosen for each test case. This allows multiple parameters to be
|
|
|
|
correlated.
|
|
|
|
**kw: each keyword argument is a list of values. One value will be chosen
|
|
|
|
for each test case.
|
|
|
|
"""
|
|
|
|
return parameterized.parameters(*sample_product_testcases(*args, **kw))
|
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
class JaxTestLoader(absltest.TestLoader):
|
|
|
|
def getTestCaseNames(self, testCaseClass):
|
|
|
|
names = super().getTestCaseNames(testCaseClass)
|
2023-07-27 12:15:16 -07:00
|
|
|
if _TEST_TARGETS.value:
|
|
|
|
pattern = re.compile(_TEST_TARGETS.value)
|
2021-09-24 07:02:08 -07:00
|
|
|
names = [name for name in names
|
|
|
|
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
2023-07-27 12:15:16 -07:00
|
|
|
if _EXCLUDE_TEST_TARGETS.value:
|
|
|
|
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
|
2021-09-24 07:02:08 -07:00
|
|
|
names = [name for name in names
|
|
|
|
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
|
|
|
return names
|
|
|
|
|
|
|
|
|
|
|
|
def with_config(**kwds):
|
|
|
|
"""Test case decorator for subclasses of JaxTestCase"""
|
|
|
|
def decorator(cls):
|
|
|
|
assert inspect.isclass(cls) and issubclass(cls, JaxTestCase), "@with_config can only wrap JaxTestCase class definitions."
|
2023-08-25 14:11:19 -07:00
|
|
|
cls._default_config = {}
|
|
|
|
for b in cls.__bases__:
|
|
|
|
cls._default_config.update(b._default_config)
|
|
|
|
cls._default_config.update(kwds)
|
2021-09-24 07:02:08 -07:00
|
|
|
return cls
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
2022-10-06 10:20:26 -07:00
|
|
|
def promote_like_jnp(fun, inexact=False):
|
|
|
|
"""Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`.
|
|
|
|
|
|
|
|
jnp and np have different type promotion semantics; this decorator allows
|
2024-01-12 22:44:03 -08:00
|
|
|
tests make an np reference implementation act more like a jnp
|
2022-10-06 10:20:26 -07:00
|
|
|
implementation.
|
|
|
|
"""
|
2023-03-13 12:18:36 -07:00
|
|
|
_promote = promote_dtypes_inexact if inexact else promote_dtypes
|
2022-10-06 10:20:26 -07:00
|
|
|
def wrapper(*args, **kw):
|
|
|
|
flat_args, tree = tree_flatten(args)
|
|
|
|
args = tree_unflatten(tree, _promote(*flat_args))
|
|
|
|
return fun(*args, **kw)
|
|
|
|
return wrapper
|
|
|
|
|
2024-05-29 10:24:55 -07:00
|
|
|
@contextmanager
|
2024-06-05 10:45:56 -07:00
|
|
|
def global_config_context(**kwds):
|
2024-05-29 10:24:55 -07:00
|
|
|
original_config = {}
|
2024-06-05 10:45:56 -07:00
|
|
|
try:
|
|
|
|
for key, value in kwds.items():
|
|
|
|
original_config[key] = config._read(key)
|
|
|
|
config.update(key, value)
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
for key, value in original_config.items():
|
|
|
|
config.update(key, value)
|
2024-05-29 10:24:55 -07:00
|
|
|
|
2022-10-06 10:20:26 -07:00
|
|
|
|
2024-05-29 14:05:01 -07:00
|
|
|
class NotPresent:
|
|
|
|
def __repr__(self):
|
|
|
|
return "<not present>"
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def assert_global_configs_unchanged():
|
|
|
|
starting_config = jax.config.values.copy()
|
|
|
|
yield
|
|
|
|
ending_config = jax.config.values
|
|
|
|
|
|
|
|
if starting_config == ending_config:
|
|
|
|
return
|
|
|
|
differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent()))
|
|
|
|
for k in (starting_config.keys() | ending_config.keys())
|
|
|
|
if (k not in starting_config or k not in ending_config
|
|
|
|
or starting_config[k] != ending_config[k])}
|
|
|
|
raise AssertionError(f"Test changed global config values. Differing values are: {differing}")
|
|
|
|
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
class JaxTestCase(parameterized.TestCase):
|
|
|
|
"""Base class for JAX tests including numerical checks and boilerplate."""
|
2022-02-14 09:22:05 -08:00
|
|
|
_default_config = {
|
|
|
|
'jax_enable_checks': True,
|
2022-06-17 15:46:50 -07:00
|
|
|
'jax_numpy_dtype_promotion': 'strict',
|
2022-02-15 02:42:30 -08:00
|
|
|
'jax_numpy_rank_promotion': 'raise',
|
|
|
|
'jax_traceback_filtering': 'off',
|
2023-08-25 14:11:19 -07:00
|
|
|
'jax_legacy_prng_key': 'error',
|
2022-02-14 09:22:05 -08:00
|
|
|
}
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
_compilation_cache_exit_stack: ExitStack | None = None
|
2023-02-07 15:14:53 -08:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
# TODO(mattjj): this obscures the error messages from failures, figure out how
|
|
|
|
# to re-enable it
|
|
|
|
# def tearDown(self) -> None:
|
|
|
|
# assert core.reset_trace_state()
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
super().setUp()
|
2024-05-29 14:05:01 -07:00
|
|
|
self.enter_context(assert_global_configs_unchanged())
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
# We use the adler32 hash for two reasons.
|
|
|
|
# a) it is deterministic run to run, unlike hash() which is randomized.
|
|
|
|
# b) it returns values in int32 range, which RandomState requires.
|
|
|
|
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
|
|
|
|
|
2023-02-07 15:14:53 -08:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2024-05-29 10:24:55 -07:00
|
|
|
cls._compilation_cache_exit_stack = ExitStack()
|
|
|
|
stack = cls._compilation_cache_exit_stack
|
2024-06-05 10:45:56 -07:00
|
|
|
stack.enter_context(global_config_context(**cls._default_config))
|
2024-05-29 10:24:55 -07:00
|
|
|
|
2023-07-27 12:15:16 -07:00
|
|
|
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
2023-11-27 14:52:22 -08:00
|
|
|
stack.enter_context(config.enable_compilation_cache(True))
|
2023-10-12 13:15:22 +01:00
|
|
|
stack.enter_context(config.raise_persistent_cache_errors(True))
|
|
|
|
stack.enter_context(config.persistent_cache_min_compile_time_secs(0))
|
2024-01-04 15:16:25 -08:00
|
|
|
stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
|
2023-02-07 15:14:53 -08:00
|
|
|
|
|
|
|
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
|
2024-01-12 22:44:03 -08:00
|
|
|
compilation_cache.set_cache_dir(tmp_dir)
|
|
|
|
stack.callback(lambda: compilation_cache.reset_cache())
|
2023-02-07 15:14:53 -08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
2024-05-29 10:24:55 -07:00
|
|
|
cls._compilation_cache_exit_stack.close()
|
2023-02-07 15:14:53 -08:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
def rng(self):
|
|
|
|
return self._rng
|
|
|
|
|
2024-01-11 14:11:16 -08:00
|
|
|
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
|
|
|
|
allow_object_dtype=False, verbose=True):
|
2021-09-24 07:02:08 -07:00
|
|
|
"""Assert that x and y arrays are exactly equal."""
|
|
|
|
if check_dtypes:
|
|
|
|
self.assertDtypesMatch(x, y)
|
2023-10-06 14:23:14 -07:00
|
|
|
x = np.asarray(x)
|
|
|
|
y = np.asarray(y)
|
|
|
|
|
|
|
|
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
|
|
|
|
# See https://github.com/google/jax/issues/17867
|
|
|
|
raise TypeError(
|
|
|
|
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
|
|
|
|
"If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. "
|
|
|
|
"If comparing collections of arrays, consider using assertAllClose. "
|
|
|
|
"To let this test proceed anyway, pass allow_object_dtype=True.")
|
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
# Work around https://github.com/numpy/numpy/issues/18992
|
|
|
|
with np.errstate(over='ignore'):
|
2024-01-11 14:11:16 -08:00
|
|
|
np.testing.assert_array_equal(x, y, err_msg=err_msg,
|
|
|
|
verbose=verbose)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
|
|
|
|
rtol=None, err_msg=''):
|
|
|
|
"""Assert that x and y are close (up to numerical tolerances)."""
|
|
|
|
self.assertEqual(x.shape, y.shape)
|
|
|
|
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
|
|
|
|
rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol))
|
|
|
|
|
|
|
|
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
|
|
|
|
|
|
|
|
if check_dtypes:
|
|
|
|
self.assertDtypesMatch(x, y)
|
|
|
|
|
|
|
|
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and canonicalize_dtypes:
|
2023-07-24 14:29:37 -07:00
|
|
|
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True),
|
|
|
|
_dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True))
|
2021-09-24 07:02:08 -07:00
|
|
|
else:
|
|
|
|
self.assertEqual(_dtype(x), _dtype(y))
|
|
|
|
|
|
|
|
def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None,
|
|
|
|
canonicalize_dtypes=True, err_msg=''):
|
|
|
|
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
|
|
|
|
if isinstance(x, dict):
|
|
|
|
self.assertIsInstance(y, dict)
|
|
|
|
self.assertEqual(set(x.keys()), set(y.keys()))
|
|
|
|
for k in x.keys():
|
|
|
|
self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
|
|
|
|
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
|
|
|
|
err_msg=err_msg)
|
|
|
|
elif is_sequence(x) and not hasattr(x, '__array__'):
|
|
|
|
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
|
|
|
|
self.assertEqual(len(x), len(y))
|
|
|
|
for x_elt, y_elt in zip(x, y):
|
|
|
|
self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
|
|
|
|
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
|
|
|
|
err_msg=err_msg)
|
|
|
|
elif hasattr(x, '__array__') or np.isscalar(x):
|
|
|
|
self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
|
|
|
|
if check_dtypes:
|
|
|
|
self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes)
|
|
|
|
x = np.asarray(x)
|
|
|
|
y = np.asarray(y)
|
|
|
|
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
|
|
|
|
err_msg=err_msg)
|
|
|
|
elif x == y:
|
|
|
|
return
|
|
|
|
else:
|
|
|
|
raise TypeError((type(x), type(y)))
|
|
|
|
|
|
|
|
def assertMultiLineStrippedEqual(self, expected, what):
|
|
|
|
"""Asserts two strings are equal, after dedenting and stripping each line."""
|
|
|
|
expected = textwrap.dedent(expected)
|
|
|
|
what = textwrap.dedent(what)
|
|
|
|
ignore_space_re = re.compile(r'\s*\n\s*')
|
|
|
|
expected_clean = re.sub(ignore_space_re, '\n', expected.strip())
|
|
|
|
what_clean = re.sub(ignore_space_re, '\n', what.strip())
|
2021-11-24 12:58:16 +02:00
|
|
|
if what_clean != expected_clean:
|
|
|
|
# Print it so we can copy-and-paste it into the test
|
|
|
|
print(f"Found\n{what}\n")
|
2021-09-24 07:02:08 -07:00
|
|
|
self.assertMultiLineEqual(expected_clean, what_clean,
|
2022-05-12 19:13:00 +01:00
|
|
|
msg=f"Found\n{what}\nExpecting\n{expected}")
|
2021-09-24 07:02:08 -07:00
|
|
|
|
2022-10-11 09:44:28 -07:00
|
|
|
@contextmanager
|
|
|
|
def assertNoWarnings(self):
|
2023-11-30 10:35:24 -08:00
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.simplefilter("error")
|
2022-10-11 09:44:28 -07:00
|
|
|
yield
|
|
|
|
|
2023-02-22 11:21:29 -08:00
|
|
|
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
|
2021-09-24 07:02:08 -07:00
|
|
|
rtol=None, atol=None, check_cache_misses=True):
|
|
|
|
"""Helper method for running JAX compilation and allclose assertions."""
|
|
|
|
args = args_maker()
|
|
|
|
|
|
|
|
def wrapped_fun(*args):
|
|
|
|
self.assertTrue(python_should_be_executing)
|
|
|
|
return fun(*args)
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
python_ans = fun(*args)
|
|
|
|
|
|
|
|
python_shapes = tree_map(lambda x: np.shape(x), python_ans)
|
|
|
|
np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans)
|
|
|
|
self.assertEqual(python_shapes, np_shapes)
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
cache_misses = dispatch.xla_primitive_callable.cache_info().misses
|
2021-09-24 07:02:08 -07:00
|
|
|
python_ans = fun(*args)
|
|
|
|
if check_cache_misses:
|
|
|
|
self.assertEqual(
|
2021-11-22 08:22:10 -08:00
|
|
|
cache_misses, dispatch.xla_primitive_callable.cache_info().misses,
|
2021-09-24 07:02:08 -07:00
|
|
|
"Compilation detected during second call of {} in op-by-op "
|
|
|
|
"mode.".format(fun))
|
|
|
|
|
|
|
|
cfun = api.jit(wrapped_fun)
|
|
|
|
python_should_be_executing = True
|
|
|
|
monitored_ans = cfun(*args)
|
|
|
|
|
|
|
|
python_should_be_executing = False
|
|
|
|
compiled_ans = cfun(*args)
|
|
|
|
|
|
|
|
self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
|
2023-02-22 11:21:29 -08:00
|
|
|
atol=atol or tol, rtol=rtol or tol)
|
2021-09-24 07:02:08 -07:00
|
|
|
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
|
2023-02-22 11:21:29 -08:00
|
|
|
atol=atol or tol, rtol=rtol or tol)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
args = args_maker()
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
python_ans = fun(*args)
|
|
|
|
|
|
|
|
python_should_be_executing = False
|
|
|
|
compiled_ans = cfun(*args)
|
|
|
|
|
|
|
|
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
|
2023-02-22 11:21:29 -08:00
|
|
|
atol=atol or tol, rtol=rtol or tol)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
|
|
|
|
check_dtypes=True, tol=None, atol=None, rtol=None,
|
|
|
|
canonicalize_dtypes=True):
|
|
|
|
args = args_maker()
|
|
|
|
lax_ans = lax_op(*args)
|
|
|
|
numpy_ans = numpy_reference_op(*args)
|
|
|
|
self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
|
|
|
|
atol=atol or tol, rtol=rtol or tol,
|
|
|
|
canonicalize_dtypes=canonicalize_dtypes)
|
|
|
|
|
2023-02-14 18:45:31 -08:00
|
|
|
_PJIT_IMPLEMENTATION = jax.jit
|
2023-03-16 10:01:58 -07:00
|
|
|
_PJIT_IMPLEMENTATION._name = "jit"
|
2022-05-06 16:28:24 +01:00
|
|
|
_NOOP_JIT_IMPLEMENTATION = lambda x, *args, **kwargs: x
|
|
|
|
_NOOP_JIT_IMPLEMENTATION._name = "noop"
|
|
|
|
|
|
|
|
JIT_IMPLEMENTATION = (
|
2023-03-16 10:01:58 -07:00
|
|
|
_PJIT_IMPLEMENTATION,
|
2022-05-06 16:28:24 +01:00
|
|
|
_NOOP_JIT_IMPLEMENTATION,
|
|
|
|
)
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
class BufferDonationTestCase(JaxTestCase):
|
2023-12-06 10:20:29 -08:00
|
|
|
def assertDeleted(self, x):
|
|
|
|
self.assertTrue(x.is_deleted())
|
|
|
|
|
|
|
|
def assertNotDeleted(self, x):
|
|
|
|
self.assertFalse(x.is_deleted())
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2023-11-30 10:35:24 -08:00
|
|
|
def ignore_warning(*, message='', category=Warning, **kw):
|
2021-09-24 07:02:08 -07:00
|
|
|
with warnings.catch_warnings():
|
2023-11-30 10:35:24 -08:00
|
|
|
warnings.filterwarnings("ignore", message=message, category=category, **kw)
|
2021-09-24 07:02:08 -07:00
|
|
|
yield
|
|
|
|
|
|
|
|
# -------------------- Mesh parametrization helpers --------------------
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
MeshSpec = list[tuple[str, int]]
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
|
|
|
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
|
|
|
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
|
|
|
axis_names, shape = unzip2(named_shape)
|
2023-02-28 12:40:30 -08:00
|
|
|
size = math.prod(shape)
|
2023-03-01 09:19:06 -08:00
|
|
|
local_devices = list(jax.local_devices())
|
2021-09-24 07:02:08 -07:00
|
|
|
if len(local_devices) < size:
|
|
|
|
raise unittest.SkipTest(f"Test requires {size} local devices")
|
2022-08-02 14:49:16 -07:00
|
|
|
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
|
2022-12-01 19:28:02 -08:00
|
|
|
with jax.sharding.Mesh(mesh_devices, axis_names):
|
2021-09-24 07:02:08 -07:00
|
|
|
yield
|
|
|
|
|
|
|
|
def with_mesh_from_kwargs(f):
|
|
|
|
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
|
|
|
|
|
|
|
def with_and_without_mesh(f):
|
|
|
|
return parameterized.named_parameters(
|
|
|
|
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
|
|
|
for name, mesh, axis_resources in (
|
|
|
|
('', (), ()),
|
|
|
|
('Mesh', (('x', 2),), (('i', 'x'),))
|
|
|
|
))(with_mesh_from_kwargs(f))
|
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
def create_mesh(mesh_shape, axis_names, iota_order=False):
|
2023-02-28 12:40:30 -08:00
|
|
|
size = math.prod(mesh_shape)
|
2023-03-01 09:19:06 -08:00
|
|
|
if len(jax.devices()) < size:
|
2022-01-11 15:42:31 -08:00
|
|
|
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
2024-09-03 16:22:23 -07:00
|
|
|
if iota_order:
|
|
|
|
devices = sorted(jax.devices(), key=lambda d: d.id)
|
|
|
|
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
|
|
|
return jax.sharding.Mesh(mesh_devices, axis_names)
|
|
|
|
else:
|
|
|
|
return jax.make_mesh(mesh_shape, axis_names)
|
2022-01-11 15:42:31 -08:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
class _cached_property:
|
|
|
|
null = object()
|
|
|
|
|
|
|
|
def __init__(self, method):
|
|
|
|
self._method = method
|
|
|
|
self._value = self.null
|
|
|
|
|
|
|
|
def __get__(self, obj, cls):
|
|
|
|
if self._value is self.null:
|
|
|
|
self._value = self._method(obj)
|
|
|
|
return self._value
|
|
|
|
|
|
|
|
|
|
|
|
class _LazyDtypes:
|
|
|
|
"""A class that unifies lists of supported dtypes.
|
|
|
|
|
|
|
|
These could be module-level constants, but device_under_test() is not always
|
|
|
|
known at import time, so we need to define these lists lazily.
|
|
|
|
"""
|
|
|
|
def supported(self, dtypes):
|
|
|
|
supported = supported_dtypes()
|
|
|
|
return type(dtypes)(d for d in dtypes if d in supported)
|
|
|
|
|
2023-07-12 13:06:03 -07:00
|
|
|
@_cached_property
|
|
|
|
def custom_floats(self):
|
|
|
|
return [np.dtype(t) for t in [
|
|
|
|
_dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
|
2023-08-04 18:48:25 +01:00
|
|
|
_dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz,
|
|
|
|
_dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]]
|
2023-07-12 13:06:03 -07:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
@_cached_property
|
|
|
|
def floating(self):
|
|
|
|
return self.supported([np.float32, np.float64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_floating(self):
|
|
|
|
return self.supported([_dtypes.bfloat16, np.float16, np.float32, np.float64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def integer(self):
|
|
|
|
return self.supported([np.int32, np.int64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_integer(self):
|
2024-04-18 15:20:20 +01:00
|
|
|
return self.supported([
|
|
|
|
_dtypes.int4, np.int8, np.int16, np.int32, np.int64])
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def unsigned(self):
|
|
|
|
return self.supported([np.uint32, np.uint64])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_unsigned(self):
|
2024-04-18 15:20:20 +01:00
|
|
|
return self.supported([
|
|
|
|
_dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64])
|
2021-09-24 07:02:08 -07:00
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def complex(self):
|
|
|
|
return self.supported([np.complex64, np.complex128])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def boolean(self):
|
|
|
|
return self.supported([np.bool_])
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def inexact(self):
|
|
|
|
return self.floating + self.complex
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all_inexact(self):
|
|
|
|
return self.all_floating + self.complex
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def numeric(self):
|
|
|
|
return self.floating + self.integer + self.unsigned + self.complex
|
|
|
|
|
|
|
|
@_cached_property
|
|
|
|
def all(self):
|
|
|
|
return (self.all_floating + self.all_integer + self.all_unsigned +
|
|
|
|
self.complex + self.boolean)
|
|
|
|
|
|
|
|
|
|
|
|
dtypes = _LazyDtypes()
|
2022-04-04 14:39:43 -07:00
|
|
|
|
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
def strict_promotion_if_dtypes_match(dtypes):
|
|
|
|
"""
|
|
|
|
Context manager to enable strict promotion if all dtypes match,
|
|
|
|
and enable standard dtype promotion otherwise.
|
|
|
|
"""
|
|
|
|
if all(dtype == dtypes[0] for dtype in dtypes):
|
|
|
|
return jax.numpy_dtype_promotion('strict')
|
|
|
|
return jax.numpy_dtype_promotion('standard')
|
2022-12-08 19:40:56 +00:00
|
|
|
|
|
|
|
_version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?")
|
2024-01-12 14:35:28 -08:00
|
|
|
def parse_version(v: str) -> tuple[int, ...]:
|
2022-12-08 19:40:56 +00:00
|
|
|
m = _version_regex.match(v)
|
|
|
|
if m is None:
|
|
|
|
raise ValueError(f"Unable to parse version '{v}'")
|
|
|
|
return tuple(int(x) for x in m.group(1).split('.'))
|
|
|
|
|
|
|
|
def numpy_version():
|
2024-01-12 14:35:28 -08:00
|
|
|
return parse_version(np.__version__)
|
2023-05-27 06:15:50 +02:00
|
|
|
|
|
|
|
def parameterized_filterable(*,
|
2023-06-23 15:11:37 -07:00
|
|
|
kwargs: Sequence[dict[str, Any]],
|
2023-12-08 12:09:04 +00:00
|
|
|
testcase_name: Callable[[dict[str, Any]], str] | None = None,
|
|
|
|
one_containing: str | None = None,
|
2023-05-27 06:15:50 +02:00
|
|
|
):
|
2024-07-01 13:53:41 +02:00
|
|
|
"""Decorator for named parameterized tests, with filtering support.
|
|
|
|
|
|
|
|
Works like ``parameterized.named_parameters``, except that it sanitizes the test
|
|
|
|
names so that we can use ``pytest -k`` and ``python test.py -k`` test filtering.
|
|
|
|
This means, e.g., that many special characters are replaced with `_`.
|
|
|
|
It also supports the ``one_containing`` arg to select one of the tests, while
|
|
|
|
leaving the name unchanged, which is useful for IDEs to be able to easily
|
|
|
|
pick up the enclosing test name.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
# one_containing="a_4",
|
|
|
|
[dict(a=4, b=5),
|
|
|
|
dict(a=5, b=4)])
|
|
|
|
def test_my_test(self, *, a, b): ...
|
2023-05-27 06:15:50 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
kwargs: Each entry is a set of kwargs to be passed to the test function.
|
|
|
|
testcase_name: Optionally, a function to construct the testcase_name from
|
2024-07-01 13:53:41 +02:00
|
|
|
one kwargs dict. If not given then ``kwargs`` may contain ``testcase_name`` and
|
|
|
|
otherwise the test case name is constructed as ``str(kwarg)``.
|
2023-07-27 11:30:34 +02:00
|
|
|
We sanitize the test names to work with -k test filters. See
|
2024-07-01 13:53:41 +02:00
|
|
|
``sanitize_test_name``.
|
|
|
|
one_containing: If given, then leaves the test name unchanged, and use
|
|
|
|
only one of the ``kwargs`` whose `testcase_name` includes ``one_containing``.
|
2023-05-27 06:15:50 +02:00
|
|
|
"""
|
|
|
|
# Ensure that all kwargs contain a testcase_name
|
2023-06-23 15:11:37 -07:00
|
|
|
kwargs_with_testcase_name: Sequence[dict[str, Any]]
|
2023-05-27 06:15:50 +02:00
|
|
|
if testcase_name is not None:
|
2023-07-27 11:30:34 +02:00
|
|
|
kwargs_with_testcase_name = [
|
|
|
|
dict(testcase_name=sanitize_test_name(str(testcase_name(kw))), **kw)
|
|
|
|
for kw in kwargs]
|
2023-05-27 06:15:50 +02:00
|
|
|
else:
|
|
|
|
for kw in kwargs:
|
2023-07-27 11:30:34 +02:00
|
|
|
testcase_name = kw.get("testcase_name")
|
|
|
|
if testcase_name is None:
|
2023-10-23 15:11:15 +01:00
|
|
|
testcase_name = "_".join(f"{k}={kw[k]}" # type: ignore
|
2023-07-27 11:30:34 +02:00
|
|
|
for k in sorted(kw.keys()))
|
|
|
|
kw["testcase_name"] = sanitize_test_name(testcase_name) # type: ignore
|
|
|
|
|
2023-05-27 06:15:50 +02:00
|
|
|
kwargs_with_testcase_name = kwargs
|
|
|
|
if one_containing is not None:
|
|
|
|
filtered = tuple(kw for kw in kwargs_with_testcase_name
|
|
|
|
if one_containing in kw["testcase_name"])
|
2023-10-20 03:48:36 +02:00
|
|
|
assert filtered, (
|
|
|
|
f"No testcase_name contains '{one_containing}'. "
|
|
|
|
"The testcase_name values are\n " +
|
|
|
|
"\n ".join(kw["testcase_name"] for kw in kwargs_with_testcase_name))
|
2023-05-27 06:15:50 +02:00
|
|
|
kw = filtered[0]
|
|
|
|
kw["testcase_name"] = ""
|
|
|
|
return parameterized.named_parameters([kw])
|
|
|
|
else:
|
|
|
|
return parameterized.named_parameters(*kwargs_with_testcase_name)
|
2023-08-22 13:35:07 -07:00
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def register_event_duration_listener(callback):
|
|
|
|
"""Manages registering/unregistering an event duration listener callback."""
|
|
|
|
try:
|
|
|
|
monitoring.register_event_duration_secs_listener(callback)
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
monitoring._unregister_event_duration_listener_by_callback(callback)
|
2023-09-07 08:45:48 -07:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def set_env(**kwargs):
|
|
|
|
"""Context manager to temporarily set/unset one or more environment variables.
|
|
|
|
|
2024-06-21 11:28:35 -04:00
|
|
|
Examples:
|
2023-09-07 08:45:48 -07:00
|
|
|
|
|
|
|
>>> import os
|
|
|
|
>>> os.environ['my_var'] = 'original'
|
|
|
|
|
|
|
|
>>> with set_env(my_var=None, other_var='some_value'):
|
|
|
|
... print("my_var is set:", 'my_var' in os.environ)
|
|
|
|
... print("other_var =", os.environ['other_var'])
|
|
|
|
...
|
|
|
|
my_var is set: False
|
|
|
|
other_var = some_value
|
|
|
|
|
|
|
|
>>> os.environ['my_var']
|
|
|
|
'original'
|
|
|
|
>>> 'other_var' in os.environ
|
|
|
|
False
|
|
|
|
"""
|
|
|
|
original = {key: os.environ.pop(key, None) for key in kwargs}
|
|
|
|
os.environ.update({k: v for k, v in kwargs.items() if v is not None})
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
_ = [os.environ.pop(key, None) for key in kwargs]
|
|
|
|
os.environ.update({k: v for k, v in original.items() if v is not None})
|
2023-09-12 20:49:25 -07:00
|
|
|
|
|
|
|
def fwd_bwd_jaxprs(f, *example_args):
|
|
|
|
fwd_jaxpr, (y_shape, res_shape) = jax.make_jaxpr(
|
|
|
|
lambda *args: jax.vjp(f, *args), return_shape=True)(*example_args)
|
|
|
|
bwd_jaxpr = jax.make_jaxpr(lambda res, outs: res(outs))(res_shape, y_shape)
|
|
|
|
return fwd_jaxpr, bwd_jaxpr
|
2024-01-09 13:23:57 -08:00
|
|
|
|
|
|
|
|
|
|
|
def numpy_vecdot(x, y, axis):
|
|
|
|
"""Implementation of numpy.vecdot for testing on numpy < 2.0.0"""
|
|
|
|
if numpy_version() >= (2, 0, 0):
|
|
|
|
raise ValueError("should be calling vecdot directly on numpy 2.0.0")
|
|
|
|
x = np.moveaxis(x, axis, -1)
|
|
|
|
y = np.moveaxis(y, axis, -1)
|
|
|
|
x, y = np.broadcast_arrays(x, y)
|
|
|
|
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
|
2024-02-15 13:29:35 +02:00
|
|
|
|
|
|
|
|
|
|
|
def complex_plane_sample(dtype, size_re=10, size_im=None):
|
|
|
|
"""Return a 2-D array of complex numbers that covers the complex plane
|
|
|
|
with a grid of samples.
|
|
|
|
|
|
|
|
The size of the grid is (3 + 2 * size_im) x (3 + 2 * size_re)
|
|
|
|
that includes infinity points, extreme finite points, and the
|
|
|
|
specified number of points from real and imaginary axis.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> print(complex_plane_sample(np.complex64, 0, 3))
|
|
|
|
[[-inf -infj 0. -infj inf -infj]
|
|
|
|
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
|
2024-03-21 23:35:29 +02:00
|
|
|
[-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j]
|
2024-02-15 13:29:35 +02:00
|
|
|
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
|
|
|
|
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
|
|
|
|
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
|
2024-03-21 23:35:29 +02:00
|
|
|
[-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j]
|
2024-02-15 13:29:35 +02:00
|
|
|
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
|
|
|
|
[-inf +infj 0. +infj inf +infj]]
|
|
|
|
|
|
|
|
"""
|
|
|
|
if size_im is None:
|
|
|
|
size_im = size_re
|
|
|
|
finfo = np.finfo(dtype)
|
|
|
|
|
|
|
|
def make_axis_points(size):
|
2024-03-21 23:35:29 +02:00
|
|
|
prec_dps_ratio = 3.3219280948873626
|
|
|
|
logmin = logmax = finfo.maxexp / prec_dps_ratio
|
|
|
|
logtiny = finfo.minexp / prec_dps_ratio
|
2024-02-15 13:29:35 +02:00
|
|
|
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
|
|
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
# Silence RuntimeWarning: overflow encountered in cast
|
|
|
|
warnings.simplefilter("ignore")
|
2024-03-21 23:35:29 +02:00
|
|
|
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
|
|
|
|
half_line = -half_neg_line[::-1]
|
|
|
|
axis_points[-size - 1:-1] = half_line
|
|
|
|
axis_points[1:size + 1] = half_neg_line
|
2024-02-15 13:29:35 +02:00
|
|
|
|
|
|
|
if size > 1:
|
|
|
|
axis_points[1] = finfo.min
|
|
|
|
axis_points[-2] = finfo.max
|
|
|
|
if size > 0:
|
|
|
|
axis_points[size] = -finfo.tiny
|
|
|
|
axis_points[-size - 1] = finfo.tiny
|
|
|
|
axis_points[0] = -np.inf
|
|
|
|
axis_points[-1] = np.inf
|
|
|
|
return axis_points
|
|
|
|
|
|
|
|
real_axis_points = make_axis_points(size_re)
|
|
|
|
imag_axis_points = make_axis_points(size_im)
|
|
|
|
|
|
|
|
real_part = real_axis_points.reshape((-1, 3 + 2 * size_re)).repeat(3 + 2 * size_im, 0).astype(dtype)
|
|
|
|
|
|
|
|
imag_part = imag_axis_points.repeat(2).view(dtype)
|
|
|
|
imag_part.real[:] = 0
|
|
|
|
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
|
|
|
|
|
|
|
|
return real_part + imag_part
|
2024-03-21 23:35:29 +02:00
|
|
|
|
|
|
|
|
|
|
|
class vectorize_with_mpmath(np.vectorize):
|
|
|
|
"""Same as numpy.vectorize but using mpmath backend for function evaluation.
|
|
|
|
"""
|
|
|
|
|
|
|
|
map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble')
|
|
|
|
map_complex_to_float = {v: k for k, v in map_float_to_complex.items()}
|
|
|
|
|
|
|
|
float_prec = dict(
|
|
|
|
# float16=11,
|
|
|
|
float32=24,
|
|
|
|
float64=53,
|
|
|
|
# float128=113,
|
|
|
|
# longdouble=113
|
|
|
|
)
|
|
|
|
|
|
|
|
float_minexp = dict(
|
|
|
|
float16=-14,
|
|
|
|
float32=-126,
|
|
|
|
float64=-1022,
|
|
|
|
float128=-16382
|
|
|
|
)
|
|
|
|
|
|
|
|
float_maxexp = dict(
|
|
|
|
float16=16,
|
|
|
|
float32=128,
|
|
|
|
float64=1024,
|
|
|
|
float128=16384,
|
|
|
|
)
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
mpmath = kwargs.pop('mpmath', None)
|
|
|
|
if mpmath is None:
|
|
|
|
raise ValueError('vectorize_with_mpmath: no mpmath argument specified')
|
|
|
|
self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0)
|
|
|
|
self.extra_prec = kwargs.pop('extra_prec', 0)
|
|
|
|
self.mpmath = mpmath
|
|
|
|
self.contexts = dict()
|
|
|
|
self.contexts_inv = dict()
|
|
|
|
for fp_format, prec in self.float_prec.items():
|
|
|
|
ctx = self.mpmath.mp.clone()
|
|
|
|
ctx.prec = prec
|
|
|
|
self.contexts[fp_format] = ctx
|
|
|
|
self.contexts_inv[ctx] = fp_format
|
|
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def get_context(self, x):
|
|
|
|
if isinstance(x, (np.ndarray, np.floating, np.complexfloating)):
|
|
|
|
fp_format = str(x.dtype)
|
|
|
|
fp_format = self.map_complex_to_float.get(fp_format, fp_format)
|
|
|
|
return self.contexts[fp_format]
|
|
|
|
raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance')
|
|
|
|
|
|
|
|
def nptomp(self, x):
|
|
|
|
"""Convert numpy array/scalar to an array/instance of mpmath number type.
|
|
|
|
"""
|
|
|
|
if isinstance(x, np.ndarray):
|
|
|
|
return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape)
|
|
|
|
elif isinstance(x, np.floating):
|
|
|
|
mpmath = self.mpmath
|
|
|
|
ctx = self.get_context(x)
|
|
|
|
prec, rounding = ctx._prec_rounding
|
|
|
|
if np.isposinf(x):
|
|
|
|
return ctx.make_mpf(mpmath.libmp.finf)
|
|
|
|
elif np.isneginf(x):
|
|
|
|
return ctx.make_mpf(mpmath.libmp.fninf)
|
|
|
|
elif np.isnan(x):
|
|
|
|
return ctx.make_mpf(mpmath.libmp.fnan)
|
|
|
|
elif np.isfinite(x):
|
|
|
|
mantissa, exponent = np.frexp(x)
|
|
|
|
man = int(np.ldexp(mantissa, prec))
|
|
|
|
exp = int(exponent - prec)
|
|
|
|
r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding))
|
|
|
|
assert ctx.isfinite(r), r._mpf_
|
|
|
|
return r
|
|
|
|
elif isinstance(x, np.complexfloating):
|
|
|
|
re, im = self.nptomp(x.real), self.nptomp(x.imag)
|
|
|
|
return re.context.make_mpc((re._mpf_, im._mpf_))
|
|
|
|
raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type')
|
|
|
|
|
|
|
|
def mptonp(self, x):
|
|
|
|
"""Convert mpmath instance to numpy array/scalar type.
|
|
|
|
"""
|
|
|
|
if isinstance(x, np.ndarray) and x.dtype.kind == 'O':
|
|
|
|
x_flat = x.flatten()
|
|
|
|
item = x_flat[0]
|
|
|
|
ctx = item.context
|
|
|
|
fp_format = self.contexts_inv[ctx]
|
|
|
|
if isinstance(item, ctx.mpc):
|
|
|
|
dtype = getattr(np, self.map_float_to_complex[fp_format])
|
|
|
|
elif isinstance(item, ctx.mpf):
|
|
|
|
dtype = getattr(np, fp_format)
|
|
|
|
else:
|
|
|
|
dtype = None
|
|
|
|
if dtype is not None:
|
|
|
|
return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape)
|
|
|
|
elif isinstance(x, self.mpmath.ctx_mp.mpnumeric):
|
|
|
|
ctx = x.context
|
|
|
|
if isinstance(x, ctx.mpc):
|
|
|
|
fp_format = self.contexts_inv[ctx]
|
|
|
|
dtype = getattr(np, self.map_float_to_complex[fp_format])
|
|
|
|
r = dtype().reshape(1).view(getattr(np, fp_format))
|
|
|
|
r[0] = self.mptonp(x.real)
|
|
|
|
r[1] = self.mptonp(x.imag)
|
|
|
|
return r.view(dtype)[0]
|
|
|
|
elif isinstance(x, ctx.mpf):
|
|
|
|
fp_format = self.contexts_inv[ctx]
|
|
|
|
dtype = getattr(np, fp_format)
|
|
|
|
if ctx.isfinite(x):
|
|
|
|
sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding)
|
|
|
|
assert bc >= 0, (sign, man, exp, bc, x._mpf_)
|
|
|
|
if exp + bc < self.float_minexp[fp_format]:
|
|
|
|
return -ctx.zero if sign else ctx.zero
|
|
|
|
if exp + bc > self.float_maxexp[fp_format]:
|
|
|
|
return ctx.ninf if sign else ctx.inf
|
|
|
|
man = dtype(-man if sign else man)
|
|
|
|
r = np.ldexp(man, exp)
|
|
|
|
assert np.isfinite(r), (x, r, x._mpf_, man)
|
|
|
|
return r
|
|
|
|
elif ctx.isnan(x):
|
|
|
|
return dtype(np.nan)
|
|
|
|
elif ctx.isinf(x):
|
|
|
|
return dtype(-np.inf if x._mpf_[0] else np.inf)
|
|
|
|
raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type')
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
mp_args = []
|
|
|
|
context = None
|
|
|
|
for a in args:
|
|
|
|
if isinstance(a, (np.ndarray, np.floating, np.complexfloating)):
|
|
|
|
mp_args.append(self.nptomp(a))
|
|
|
|
if context is None:
|
|
|
|
context = self.get_context(a)
|
|
|
|
else:
|
|
|
|
assert context is self.get_context(a)
|
|
|
|
else:
|
|
|
|
mp_args.append(a)
|
|
|
|
|
|
|
|
extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec
|
|
|
|
with context.extraprec(extra_prec):
|
|
|
|
result = super().__call__(*mp_args, **kwargs)
|
|
|
|
|
|
|
|
if isinstance(result, tuple):
|
|
|
|
lst = []
|
|
|
|
for r in result:
|
|
|
|
if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O')
|
|
|
|
or isinstance(r, self.mpmath.ctx_mp.mpnumeric)):
|
|
|
|
r = self.mptonp(r)
|
|
|
|
lst.append(r)
|
|
|
|
return tuple(lst)
|
|
|
|
|
|
|
|
if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O')
|
|
|
|
or isinstance(result, self.mpmath.ctx_mp.mpnumeric)):
|
|
|
|
return self.mptonp(result)
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
class numpy_with_mpmath:
|
|
|
|
"""Namespace of universal functions on numpy arrays that use mpmath
|
|
|
|
backend for evaluation and return numpy arrays as outputs.
|
|
|
|
"""
|
|
|
|
|
|
|
|
_provides = [
|
|
|
|
'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2',
|
|
|
|
'log', 'log1p', 'log10', 'log2',
|
|
|
|
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
|
|
|
|
'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
|
|
|
|
'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc',
|
|
|
|
'normalize',
|
|
|
|
]
|
|
|
|
|
|
|
|
_mp_names = dict(
|
|
|
|
abs='absmin', absolute='absmin',
|
|
|
|
log='ln',
|
|
|
|
arcsin='asin', arccos='acos', arctan='atan',
|
|
|
|
arcsinh='asinh', arccosh='acosh', arctanh='atanh',
|
|
|
|
)
|
|
|
|
|
|
|
|
def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0):
|
|
|
|
self.mpmath = mpmath
|
|
|
|
|
|
|
|
for name in self._provides:
|
|
|
|
mp_name = self._mp_names.get(name, name)
|
|
|
|
|
|
|
|
if hasattr(self, name):
|
|
|
|
op = getattr(self, name)
|
|
|
|
else:
|
|
|
|
|
|
|
|
def op(x, mp_name=mp_name):
|
|
|
|
return getattr(x.context, mp_name)(x)
|
|
|
|
|
|
|
|
setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec))
|
|
|
|
|
|
|
|
# The following function methods operate on mpmath number instances.
|
|
|
|
# The corresponding function names must be listed in
|
|
|
|
# numpy_with_mpmath._provides list.
|
|
|
|
|
|
|
|
def square(self, x):
|
|
|
|
return x * x
|
|
|
|
|
|
|
|
def positive(self, x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
def negative(self, x):
|
|
|
|
return -x
|
|
|
|
|
|
|
|
def sqrt(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
if isinstance(x, ctx.mpc):
|
2024-04-04 11:19:57 +03:00
|
|
|
# Workaround mpmath 1.3 bug in sqrt(+-inf+-infj) evaluation (see mpmath/mpmath#776).
|
|
|
|
# TODO(pearu): remove this function when mpmath 1.4 or newer
|
|
|
|
# will be the required test dependency.
|
|
|
|
if ctx.isinf(x.imag):
|
|
|
|
return ctx.make_mpc((ctx.inf._mpf_, x.imag._mpf_))
|
2024-03-21 23:35:29 +02:00
|
|
|
return ctx.sqrt(x)
|
|
|
|
|
|
|
|
def expm1(self, x):
|
|
|
|
return x.context.expm1(x)
|
|
|
|
|
2024-04-03 11:06:35 +03:00
|
|
|
def log1p(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
if isinstance(x, ctx.mpc):
|
|
|
|
# Workaround mpmath 1.3 bug in log(+-inf+-infj) evaluation (see mpmath/mpmath#774).
|
|
|
|
# TODO(pearu): remove this function when mpmath 1.4 or newer
|
|
|
|
# will be the required test dependency.
|
|
|
|
if ctx.isinf(x.real) and ctx.isinf(x.imag):
|
|
|
|
pi = ctx.pi
|
|
|
|
if x.real > 0 and x.imag > 0:
|
|
|
|
return ctx.make_mpc((x.real._mpf_, (pi / 4)._mpf_))
|
|
|
|
if x.real > 0 and x.imag < 0:
|
|
|
|
return ctx.make_mpc((x.real._mpf_, (-pi / 4)._mpf_))
|
|
|
|
if x.real < 0 and x.imag < 0:
|
|
|
|
return ctx.make_mpc(((-x.real)._mpf_, (-3 * pi / 4)._mpf_))
|
|
|
|
if x.real < 0 and x.imag > 0:
|
|
|
|
return ctx.make_mpc(((-x.real)._mpf_, (3 * pi / 4)._mpf_))
|
|
|
|
return ctx.log1p(x)
|
|
|
|
|
2024-04-10 17:34:04 +03:00
|
|
|
def tan(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
if isinstance(x, ctx.mpc):
|
|
|
|
# Workaround mpmath 1.3 bug in tan(+-inf+-infj) evaluation (see mpmath/mpmath#781).
|
|
|
|
# TODO(pearu): remove this function when mpmath 1.4 or newer
|
|
|
|
# will be the required test dependency.
|
|
|
|
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
|
|
|
|
if x.imag > 0:
|
|
|
|
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
|
|
|
|
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
|
|
|
|
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
|
|
|
|
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
|
|
|
|
return ctx.tan(x)
|
|
|
|
|
|
|
|
def tanh(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
if isinstance(x, ctx.mpc):
|
|
|
|
# Workaround mpmath 1.3 bug in tanh(+-inf+-infj) evaluation (see mpmath/mpmath#781).
|
|
|
|
# TODO(pearu): remove this function when mpmath 1.4 or newer
|
|
|
|
# will be the required test dependency.
|
|
|
|
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
|
|
|
|
if x.imag > 0:
|
|
|
|
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
|
|
|
|
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
|
|
|
|
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
|
|
|
|
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
|
|
|
|
return ctx.tanh(x)
|
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
def log2(self, x):
|
|
|
|
return x.context.ln(x) / x.context.ln2
|
|
|
|
|
|
|
|
def log10(self, x):
|
|
|
|
return x.context.ln(x) / x.context.ln10
|
|
|
|
|
|
|
|
def exp2(self, x):
|
|
|
|
return x.context.exp(x * x.context.ln2)
|
|
|
|
|
2024-04-23 21:01:06 +03:00
|
|
|
def arcsin(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
if isinstance(x, ctx.mpc):
|
2024-04-24 23:49:10 +03:00
|
|
|
# Workaround mpmath 1.3 bug in asin(+-inf+-infj) evaluation (see
|
|
|
|
# mpmath/mpmath#793).
|
|
|
|
# TODO(pearu): remove the if-block below when mpmath 1.4 or
|
|
|
|
# newer will be the required test dependency.
|
2024-04-23 21:01:06 +03:00
|
|
|
pi = ctx.pi
|
|
|
|
inf = ctx.inf
|
|
|
|
zero = ctx.zero
|
|
|
|
if ctx.isinf(x.real):
|
|
|
|
sign_real = -1 if x.real < 0 else 1
|
|
|
|
real = sign_real * pi / (4 if ctx.isinf(x.imag) else 2)
|
|
|
|
imag = -inf if x.imag < 0 else inf
|
|
|
|
return ctx.make_mpc((real._mpf_, imag._mpf_))
|
|
|
|
elif ctx.isinf(x.imag):
|
|
|
|
return ctx.make_mpc((zero._mpf_, x.imag._mpf_))
|
|
|
|
|
2024-04-24 23:49:10 +03:00
|
|
|
# On branch cut, mpmath.mp.asin returns different value compared
|
|
|
|
# to mpmath.fp.asin and numpy.arcsin (see
|
|
|
|
# mpmath/mpmath#786). The following if-block ensures
|
|
|
|
# compatibiliy with numpy.arcsin.
|
2024-04-23 21:01:06 +03:00
|
|
|
if x.real > 1 and x.imag == 0:
|
|
|
|
return ctx.asin(x).conjugate()
|
|
|
|
|
|
|
|
return ctx.asin(x)
|
|
|
|
|
2024-07-31 16:46:43 +03:00
|
|
|
def arccos(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
|
|
|
|
if isinstance(x, ctx.mpc):
|
|
|
|
# Workaround mpmath 1.3 bug in acos(+-inf+-infj) evaluation (see
|
|
|
|
# mpmath/mpmath#793).
|
|
|
|
# TODO(pearu): remove the if-block below when mpmath 1.4 or
|
|
|
|
# newer will be the required test dependency.
|
|
|
|
pi = ctx.pi
|
|
|
|
inf = ctx.inf
|
|
|
|
zero = ctx.zero
|
|
|
|
|
|
|
|
if ctx.isinf(x.imag):
|
|
|
|
if ctx.isinf(x.real):
|
|
|
|
real = pi / 4 if x.real > 0 else 3 * pi / 4
|
|
|
|
else:
|
|
|
|
real = pi / 2
|
|
|
|
imag = inf if x.imag < 0 else -inf
|
|
|
|
return ctx.make_mpc((real._mpf_, imag._mpf_))
|
|
|
|
elif ctx.isinf(x.real):
|
|
|
|
inf = ctx.inf
|
|
|
|
sign_imag = -1 if x.imag < 0 else 1
|
|
|
|
real = zero if x.real > 0 else pi
|
|
|
|
return ctx.make_mpc((real._mpf_, (-sign_imag * inf)._mpf_))
|
|
|
|
# On branch cut, mpmath.mp.acos returns different value
|
|
|
|
# compared to mpmath.fp.acos and numpy.arccos. The
|
|
|
|
# following if-block ensures compatibiliy with
|
|
|
|
# numpy.arccos.
|
|
|
|
if x.imag == 0 and x.real > 1:
|
|
|
|
return -ctx.acos(x)
|
|
|
|
|
|
|
|
return ctx.acos(x)
|
|
|
|
|
2024-04-23 21:01:06 +03:00
|
|
|
def arcsinh(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
|
|
|
|
if isinstance(x, ctx.mpc):
|
2024-04-24 23:49:10 +03:00
|
|
|
# Workaround mpmath 1.3 bug in asinh(+-inf+-infj) evaluation
|
|
|
|
# (see mpmath/mpmath#749).
|
|
|
|
# TODO(pearu): remove the if-block below when mpmath 1.4 or
|
|
|
|
# newer will be the required test dependency.
|
|
|
|
pi = ctx.pi
|
|
|
|
inf = ctx.inf
|
|
|
|
zero = ctx.zero
|
|
|
|
if ctx.isinf(x.imag):
|
|
|
|
sign_imag = -1 if x.imag < 0 else 1
|
|
|
|
real = -inf if x.real < 0 else inf
|
|
|
|
imag = sign_imag * pi / (4 if ctx.isinf(x.real) else 2)
|
|
|
|
return ctx.make_mpc((real._mpf_, imag._mpf_))
|
|
|
|
elif ctx.isinf(x.real):
|
|
|
|
return ctx.make_mpc((x.real._mpf_, zero._mpf_))
|
|
|
|
|
|
|
|
# On branch cut, mpmath.mp.asinh returns different value
|
|
|
|
# compared to mpmath.fp.asinh and numpy.arcsinh (see
|
|
|
|
# mpmath/mpmath#786). The following if-block ensures
|
|
|
|
# compatibiliy with numpy.arcsinh.
|
2024-04-23 21:01:06 +03:00
|
|
|
if x.real == 0 and x.imag < -1:
|
|
|
|
return (-ctx.asinh(x)).conjugate()
|
|
|
|
return ctx.asinh(x)
|
|
|
|
|
2024-07-31 16:46:43 +03:00
|
|
|
def arccosh(self, x):
|
|
|
|
ctx = x.context
|
|
|
|
|
|
|
|
if isinstance(x, ctx.mpc):
|
|
|
|
# Workaround mpmath 1.3 bug in acosh(+-inf+-infj) evaluation
|
|
|
|
# (see mpmath/mpmath#749).
|
|
|
|
pi = ctx.pi
|
|
|
|
inf = ctx.inf
|
|
|
|
zero = ctx.zero
|
|
|
|
if ctx.isinf(x.real):
|
|
|
|
sign_imag = -1 if x.imag < 0 else 1
|
|
|
|
imag = (
|
|
|
|
(3 if x.real < 0 else 1) * sign_imag * pi / 4
|
|
|
|
if ctx.isinf(x.imag)
|
|
|
|
else (sign_imag * pi if x.real < 0 else zero)
|
|
|
|
)
|
|
|
|
return ctx.make_mpc((inf._mpf_, imag._mpf_))
|
|
|
|
elif ctx.isinf(x.imag):
|
|
|
|
sign_imag = -1 if x.imag < 0 else 1
|
|
|
|
imag = sign_imag * pi / 2
|
|
|
|
return ctx.make_mpc((inf._mpf_, imag._mpf_))
|
|
|
|
return ctx.acosh(x)
|
|
|
|
|
2024-03-21 23:35:29 +02:00
|
|
|
def normalize(self, exact, reference, value):
|
|
|
|
"""Normalize reference and value using precision defined by the
|
|
|
|
difference of exact and reference.
|
|
|
|
"""
|
|
|
|
def worker(ctx, s, e, r, v):
|
|
|
|
ss, sm, se, sbc = s._mpf_
|
|
|
|
es, em, ee, ebc = e._mpf_
|
|
|
|
rs, rm, re, rbc = r._mpf_
|
|
|
|
vs, vm, ve, vbc = v._mpf_
|
|
|
|
|
|
|
|
if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)):
|
|
|
|
return r, v
|
|
|
|
|
|
|
|
me = min(se, ee, re, ve)
|
|
|
|
|
|
|
|
# transform mantissa parts to the same exponent base
|
|
|
|
sm_e = sm << (se - me)
|
|
|
|
em_e = em << (ee - me)
|
|
|
|
rm_e = rm << (re - me)
|
|
|
|
vm_e = vm << (ve - me)
|
|
|
|
|
|
|
|
# find matching higher and non-matching lower bits of e and r
|
|
|
|
sm_b = bin(sm_e)[2:] if sm_e else ''
|
|
|
|
em_b = bin(em_e)[2:] if em_e else ''
|
|
|
|
rm_b = bin(rm_e)[2:] if rm_e else ''
|
|
|
|
vm_b = bin(vm_e)[2:] if vm_e else ''
|
|
|
|
|
|
|
|
m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b))
|
|
|
|
em_b = '0' * (m - len(em_b)) + em_b
|
|
|
|
rm_b = '0' * (m - len(rm_b)) + rm_b
|
|
|
|
|
|
|
|
c1 = 0
|
|
|
|
for b0, b1 in zip(em_b, rm_b):
|
|
|
|
if b0 != b1:
|
|
|
|
break
|
|
|
|
c1 += 1
|
|
|
|
c0 = m - c1
|
|
|
|
|
|
|
|
# truncate r and v mantissa
|
|
|
|
rm_m = rm_e >> c0
|
|
|
|
vm_m = vm_e >> c0
|
|
|
|
|
|
|
|
# normalized r and v
|
|
|
|
nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero)
|
|
|
|
nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero)
|
|
|
|
|
|
|
|
return nr, nv
|
|
|
|
|
|
|
|
ctx = exact.context
|
|
|
|
scale = abs(exact)
|
|
|
|
if isinstance(exact, ctx.mpc):
|
|
|
|
rr, rv = worker(ctx, scale, exact.real, reference.real, value.real)
|
|
|
|
ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag)
|
|
|
|
return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_))
|
|
|
|
elif isinstance(exact, ctx.mpf):
|
|
|
|
return worker(ctx, scale, exact, reference, value)
|
|
|
|
else:
|
|
|
|
assert 0 # unreachable
|
2024-07-15 13:08:57 +01:00
|
|
|
|
|
|
|
# Hypothesis testing support
|
|
|
|
def setup_hypothesis(max_examples=30) -> None:
|
|
|
|
"""Sets up the hypothesis profiles.
|
|
|
|
|
|
|
|
Sets up the hypothesis testing profiles, and selects the one specified by
|
|
|
|
the ``JAX_HYPOTHESIS_PROFILE`` environment variable (or the
|
|
|
|
``--jax_hypothesis_profile`` configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
max_examples: the maximum number of hypothesis examples to try, when using
|
|
|
|
the default "deterministic" profile.
|
|
|
|
"""
|
|
|
|
try:
|
2024-08-01 12:00:42 +01:00
|
|
|
import hypothesis as hp
|
2024-07-15 13:08:57 +01:00
|
|
|
except (ModuleNotFoundError, ImportError):
|
|
|
|
return
|
|
|
|
|
|
|
|
hp.settings.register_profile(
|
|
|
|
"deterministic",
|
|
|
|
database=None,
|
|
|
|
derandomize=True,
|
|
|
|
deadline=None,
|
|
|
|
max_examples=max_examples,
|
|
|
|
print_blob=True,
|
|
|
|
)
|
|
|
|
hp.settings.register_profile(
|
|
|
|
"interactive",
|
|
|
|
parent=hp.settings.load_profile("deterministic"),
|
|
|
|
max_examples=1,
|
|
|
|
report_multiple_bugs=False,
|
|
|
|
verbosity=hp.Verbosity.verbose,
|
|
|
|
# Don't try and shrink
|
|
|
|
phases=(
|
|
|
|
hp.Phase.explicit,
|
|
|
|
hp.Phase.reuse,
|
|
|
|
hp.Phase.generate,
|
|
|
|
hp.Phase.target,
|
|
|
|
hp.Phase.explain,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
profile = HYPOTHESIS_PROFILE.value
|
|
|
|
logging.info("Using hypothesis profile: %s", profile)
|
|
|
|
hp.settings.load_profile(profile)
|