Merge pull request #20294 from Micky774:array_namespace_info

PiperOrigin-RevId: 623877931
This commit is contained in:
jax authors 2024-04-11 11:09:37 -07:00
commit 301c3518d8
4 changed files with 155 additions and 3 deletions

View File

@ -205,6 +205,7 @@ from jax.experimental.array_api._statistical_functions import (
)
from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
all as all,
any as any,
)

View File

@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from __future__ import annotations
import jax
from typing import Tuple
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config
def all(x, /, *, axis=None, keepdims=False):
"""Tests whether all input array elements evaluate to True along a specified axis."""
@ -23,3 +28,66 @@ def all(x, /, *, axis=None, keepdims=False):
def any(x, /, *, axis=None, keepdims=False):
"""Tests whether any input array element evaluates to True along a specified axis."""
return jax.numpy.any(x, axis=axis, keepdims=keepdims)
class __array_namespace_info__:
def __init__(self):
self._capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}
def _build_dtype_dict(self):
array_api_types = {
"bool", "int8", "int16",
"int32", "uint8", "uint16",
"uint32", "float32", "complex64"
}
if config.enable_x64.value:
array_api_types |= {"int64", "uint64", "float64", "complex128"}
return {category: {t.name: t for t in types if t.name in array_api_types}
for category, types in _dtypes._dtype_kinds.items()}
def default_device(self):
# By default JAX arrays are uncommitted (device=None), meaning that
# JAX is free to choose the most efficient device placement.
return None
def devices(self):
return jax.devices()
def capabilities(self):
return self._capabilities
def default_dtypes(self, *, device: xc.Device | Sharding | None = None):
# Array API supported dtypes are device-independent in JAX
del device
default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
return {
dtype_name: _dtypes.canonicalize_dtype(
_dtypes._default_types.get(kind)
) for dtype_name, kind in default_dtypes.items()
}
def dtypes(
self, *,
device: xc.Device | Sharding | None = None,
kind: str | Tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX
del device
data_types = self._build_dtype_dict()
if kind is None:
out_dict = data_types["numeric"] | data_types["bool"]
elif isinstance(kind, tuple):
out_dict = {}
for _kind in kind:
out_dict |= data_types[_kind]
else:
out_dict = data_types[kind]
return out_dict

View File

@ -54,6 +54,7 @@ py_test(
deps = [
"//jax",
"//jax:experimental_array_api",
"//jax:test_util",
] + py_deps("absl/testing"),
)

View File

@ -21,9 +21,11 @@ from __future__ import annotations
from types import ModuleType
from absl.testing import absltest
from absl.testing import absltest, parameterized
import jax
from jax import config
import jax.numpy as jnp
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype
from jax.experimental import array_api
config.parse_flags_with_absl()
@ -233,6 +235,86 @@ class ArrayAPISmokeTest(absltest.TestCase):
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), array_api)
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
info = array_api.__array_namespace_info__()
def setUp(self):
super().setUp()
self._boolean = self.build_dtype_dict(["bool"])
self._signed = self.build_dtype_dict(["int8", "int16", "int32"])
self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"])
self._floating = self.build_dtype_dict(["float32"])
self._complex = self.build_dtype_dict(["complex64"])
if config.enable_x64.value:
self._signed["int64"] = jnp.dtype("int64")
self._unsigned["uint64"] = jnp.dtype("uint64")
self._floating["float64"] = jnp.dtype("float64")
self._complex["complex128"] = jnp.dtype("complex128")
self._integral = self._signed | self._unsigned
self._numeric = (
self._signed | self._unsigned | self._floating | self._complex
)
def build_dtype_dict(self, dtypes):
out = {}
for name in dtypes:
out[name] = jnp.dtype(name)
return out
def test_capabilities_info(self):
capabilities = self.info.capabilities()
assert capabilities["boolean indexing"]
assert not capabilities["data-dependent shapes"]
def test_default_device_info(self):
assert self.info.default_device() is None
def test_devices_info(self):
assert self.info.devices() == jax.devices()
def test_default_dtypes_info(self):
_default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
target_dict = {
dtype_name: canonicalize_dtype(
_default_types.get(kind)
) for dtype_name, kind in _default_dtypes.items()
}
assert self.info.default_dtypes() == target_dict
@parameterized.parameters(
"bool", "signed integer", "real floating",
"complex floating", "integral", "numeric", None,
(("real floating", "complex floating"),),
(("integral", "signed integer"),),
(("integral", "bool"),),
)
def test_dtypes_info(self, kind):
info_dict = self.info.dtypes(kind=kind)
control = {
"bool":self._boolean,
"signed integer":self._signed,
"unsigned integer":self._unsigned,
"real floating":self._floating,
"complex floating":self._complex,
"integral": self._integral,
"numeric": self._numeric
}
target_dict = {}
if kind is None:
target_dict = control["numeric"] | self._boolean
elif isinstance(kind, tuple):
target_dict = {}
for _kind in kind:
target_dict |= control[_kind]
else:
target_dict = control[kind]
assert info_dict == target_dict
class ArrayAPIErrors(absltest.TestCase):
"""Test that our array API implementations raise errors where required"""