mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add __array_namespace_info__ and corresponding utilities
This commit is contained in:
parent
033992867f
commit
e6508a4f47
@ -204,6 +204,7 @@ from jax.experimental.array_api._statistical_functions import (
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._utility_functions import (
|
||||
__array_namespace_info__ as __array_namespace_info__,
|
||||
all as all,
|
||||
any as any,
|
||||
)
|
||||
|
@ -12,8 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
from __future__ import annotations
|
||||
|
||||
import jax
|
||||
from typing import Tuple
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import dtypes as _dtypes, config
|
||||
|
||||
def all(x, /, *, axis=None, keepdims=False):
|
||||
"""Tests whether all input array elements evaluate to True along a specified axis."""
|
||||
@ -23,3 +28,66 @@ def all(x, /, *, axis=None, keepdims=False):
|
||||
def any(x, /, *, axis=None, keepdims=False):
|
||||
"""Tests whether any input array element evaluates to True along a specified axis."""
|
||||
return jax.numpy.any(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
class __array_namespace_info__:
|
||||
|
||||
def __init__(self):
|
||||
self._capabilities = {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": False,
|
||||
}
|
||||
|
||||
|
||||
def _build_dtype_dict(self):
|
||||
array_api_types = {
|
||||
"bool", "int8", "int16",
|
||||
"int32", "uint8", "uint16",
|
||||
"uint32", "float32", "complex64"
|
||||
}
|
||||
if config.enable_x64.value:
|
||||
array_api_types |= {"int64", "uint64", "float64", "complex128"}
|
||||
return {category: {t.name: t for t in types if t.name in array_api_types}
|
||||
for category, types in _dtypes._dtype_kinds.items()}
|
||||
|
||||
def default_device(self):
|
||||
# By default JAX arrays are uncommitted (device=None), meaning that
|
||||
# JAX is free to choose the most efficient device placement.
|
||||
return None
|
||||
|
||||
def devices(self):
|
||||
return jax.devices()
|
||||
|
||||
def capabilities(self):
|
||||
return self._capabilities
|
||||
|
||||
def default_dtypes(self, *, device: xc.Device | Sharding | None = None):
|
||||
# Array API supported dtypes are device-independent in JAX
|
||||
del device
|
||||
default_dtypes = {
|
||||
"real floating": "f",
|
||||
"complex floating": "c",
|
||||
"integral": "i",
|
||||
"indexing": "i",
|
||||
}
|
||||
return {
|
||||
dtype_name: _dtypes.canonicalize_dtype(
|
||||
_dtypes._default_types.get(kind)
|
||||
) for dtype_name, kind in default_dtypes.items()
|
||||
}
|
||||
|
||||
def dtypes(
|
||||
self, *,
|
||||
device: xc.Device | Sharding | None = None,
|
||||
kind: str | Tuple[str, ...] | None = None):
|
||||
# Array API supported dtypes are device-independent in JAX
|
||||
del device
|
||||
data_types = self._build_dtype_dict()
|
||||
if kind is None:
|
||||
out_dict = data_types["numeric"] | data_types["bool"]
|
||||
elif isinstance(kind, tuple):
|
||||
out_dict = {}
|
||||
for _kind in kind:
|
||||
out_dict |= data_types[_kind]
|
||||
else:
|
||||
out_dict = data_types[kind]
|
||||
return out_dict
|
||||
|
@ -54,6 +54,7 @@ py_test(
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:experimental_array_api",
|
||||
"//jax:test_util",
|
||||
] + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
|
@ -21,9 +21,11 @@ from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config, test_util as jtu
|
||||
from jax._src.dtypes import _default_types, canonicalize_dtype
|
||||
from jax.experimental import array_api
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -232,6 +234,86 @@ class ArrayAPISmokeTest(absltest.TestCase):
|
||||
self.assertIsInstance(x, jax.Array)
|
||||
self.assertIs(x.__array_namespace__(), array_api)
|
||||
|
||||
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
|
||||
|
||||
info = array_api.__array_namespace_info__()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._boolean = self.build_dtype_dict(["bool"])
|
||||
self._signed = self.build_dtype_dict(["int8", "int16", "int32"])
|
||||
self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"])
|
||||
self._floating = self.build_dtype_dict(["float32"])
|
||||
self._complex = self.build_dtype_dict(["complex64"])
|
||||
if config.enable_x64.value:
|
||||
self._signed["int64"] = jnp.dtype("int64")
|
||||
self._unsigned["uint64"] = jnp.dtype("uint64")
|
||||
self._floating["float64"] = jnp.dtype("float64")
|
||||
self._complex["complex128"] = jnp.dtype("complex128")
|
||||
self._integral = self._signed | self._unsigned
|
||||
self._numeric = (
|
||||
self._signed | self._unsigned | self._floating | self._complex
|
||||
)
|
||||
def build_dtype_dict(self, dtypes):
|
||||
out = {}
|
||||
for name in dtypes:
|
||||
out[name] = jnp.dtype(name)
|
||||
return out
|
||||
|
||||
def test_capabilities_info(self):
|
||||
capabilities = self.info.capabilities()
|
||||
assert capabilities["boolean indexing"]
|
||||
assert not capabilities["data-dependent shapes"]
|
||||
|
||||
def test_default_device_info(self):
|
||||
assert self.info.default_device() is None
|
||||
|
||||
def test_devices_info(self):
|
||||
assert self.info.devices() == jax.devices()
|
||||
|
||||
def test_default_dtypes_info(self):
|
||||
_default_dtypes = {
|
||||
"real floating": "f",
|
||||
"complex floating": "c",
|
||||
"integral": "i",
|
||||
"indexing": "i",
|
||||
}
|
||||
target_dict = {
|
||||
dtype_name: canonicalize_dtype(
|
||||
_default_types.get(kind)
|
||||
) for dtype_name, kind in _default_dtypes.items()
|
||||
}
|
||||
assert self.info.default_dtypes() == target_dict
|
||||
|
||||
@parameterized.parameters(
|
||||
"bool", "signed integer", "real floating",
|
||||
"complex floating", "integral", "numeric", None,
|
||||
(("real floating", "complex floating"),),
|
||||
(("integral", "signed integer"),),
|
||||
(("integral", "bool"),),
|
||||
)
|
||||
def test_dtypes_info(self, kind):
|
||||
|
||||
info_dict = self.info.dtypes(kind=kind)
|
||||
control = {
|
||||
"bool":self._boolean,
|
||||
"signed integer":self._signed,
|
||||
"unsigned integer":self._unsigned,
|
||||
"real floating":self._floating,
|
||||
"complex floating":self._complex,
|
||||
"integral": self._integral,
|
||||
"numeric": self._numeric
|
||||
}
|
||||
target_dict = {}
|
||||
if kind is None:
|
||||
target_dict = control["numeric"] | self._boolean
|
||||
elif isinstance(kind, tuple):
|
||||
target_dict = {}
|
||||
for _kind in kind:
|
||||
target_dict |= control[_kind]
|
||||
else:
|
||||
target_dict = control[kind]
|
||||
assert info_dict == target_dict
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user