2023-03-15 23:09:59 -07:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Tests for backwards compatibility of custom calls.
|
|
|
|
|
2023-06-22 02:37:39 -07:00
|
|
|
See the back_compat_test_util module docstring for how to setup and update
|
|
|
|
these tests.
|
2023-03-15 23:09:59 -07:00
|
|
|
"""
|
|
|
|
import dataclasses
|
2023-03-17 11:19:15 -07:00
|
|
|
from functools import partial
|
2023-03-20 07:07:43 -07:00
|
|
|
import itertools
|
2023-04-13 11:48:11 -07:00
|
|
|
import math
|
2023-03-15 23:09:59 -07:00
|
|
|
|
2023-03-17 10:34:51 -07:00
|
|
|
from absl.testing import absltest, parameterized
|
2023-03-15 23:09:59 -07:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import jax
|
2023-04-21 11:51:22 -07:00
|
|
|
from jax import config
|
2023-03-15 23:09:59 -07:00
|
|
|
from jax import lax
|
2023-04-28 21:22:07 -07:00
|
|
|
from jax.experimental.jax2tf import jax_export
|
2023-06-22 02:37:39 -07:00
|
|
|
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu
|
|
|
|
|
2023-03-17 10:15:47 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
|
2023-06-19 16:22:26 +03:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_eigh_cusolver_syev
|
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_eigh_lapack_syev
|
2023-06-23 07:24:41 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lu_lapack_getrf
|
2023-06-19 16:22:26 +03:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_qr_cusolver_geqrf
|
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_qr_lapack_geqrf
|
2023-03-17 10:34:51 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_threefry2x32
|
2023-06-01 10:29:12 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
|
2023-03-17 10:34:51 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Eigh
|
2023-03-21 16:25:43 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Lu
|
2023-05-09 22:30:50 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_ApproxTopK
|
2023-03-20 07:07:43 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Qr
|
2023-03-17 11:19:15 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Sharding
|
2023-06-17 10:33:29 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_stablehlo_dynamic_reduce_window
|
2023-06-19 00:38:59 -07:00
|
|
|
from jax.experimental.jax2tf.tests.back_compat_testdata import stablehlo_dynamic_rng_bit_generator
|
2023-03-20 07:07:43 -07:00
|
|
|
|
2023-03-17 11:19:15 -07:00
|
|
|
from jax.experimental import pjit
|
|
|
|
from jax.experimental.shard_map import shard_map
|
2023-03-15 23:09:59 -07:00
|
|
|
import jax.numpy as jnp
|
|
|
|
|
2023-03-17 11:19:15 -07:00
|
|
|
from jax.sharding import Mesh
|
|
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
from jax._src import test_util as jtu
|
2023-06-01 10:29:12 -07:00
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
2023-03-17 10:15:47 -07:00
|
|
|
|
2023-06-22 02:37:39 -07:00
|
|
|
class CompatTest(bctu.CompatTestBase):
|
2023-03-15 23:09:59 -07:00
|
|
|
def test_dummy(self):
|
2023-06-19 16:15:10 +03:00
|
|
|
# Tests the testing mechanism. Let this test run on all platforms
|
2023-06-22 02:37:39 -07:00
|
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
2023-03-17 10:15:47 -07:00
|
|
|
platform_dummy_data = dataclasses.replace(
|
2023-06-19 16:15:10 +03:00
|
|
|
dummy_data, platform=self.default_jax_backend())
|
2023-03-15 23:09:59 -07:00
|
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
|
|
|
|
def test_detect_different_output(self):
|
|
|
|
# Test the detection mechanism. Let this test run on all platforms
|
2023-06-22 02:37:39 -07:00
|
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
2023-03-15 23:09:59 -07:00
|
|
|
platform_dummy_data = dataclasses.replace(
|
|
|
|
dummy_data,
|
2023-06-19 16:15:10 +03:00
|
|
|
platform=self.default_jax_backend(),
|
2023-03-17 10:15:47 -07:00
|
|
|
expected_outputs=(np.array(2.0, dtype=np.float32),))
|
2023-03-15 23:09:59 -07:00
|
|
|
with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"):
|
|
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
|
|
|
|
def test_detect_different_custom_calls(self):
|
|
|
|
# Test the detection mechanism. Let this test run on all platforms
|
2023-06-22 02:37:39 -07:00
|
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
2023-03-15 23:09:59 -07:00
|
|
|
platform_dummy_data = dataclasses.replace(
|
|
|
|
dummy_data,
|
2023-06-19 16:15:10 +03:00
|
|
|
platform=self.default_jax_backend(),
|
2023-03-15 23:09:59 -07:00
|
|
|
custom_call_targets=["missing"])
|
|
|
|
with self.assertRaisesRegex(AssertionError, "Lists differ"):
|
|
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
|
2023-03-20 07:07:43 -07:00
|
|
|
def test_custom_call_coverage(self):
|
2023-06-17 10:33:29 -07:00
|
|
|
"""Tests that the back compat tests cover all the targets declared stable."""
|
2023-04-28 21:22:07 -07:00
|
|
|
targets_to_cover = set(jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
2023-03-20 07:07:43 -07:00
|
|
|
# Add here all the testdatas that should cover the targets guaranteed
|
|
|
|
# stable
|
|
|
|
covering_testdatas = [
|
2023-06-17 04:50:12 -07:00
|
|
|
cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14,
|
2023-06-19 16:22:26 +03:00
|
|
|
cpu_eigh_lapack_syev.data_2023_03_17,
|
|
|
|
cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
|
2023-06-23 07:24:41 -07:00
|
|
|
cpu_lu_lapack_getrf.data_2023_06_14,
|
2023-06-19 16:22:26 +03:00
|
|
|
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
|
2023-06-22 02:37:39 -07:00
|
|
|
tf_call_tf_function.data_2023_06_02, # This is tested in back_compat_tf_test.py
|
2023-05-09 22:30:50 -07:00
|
|
|
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
|
2023-05-16 13:07:23 -07:00
|
|
|
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
|
2023-06-17 10:33:29 -07:00
|
|
|
tpu_ApproxTopK.data_2023_05_16,
|
|
|
|
tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17,
|
2023-06-19 00:38:59 -07:00
|
|
|
tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17,
|
|
|
|
stablehlo_dynamic_rng_bit_generator.data_2023_06_17,]
|
2023-06-23 07:24:41 -07:00
|
|
|
# Some of the above are nested structures.
|
2023-03-20 07:07:43 -07:00
|
|
|
covering_testdatas = itertools.chain(
|
2023-06-19 16:15:10 +03:00
|
|
|
*[self.load_testdata_nested(d) for d in covering_testdatas])
|
2023-03-20 07:07:43 -07:00
|
|
|
covered_targets = set()
|
|
|
|
for data in covering_testdatas:
|
2023-06-22 02:37:39 -07:00
|
|
|
self.assertIsInstance(data, bctu.CompatTestData)
|
2023-03-20 07:07:43 -07:00
|
|
|
covered_targets = covered_targets.union(data.custom_call_targets)
|
|
|
|
|
2023-06-16 23:58:37 -07:00
|
|
|
covered_targets = covered_targets.union({
|
2023-06-19 16:15:10 +03:00
|
|
|
# TODO(necula): add tests for eig on CPU
|
2023-06-22 01:05:23 -07:00
|
|
|
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
|
2023-06-22 01:39:20 -07:00
|
|
|
# TODO(necula): add tests for qr on CPU in a separate change.
|
|
|
|
"lapack_cpotrf", "lapack_dpotrf", "lapack_spotrf", "lapack_zpotrf",
|
2023-06-22 01:05:23 -07:00
|
|
|
# TODO(necula): add tests for svd on CPU
|
|
|
|
"lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd",
|
2023-06-26 12:12:15 -07:00
|
|
|
# TODO(necula): add tests for triangular_solve on CPU
|
|
|
|
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
|
2023-06-26 13:58:23 -07:00
|
|
|
# TODO(necula): add tests for schur on CPU
|
|
|
|
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
|
2023-06-22 01:05:23 -07:00
|
|
|
})
|
2023-03-20 07:07:43 -07:00
|
|
|
not_covered = targets_to_cover.difference(covered_targets)
|
|
|
|
self.assertEmpty(not_covered)
|
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
def test_ducc_fft(self):
|
|
|
|
def func(x):
|
|
|
|
return lax.fft(x, fft_type="fft", fft_lengths=(4,))
|
|
|
|
|
2023-06-17 04:50:12 -07:00
|
|
|
# An old lowering, with ducc_fft. We keep it for 6 months.
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
|
2023-06-17 04:50:12 -07:00
|
|
|
# We have changed the lowering for fft, do not compare with current.
|
|
|
|
self.run_one_test(func, data, compare_with_current=False)
|
|
|
|
|
|
|
|
# A newer lowering, with dynamic_ducc_fft.
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(cpu_ducc_fft.data_2023_06_14)
|
2023-03-15 23:09:59 -07:00
|
|
|
self.run_one_test(func, data)
|
|
|
|
|
2023-03-17 10:34:51 -07:00
|
|
|
@staticmethod
|
2023-04-12 11:44:48 -07:00
|
|
|
def eigh_input(shape, dtype):
|
2023-03-17 10:34:51 -07:00
|
|
|
# In order to keep inputs small, we construct the input programmatically
|
2023-04-13 11:48:11 -07:00
|
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
2023-03-17 10:34:51 -07:00
|
|
|
# Make operand self-adjoint
|
|
|
|
operand = (operand + jnp.conj(jnp.swapaxes(operand, -1, -2))) / 2.
|
2023-04-12 11:44:48 -07:00
|
|
|
return operand
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def eigh_harness(shape, dtype):
|
|
|
|
operand = CompatTest.eigh_input(shape, dtype)
|
2023-03-17 10:34:51 -07:00
|
|
|
return lax.linalg.eigh(jnp.tril(operand), lower=True, symmetrize_input=False)
|
|
|
|
|
2023-04-12 11:44:48 -07:00
|
|
|
def check_eigh_results(self, operand, res_now, res_expected, *,
|
2023-06-12 16:15:15 -07:00
|
|
|
rtol, atol=None):
|
2023-04-12 11:44:48 -07:00
|
|
|
v_now, w_now = res_now
|
|
|
|
_, w_expected = res_expected
|
|
|
|
n, m = operand.shape
|
|
|
|
assert n == m
|
2023-04-28 21:22:07 -07:00
|
|
|
assert v_now.shape == operand.shape
|
|
|
|
assert w_now.shape == (n,)
|
2023-04-12 11:44:48 -07:00
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(np.swapaxes(v_now, -1, -2)), v_now)),
|
|
|
|
rtol)
|
2023-04-28 21:22:07 -07:00
|
|
|
# w_now : f64[n] while v_now: c128[n, n]
|
|
|
|
w_now_like_v = w_now[np.newaxis, :].astype(v_now.dtype)
|
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.matmul(operand, v_now) - w_now_like_v * v_now),
|
|
|
|
rtol * np.linalg.norm(operand))
|
2023-06-12 16:15:15 -07:00
|
|
|
self.assertAllClose(w_expected, w_now, rtol=rtol, atol=atol)
|
2023-04-12 11:44:48 -07:00
|
|
|
|
2023-03-17 10:34:51 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
2023-06-19 16:22:26 +03:00
|
|
|
def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.eigh
|
2023-03-17 10:34:51 -07:00
|
|
|
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
2023-04-12 11:44:48 -07:00
|
|
|
size = 8
|
|
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
2023-03-17 10:34:51 -07:00
|
|
|
func = lambda: CompatTest.eigh_harness((8, 8), dtype)
|
2023-06-19 16:22:26 +03:00
|
|
|
data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name])
|
2023-04-12 11:44:48 -07:00
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
2023-06-12 16:15:15 -07:00
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
2023-04-12 11:44:48 -07:00
|
|
|
check_results=partial(self.check_eigh_results, operand))
|
2023-03-17 10:34:51 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
|
|
|
|
dtype_name=dtype_name, variant=variant)
|
|
|
|
for dtype_name in ("f32", "f64")
|
|
|
|
# We use different custom calls for sizes <= 32
|
|
|
|
for variant in ["syevj", "syevd"])
|
2023-06-19 16:22:26 +03:00
|
|
|
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.eigh
|
2023-03-17 10:34:51 -07:00
|
|
|
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
|
|
|
size = dict(syevj=8, syevd=36)[variant]
|
2023-04-12 11:44:48 -07:00
|
|
|
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
|
2023-06-14 13:00:00 -07:00
|
|
|
atol = dict(f32=1e-2, f64=1e-10)[dtype_name]
|
2023-04-12 11:44:48 -07:00
|
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
2023-03-17 10:34:51 -07:00
|
|
|
func = lambda: CompatTest.eigh_harness((size, size), dtype)
|
2023-06-19 16:22:26 +03:00
|
|
|
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
|
2023-06-14 13:00:00 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
2023-04-12 11:44:48 -07:00
|
|
|
check_results=partial(self.check_eigh_results, operand))
|
2023-03-17 10:34:51 -07:00
|
|
|
|
|
|
|
def test_tpu_Eigh(self):
|
2023-05-04 07:45:38 -07:00
|
|
|
self.skipTest(
|
|
|
|
"TODO(b/280668311): Change input matrix to not be ill-conditioned."
|
|
|
|
)
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.eigh
|
2023-04-12 11:44:48 -07:00
|
|
|
shape = (8, 8)
|
|
|
|
dtype = np.float32
|
|
|
|
operand = CompatTest.eigh_input(shape, dtype)
|
|
|
|
func = lambda: CompatTest.eigh_harness(shape, dtype)
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Eigh.data)
|
2023-04-12 11:44:48 -07:00
|
|
|
self.run_one_test(func, data, rtol=1e-3,
|
|
|
|
check_results=partial(self.check_eigh_results, operand))
|
2023-03-17 10:34:51 -07:00
|
|
|
|
2023-03-20 07:07:43 -07:00
|
|
|
@staticmethod
|
|
|
|
def qr_harness(shape, dtype):
|
|
|
|
# In order to keep inputs small, we construct the input programmatically
|
2023-04-13 11:48:11 -07:00
|
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
2023-03-20 07:07:43 -07:00
|
|
|
return lax.linalg.qr(operand, full_matrices=True)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
2023-06-19 16:22:26 +03:00
|
|
|
def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.qr
|
|
|
|
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
func = lambda: CompatTest.qr_harness((3, 3), dtype)
|
2023-06-19 16:22:26 +03:00
|
|
|
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
|
2023-04-12 11:44:48 -07:00
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
2023-03-20 07:07:43 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
|
|
|
|
dtype_name=dtype_name, batched=batched)
|
|
|
|
for dtype_name in ("f32",)
|
|
|
|
# For batched qr we use cublas_geqrf_batched
|
|
|
|
for batched in ("batched", "unbatched"))
|
2023-06-19 16:22:26 +03:00
|
|
|
def test_cuda_qr_cusolver_geqrf(self, dtype_name="f32", batched="unbatched"):
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.qr
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
2023-04-12 11:44:48 -07:00
|
|
|
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
|
2023-03-20 07:07:43 -07:00
|
|
|
shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched]
|
|
|
|
func = lambda: CompatTest.qr_harness(shape, dtype)
|
2023-06-19 16:22:26 +03:00
|
|
|
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched])
|
2023-03-20 07:07:43 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol)
|
|
|
|
|
|
|
|
def test_tpu_Qr(self):
|
|
|
|
# For lax.linalg.qr
|
|
|
|
func = lambda: CompatTest.qr_harness((3, 3), np.float32)
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Qr.data_2023_03_17)
|
2023-03-20 07:07:43 -07:00
|
|
|
self.run_one_test(func, data, rtol=1e-3)
|
|
|
|
|
2023-03-21 16:25:43 -07:00
|
|
|
@staticmethod
|
|
|
|
def lu_harness(shape, dtype):
|
2023-04-13 11:48:11 -07:00
|
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
2023-03-21 16:25:43 -07:00
|
|
|
return lax.linalg.lu(operand)
|
|
|
|
|
2023-06-23 07:24:41 -07:00
|
|
|
def check_lu_results(self, operand, res_now, res_expected, *,
|
|
|
|
dtype, rtol=None, atol=None):
|
|
|
|
# Same checker as in linalg_test.py
|
|
|
|
del res_expected # we do not check against expected
|
|
|
|
lu_now, pivots_now, _ = res_now
|
|
|
|
|
|
|
|
n, m = operand.shape
|
|
|
|
self.assertEqual(n, m)
|
|
|
|
l = np.tril(lu_now, -1) + np.eye(n, dtype=dtype)
|
|
|
|
u = np.triu(lu_now)
|
|
|
|
operand_copy = operand.copy()
|
|
|
|
for i in range(n):
|
|
|
|
operand_copy[[i, pivots_now[i]],] = operand_copy[[pivots_now[i], i],]
|
|
|
|
self.assertAllClose(operand_copy, np.matmul(l, u), rtol=rtol, atol=atol)
|
|
|
|
|
2023-03-21 16:25:43 -07:00
|
|
|
def test_tpu_Lu(self):
|
2023-06-23 07:24:41 -07:00
|
|
|
# For lax.linalg.lu on TPU.
|
|
|
|
shape = (3, 3)
|
|
|
|
dtype = np.float32
|
|
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Lu.data_2023_03_21)
|
2023-06-23 07:24:41 -07:00
|
|
|
operand = np.reshape(np.arange(math.prod(shape), dtype=dtype), shape)
|
|
|
|
self.run_one_test(func, data, rtol=1e-3,
|
|
|
|
check_results=partial(self.check_lu_results, operand,
|
|
|
|
dtype=dtype))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}",
|
|
|
|
dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
def test_cpu_lu_lapack_getrf(self, dtype_name:str):
|
|
|
|
# For lax.linalg.lu on CPU.
|
|
|
|
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (3, 3)
|
|
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
|
|
|
data = self.load_testdata(cpu_lu_lapack_getrf.data_2023_06_14[dtype_name])
|
|
|
|
operand = np.reshape(np.arange(math.prod(shape), dtype=dtype), shape)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_lu_results, operand,
|
|
|
|
dtype=dtype))
|
2023-03-21 16:25:43 -07:00
|
|
|
|
2023-05-09 22:30:50 -07:00
|
|
|
def test_approx_top_k(self):
|
|
|
|
def func():
|
|
|
|
x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0])
|
2023-05-16 13:07:23 -07:00
|
|
|
y = lax.approx_max_k(x, 3)
|
|
|
|
z = lax.approx_max_k(x, 3)
|
|
|
|
return y + z
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_ApproxTopK.data_2023_05_16)
|
2023-05-09 22:30:50 -07:00
|
|
|
self.run_one_test(func, data)
|
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
def test_cu_threefry2x32(self):
|
|
|
|
def func(x):
|
|
|
|
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
2023-03-17 10:15:47 -07:00
|
|
|
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
|
2023-03-15 23:09:59 -07:00
|
|
|
self.run_one_test(func, data)
|
|
|
|
|
2023-03-17 11:19:15 -07:00
|
|
|
def test_sharding(self):
|
|
|
|
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
|
|
|
|
if jtu.device_under_test() != "tpu" or len(jax.devices()) < 2:
|
|
|
|
self.skipTest("Test runs only on TPU with at least 2 devices")
|
|
|
|
|
|
|
|
# Must use exactly 2 devices for expected outputs from ppermute
|
|
|
|
devices = jax.devices()[:2]
|
|
|
|
mesh = Mesh(devices, axis_names=('a'))
|
|
|
|
|
|
|
|
@partial(pjit.pjit,
|
|
|
|
in_shardings=(P('a', None),), out_shardings=P('a', None))
|
|
|
|
@partial(shard_map, mesh=mesh,
|
|
|
|
in_specs=(P('a', None),), out_specs=P('a', None))
|
|
|
|
def func(x): # b: f32[2, 4]
|
|
|
|
axis_size = lax.psum(1, 'a')
|
|
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
|
|
return lax.ppermute(x, 'a', perm=perm)
|
|
|
|
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Sharding.data_2023_03_16)
|
2023-03-17 11:19:15 -07:00
|
|
|
with mesh:
|
2023-04-28 21:22:07 -07:00
|
|
|
self.run_one_test(func, data)
|
2023-03-17 11:19:15 -07:00
|
|
|
|
2023-06-17 10:33:29 -07:00
|
|
|
def test_tpu_stablehlo_dynamic_reduce_window_unary(self):
|
|
|
|
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
|
|
|
|
# reduce window with dynamic shapes.
|
|
|
|
# See https://github.com/openxla/stablehlo/issues/1258 for the long term.
|
|
|
|
# The inputs are already in the test data, here only for readability.
|
|
|
|
shape = (3, 4)
|
|
|
|
_ = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return jnp.cumsum(x, axis=0)
|
|
|
|
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17)
|
2023-06-17 10:33:29 -07:00
|
|
|
self.run_one_test(
|
|
|
|
func, data,
|
2023-06-19 00:38:59 -07:00
|
|
|
polymorphic_shapes=("b, ...",))
|
2023-06-17 10:33:29 -07:00
|
|
|
|
|
|
|
def test_tpu_stablehlo_dynamic_reduce_window_variadic(self):
|
|
|
|
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
|
|
|
|
# reduce window with dynamic shapes.
|
|
|
|
# See https://github.com/openxla/stablehlo/issues/1258 for the long term.
|
|
|
|
# The inputs are already in the test data, here only for readability.
|
|
|
|
shape = (3, 4)
|
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
y = 100 + np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
|
|
|
_ = (x, y)
|
|
|
|
def func(x, y): # x: f32[b, 2] y: i32[b, 2]
|
|
|
|
return lax.reduce_window(
|
|
|
|
(x, y), (np.array(1., np.float32), np.array(2, np.int32)),
|
|
|
|
lambda xy0, xy1: (lax.add(xy0[0], xy1[0]),
|
|
|
|
lax.sub(xy0[1], xy1[1])),
|
|
|
|
(2, x.shape[0]), (1, 1), "VALID")
|
|
|
|
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17)
|
2023-06-17 10:33:29 -07:00
|
|
|
self.run_one_test(
|
|
|
|
func, data,
|
|
|
|
polymorphic_shapes=("b, ...", "b, ..."))
|
|
|
|
|
2023-06-19 00:38:59 -07:00
|
|
|
def test_stablehlo_dynamic_rbg_bit_generator(self):
|
|
|
|
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
|
|
|
|
# rbg_bit_generator with dynamic shapes.
|
|
|
|
# See https://github.com/openxla/stablehlo/issues/1344 for the long term.
|
|
|
|
key = np.arange(42, 42+4, dtype=np.uint32)
|
|
|
|
a_shape = (2, 3)
|
|
|
|
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)
|
|
|
|
inputs = (key, a)
|
|
|
|
del inputs # already in the test data, here only for readability.
|
|
|
|
|
|
|
|
def func(key, a): # a is only used for its shape
|
|
|
|
return jax.random.key_data(jax.random.split(key, a.shape[0] * a.shape[1]))
|
|
|
|
|
|
|
|
# Note that the test currently checks that the generated sequence is the
|
|
|
|
# same. According to the StableHLO spec: "The output is guaranteed to be
|
|
|
|
# deterministic function of initial_state, but it is not guaranteed to be
|
|
|
|
# deterministic between implementations"
|
|
|
|
# See https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator
|
|
|
|
# This test will fail when the implementation changes. We expect this to
|
|
|
|
# be rare, and most users may expect the RNG sequence to be the same
|
|
|
|
# upon reloading of a saved model.
|
|
|
|
# In case of an intended change in behavior we will have the option to
|
|
|
|
# replace this strict check with something else.
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(stablehlo_dynamic_rng_bit_generator.data_2023_06_17)
|
2023-06-19 00:38:59 -07:00
|
|
|
|
|
|
|
prev_default_prng_impl = jax.config.jax_default_prng_impl
|
|
|
|
try:
|
|
|
|
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
|
|
|
|
|
|
|
|
self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1"))
|
|
|
|
finally:
|
|
|
|
jax.config.update("jax_default_prng_impl", prev_default_prng_impl)
|
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|