mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[array API] make capabilities more accurate
This commit is contained in:
parent
ec2f0f5913
commit
096810a721
@ -51,8 +51,9 @@ class ArrayNamespaceInfo:
|
||||
.. _Python array API: https://data-apis.org/array-api/
|
||||
"""
|
||||
_capabilities = {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": False,
|
||||
"boolean indexing": False, # within transformations
|
||||
"data-dependent shapes": False, # within transformations
|
||||
"max dimensions": 64, # XLA limitation
|
||||
}
|
||||
|
||||
def _build_dtype_dict(self):
|
||||
|
@ -275,8 +275,9 @@ class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_capabilities_info(self):
|
||||
capabilities = self.info.capabilities()
|
||||
assert capabilities["boolean indexing"]
|
||||
assert not capabilities["boolean indexing"]
|
||||
assert not capabilities["data-dependent shapes"]
|
||||
assert capabilities["max dimensions"] == 64
|
||||
|
||||
def test_default_device_info(self):
|
||||
assert self.info.default_device() is None
|
||||
|
Loading…
x
Reference in New Issue
Block a user