mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add support for bool dlpack values.
PiperOrigin-RevId: 599199196
This commit is contained in:
parent
1bd22b0fe1
commit
c4368351d2
@ -52,6 +52,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
|
||||
of the function in SciPy v1.13.
|
||||
|
||||
* JAX now supports the bool DLPack type for both import and export.
|
||||
Previously bool values could not be imported and were exported as integers.
|
||||
|
||||
* Deprecations & Removals
|
||||
* A number of previously deprecated functions have been removed, following a
|
||||
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
|
||||
|
@ -22,6 +22,7 @@ from jax import numpy as jnp
|
||||
from jax._src import array
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
@ -30,6 +31,9 @@ SUPPORTED_DTYPES = frozenset({
|
||||
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
|
||||
jnp.float64, jnp.complex64, jnp.complex128})
|
||||
|
||||
if xla_extension_version >= 231:
|
||||
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
|
||||
|
||||
|
||||
# Mirror of dlpack.h enum
|
||||
class DLDeviceType(enum.IntEnum):
|
||||
|
Loading…
x
Reference in New Issue
Block a user