rocm_jax/tests/array_api_test.py
2025-01-06 15:19:02 -08:00

355 lines
7.5 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Smoketest for JAX's array API.
The full test suite for the array API is run via the array-api-tests CI;
this is just a minimal smoke test to catch issues early.
"""
from __future__ import annotations
from types import ModuleType
from absl.testing import absltest, parameterized
import jax
import jax.numpy as jnp
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype
ARRAY_API_NAMESPACE = jnp
config.parse_flags_with_absl()
MAIN_NAMESPACE = {
'abs',
'acos',
'acosh',
'add',
'all',
'any',
'arange',
'argmax',
'argmin',
'argsort',
'asarray',
'asin',
'asinh',
'astype',
'atan',
'atan2',
'atanh',
'bitwise_and',
'bitwise_invert',
'bitwise_left_shift',
'bitwise_or',
'bitwise_right_shift',
'bitwise_xor',
'bool',
'broadcast_arrays',
'broadcast_to',
'can_cast',
'ceil',
'clip',
'complex128',
'complex64',
'concat',
'conj',
'copysign',
'cos',
'cosh',
'cumulative_sum',
'divide',
'e',
'empty',
'empty_like',
'equal',
'exp',
'expand_dims',
'expm1',
'eye',
'fft',
'finfo',
'flip',
'float32',
'float64',
'floor',
'floor_divide',
'from_dlpack',
'full',
'full_like',
'greater',
'greater_equal',
'hypot',
'iinfo',
'imag',
'inf',
'int16',
'int32',
'int64',
'int8',
'isdtype',
'isfinite',
'isinf',
'isnan',
'less',
'less_equal',
'linalg',
'linspace',
'log',
'log10',
'log1p',
'log2',
'logaddexp',
'logical_and',
'logical_not',
'logical_or',
'logical_xor',
'matmul',
'matrix_transpose',
'max',
'maximum',
'mean',
'meshgrid',
'min',
'minimum',
'moveaxis',
'multiply',
'nan',
'negative',
'newaxis',
'nonzero',
'not_equal',
'ones',
'ones_like',
'permute_dims',
'pi',
'positive',
'pow',
'prod',
'real',
'remainder',
'repeat',
'reshape',
'result_type',
'roll',
'round',
'searchsorted',
'sign',
'signbit',
'sin',
'sinh',
'sort',
'sqrt',
'square',
'squeeze',
'stack',
'std',
'subtract',
'sum',
'take',
'tan',
'tanh',
'tensordot',
'tile',
'tril',
'triu',
'trunc',
'uint16',
'uint32',
'uint64',
'uint8',
'unique_all',
'unique_counts',
'unique_inverse',
'unique_values',
'unstack',
'var',
'vecdot',
'where',
'zeros',
'zeros_like',
}
LINALG_NAMESPACE = {
'cholesky',
'cross',
'det',
'diagonal',
'eigh',
'eigvalsh',
'inv',
'matmul',
'matrix_norm',
'matrix_power',
'matrix_rank',
'matrix_transpose',
'outer',
'pinv',
'qr',
'slogdet',
'solve',
'svd',
'svdvals',
'tensordot',
'trace',
'vecdot',
'vector_norm',
}
FFT_NAMESPACE = {
'fft',
'fftfreq',
'fftn',
'fftshift',
'hfft',
'ifft',
'ifftn',
'ifftshift',
'ihfft',
'irfft',
'irfftn',
'rfft',
'rfftfreq',
'rfftn',
}
def names(module: ModuleType) -> set[str]:
return {name for name in dir(module) if not name.startswith('_')}
class ArrayAPISmokeTest(absltest.TestCase):
"""Smoke test for the array API."""
def test_main_namespace(self):
self.assertContainsSubset(MAIN_NAMESPACE, names(ARRAY_API_NAMESPACE))
def test_linalg_namespace(self):
self.assertContainsSubset(LINALG_NAMESPACE, names(ARRAY_API_NAMESPACE.linalg))
def test_fft_namespace(self):
self.assertContainsSubset(FFT_NAMESPACE, names(ARRAY_API_NAMESPACE.fft))
def test_array_namespace_method(self):
x = ARRAY_API_NAMESPACE.arange(20)
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE)
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
info = ARRAY_API_NAMESPACE.__array_namespace_info__()
def setUp(self):
super().setUp()
self._boolean = self.build_dtype_dict(["bool"])
self._signed = self.build_dtype_dict(["int8", "int16", "int32"])
self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"])
self._floating = self.build_dtype_dict(["float32"])
self._complex = self.build_dtype_dict(["complex64"])
if config.enable_x64.value:
self._signed["int64"] = jnp.dtype("int64")
self._unsigned["uint64"] = jnp.dtype("uint64")
self._floating["float64"] = jnp.dtype("float64")
self._complex["complex128"] = jnp.dtype("complex128")
self._integral = self._signed | self._unsigned
self._numeric = (
self._signed | self._unsigned | self._floating | self._complex
)
def build_dtype_dict(self, dtypes):
out = {}
for name in dtypes:
out[name] = jnp.dtype(name)
return out
def test_capabilities_info(self):
capabilities = self.info.capabilities()
assert capabilities["boolean indexing"]
assert not capabilities["data-dependent shapes"]
def test_default_device_info(self):
assert self.info.default_device() is None
def test_devices_info(self):
assert self.info.devices() == jax.devices()
def test_default_dtypes_info(self):
_default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
target_dict = {
dtype_name: canonicalize_dtype(
_default_types.get(kind)
) for dtype_name, kind in _default_dtypes.items()
}
assert self.info.default_dtypes() == target_dict
@parameterized.parameters(
"bool", "signed integer", "real floating",
"complex floating", "integral", "numeric", None,
(("real floating", "complex floating"),),
(("integral", "signed integer"),),
(("integral", "bool"),),
)
def test_dtypes_info(self, kind):
info_dict = self.info.dtypes(kind=kind)
control = {
"bool":self._boolean,
"signed integer":self._signed,
"unsigned integer":self._unsigned,
"real floating":self._floating,
"complex floating":self._complex,
"integral": self._integral,
"numeric": self._numeric
}
target_dict = {}
if kind is None:
target_dict = control["numeric"] | self._boolean
elif isinstance(kind, tuple):
target_dict = {}
for _kind in kind:
target_dict |= control[_kind]
else:
target_dict = control[kind]
assert info_dict == target_dict
class ArrayAPIErrors(absltest.TestCase):
"""Test that our array API implementations raise errors where required"""
# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
def test_clip_complex(self):
x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.complex64)
complex_msg = "Complex values have no ordering and cannot be clipped"
with self.assertRaisesRegex(ValueError, complex_msg):
ARRAY_API_NAMESPACE.clip(x)
with self.assertRaisesRegex(ValueError, complex_msg):
ARRAY_API_NAMESPACE.clip(x, max=x)
x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.int32)
with self.assertRaisesRegex(ValueError, complex_msg):
ARRAY_API_NAMESPACE.clip(x, min=-1+5j)
with self.assertRaisesRegex(ValueError, complex_msg):
ARRAY_API_NAMESPACE.clip(x, max=-1+5j)
if __name__ == '__main__':
absltest.main()