diff --git a/CHANGELOG.md b/CHANGELOG.md index 12d33b922..611b51194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,9 @@ Remember to align the itemized text with the first line of an item within a list convention for the complex sign, `x / abs(x)`. This is consistent with the behavior of the function in SciPy v1.13. + * JAX now supports the bool DLPack type for both import and export. + Previously bool values could not be imported and were exported as integers. + * Deprecations & Removals * A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see {ref}`api-compatibility`). diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index a6e4ae4dc..bdab508fd 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -22,6 +22,7 @@ from jax import numpy as jnp from jax._src import array from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax._src.typing import Array @@ -30,6 +31,9 @@ SUPPORTED_DTYPES = frozenset({ jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64, jnp.complex64, jnp.complex128}) +if xla_extension_version >= 231: + SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_}) + # Mirror of dlpack.h enum class DLDeviceType(enum.IntEnum):