mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
355 lines
7.5 KiB
Python
355 lines
7.5 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Smoketest for JAX's array API.
|
|
|
|
The full test suite for the array API is run via the array-api-tests CI;
|
|
this is just a minimal smoke test to catch issues early.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from types import ModuleType
|
|
|
|
from absl.testing import absltest, parameterized
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax._src import config, test_util as jtu
|
|
from jax._src.dtypes import _default_types, canonicalize_dtype
|
|
|
|
ARRAY_API_NAMESPACE = jnp
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
MAIN_NAMESPACE = {
|
|
'abs',
|
|
'acos',
|
|
'acosh',
|
|
'add',
|
|
'all',
|
|
'any',
|
|
'arange',
|
|
'argmax',
|
|
'argmin',
|
|
'argsort',
|
|
'asarray',
|
|
'asin',
|
|
'asinh',
|
|
'astype',
|
|
'atan',
|
|
'atan2',
|
|
'atanh',
|
|
'bitwise_and',
|
|
'bitwise_invert',
|
|
'bitwise_left_shift',
|
|
'bitwise_or',
|
|
'bitwise_right_shift',
|
|
'bitwise_xor',
|
|
'bool',
|
|
'broadcast_arrays',
|
|
'broadcast_to',
|
|
'can_cast',
|
|
'ceil',
|
|
'clip',
|
|
'complex128',
|
|
'complex64',
|
|
'concat',
|
|
'conj',
|
|
'copysign',
|
|
'cos',
|
|
'cosh',
|
|
'cumulative_sum',
|
|
'divide',
|
|
'e',
|
|
'empty',
|
|
'empty_like',
|
|
'equal',
|
|
'exp',
|
|
'expand_dims',
|
|
'expm1',
|
|
'eye',
|
|
'fft',
|
|
'finfo',
|
|
'flip',
|
|
'float32',
|
|
'float64',
|
|
'floor',
|
|
'floor_divide',
|
|
'from_dlpack',
|
|
'full',
|
|
'full_like',
|
|
'greater',
|
|
'greater_equal',
|
|
'hypot',
|
|
'iinfo',
|
|
'imag',
|
|
'inf',
|
|
'int16',
|
|
'int32',
|
|
'int64',
|
|
'int8',
|
|
'isdtype',
|
|
'isfinite',
|
|
'isinf',
|
|
'isnan',
|
|
'less',
|
|
'less_equal',
|
|
'linalg',
|
|
'linspace',
|
|
'log',
|
|
'log10',
|
|
'log1p',
|
|
'log2',
|
|
'logaddexp',
|
|
'logical_and',
|
|
'logical_not',
|
|
'logical_or',
|
|
'logical_xor',
|
|
'matmul',
|
|
'matrix_transpose',
|
|
'max',
|
|
'maximum',
|
|
'mean',
|
|
'meshgrid',
|
|
'min',
|
|
'minimum',
|
|
'moveaxis',
|
|
'multiply',
|
|
'nan',
|
|
'negative',
|
|
'newaxis',
|
|
'nonzero',
|
|
'not_equal',
|
|
'ones',
|
|
'ones_like',
|
|
'permute_dims',
|
|
'pi',
|
|
'positive',
|
|
'pow',
|
|
'prod',
|
|
'real',
|
|
'remainder',
|
|
'repeat',
|
|
'reshape',
|
|
'result_type',
|
|
'roll',
|
|
'round',
|
|
'searchsorted',
|
|
'sign',
|
|
'signbit',
|
|
'sin',
|
|
'sinh',
|
|
'sort',
|
|
'sqrt',
|
|
'square',
|
|
'squeeze',
|
|
'stack',
|
|
'std',
|
|
'subtract',
|
|
'sum',
|
|
'take',
|
|
'tan',
|
|
'tanh',
|
|
'tensordot',
|
|
'tile',
|
|
'tril',
|
|
'triu',
|
|
'trunc',
|
|
'uint16',
|
|
'uint32',
|
|
'uint64',
|
|
'uint8',
|
|
'unique_all',
|
|
'unique_counts',
|
|
'unique_inverse',
|
|
'unique_values',
|
|
'unstack',
|
|
'var',
|
|
'vecdot',
|
|
'where',
|
|
'zeros',
|
|
'zeros_like',
|
|
}
|
|
|
|
LINALG_NAMESPACE = {
|
|
'cholesky',
|
|
'cross',
|
|
'det',
|
|
'diagonal',
|
|
'eigh',
|
|
'eigvalsh',
|
|
'inv',
|
|
'matmul',
|
|
'matrix_norm',
|
|
'matrix_power',
|
|
'matrix_rank',
|
|
'matrix_transpose',
|
|
'outer',
|
|
'pinv',
|
|
'qr',
|
|
'slogdet',
|
|
'solve',
|
|
'svd',
|
|
'svdvals',
|
|
'tensordot',
|
|
'trace',
|
|
'vecdot',
|
|
'vector_norm',
|
|
}
|
|
|
|
FFT_NAMESPACE = {
|
|
'fft',
|
|
'fftfreq',
|
|
'fftn',
|
|
'fftshift',
|
|
'hfft',
|
|
'ifft',
|
|
'ifftn',
|
|
'ifftshift',
|
|
'ihfft',
|
|
'irfft',
|
|
'irfftn',
|
|
'rfft',
|
|
'rfftfreq',
|
|
'rfftn',
|
|
}
|
|
|
|
|
|
def names(module: ModuleType) -> set[str]:
|
|
return {name for name in dir(module) if not name.startswith('_')}
|
|
|
|
|
|
class ArrayAPISmokeTest(absltest.TestCase):
|
|
"""Smoke test for the array API."""
|
|
|
|
def test_main_namespace(self):
|
|
self.assertContainsSubset(MAIN_NAMESPACE, names(ARRAY_API_NAMESPACE))
|
|
|
|
def test_linalg_namespace(self):
|
|
self.assertContainsSubset(LINALG_NAMESPACE, names(ARRAY_API_NAMESPACE.linalg))
|
|
|
|
def test_fft_namespace(self):
|
|
self.assertContainsSubset(FFT_NAMESPACE, names(ARRAY_API_NAMESPACE.fft))
|
|
|
|
def test_array_namespace_method(self):
|
|
x = ARRAY_API_NAMESPACE.arange(20)
|
|
self.assertIsInstance(x, jax.Array)
|
|
self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE)
|
|
|
|
|
|
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
|
|
|
|
info = ARRAY_API_NAMESPACE.__array_namespace_info__()
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._boolean = self.build_dtype_dict(["bool"])
|
|
self._signed = self.build_dtype_dict(["int8", "int16", "int32"])
|
|
self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"])
|
|
self._floating = self.build_dtype_dict(["float32"])
|
|
self._complex = self.build_dtype_dict(["complex64"])
|
|
if config.enable_x64.value:
|
|
self._signed["int64"] = jnp.dtype("int64")
|
|
self._unsigned["uint64"] = jnp.dtype("uint64")
|
|
self._floating["float64"] = jnp.dtype("float64")
|
|
self._complex["complex128"] = jnp.dtype("complex128")
|
|
self._integral = self._signed | self._unsigned
|
|
self._numeric = (
|
|
self._signed | self._unsigned | self._floating | self._complex
|
|
)
|
|
def build_dtype_dict(self, dtypes):
|
|
out = {}
|
|
for name in dtypes:
|
|
out[name] = jnp.dtype(name)
|
|
return out
|
|
|
|
def test_capabilities_info(self):
|
|
capabilities = self.info.capabilities()
|
|
assert capabilities["boolean indexing"]
|
|
assert not capabilities["data-dependent shapes"]
|
|
|
|
def test_default_device_info(self):
|
|
assert self.info.default_device() is None
|
|
|
|
def test_devices_info(self):
|
|
assert self.info.devices() == jax.devices()
|
|
|
|
def test_default_dtypes_info(self):
|
|
_default_dtypes = {
|
|
"real floating": "f",
|
|
"complex floating": "c",
|
|
"integral": "i",
|
|
"indexing": "i",
|
|
}
|
|
target_dict = {
|
|
dtype_name: canonicalize_dtype(
|
|
_default_types.get(kind)
|
|
) for dtype_name, kind in _default_dtypes.items()
|
|
}
|
|
assert self.info.default_dtypes() == target_dict
|
|
|
|
@parameterized.parameters(
|
|
"bool", "signed integer", "real floating",
|
|
"complex floating", "integral", "numeric", None,
|
|
(("real floating", "complex floating"),),
|
|
(("integral", "signed integer"),),
|
|
(("integral", "bool"),),
|
|
)
|
|
def test_dtypes_info(self, kind):
|
|
|
|
info_dict = self.info.dtypes(kind=kind)
|
|
control = {
|
|
"bool":self._boolean,
|
|
"signed integer":self._signed,
|
|
"unsigned integer":self._unsigned,
|
|
"real floating":self._floating,
|
|
"complex floating":self._complex,
|
|
"integral": self._integral,
|
|
"numeric": self._numeric
|
|
}
|
|
target_dict = {}
|
|
if kind is None:
|
|
target_dict = control["numeric"] | self._boolean
|
|
elif isinstance(kind, tuple):
|
|
target_dict = {}
|
|
for _kind in kind:
|
|
target_dict |= control[_kind]
|
|
else:
|
|
target_dict = control[kind]
|
|
assert info_dict == target_dict
|
|
|
|
class ArrayAPIErrors(absltest.TestCase):
|
|
"""Test that our array API implementations raise errors where required"""
|
|
|
|
# TODO(micky774): Remove when jnp.clip deprecation is completed
|
|
# (began 2024-4-2) and default behavior is Array API 2023 compliant
|
|
def test_clip_complex(self):
|
|
x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.complex64)
|
|
complex_msg = "Complex values have no ordering and cannot be clipped"
|
|
with self.assertRaisesRegex(ValueError, complex_msg):
|
|
ARRAY_API_NAMESPACE.clip(x)
|
|
|
|
with self.assertRaisesRegex(ValueError, complex_msg):
|
|
ARRAY_API_NAMESPACE.clip(x, max=x)
|
|
|
|
x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.int32)
|
|
with self.assertRaisesRegex(ValueError, complex_msg):
|
|
ARRAY_API_NAMESPACE.clip(x, min=-1+5j)
|
|
|
|
with self.assertRaisesRegex(ValueError, complex_msg):
|
|
ARRAY_API_NAMESPACE.clip(x, max=-1+5j)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main()
|