Add support for bool dlpack values.

PiperOrigin-RevId: 599199196
This commit is contained in:
Peter Hawkins 2024-01-17 09:30:04 -08:00 committed by jax authors
parent 1bd22b0fe1
commit c4368351d2
2 changed files with 7 additions and 0 deletions

View File

@ -52,6 +52,9 @@ Remember to align the itemized text with the first line of an item within a list
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
of the function in SciPy v1.13.
* JAX now supports the bool DLPack type for both import and export.
Previously bool values could not be imported and were exported as integers.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).

View File

@ -22,6 +22,7 @@ from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
@ -30,6 +31,9 @@ SUPPORTED_DTYPES = frozenset({
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
jnp.float64, jnp.complex64, jnp.complex128})
if xla_extension_version >= 231:
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):