Fix cuda array interface with old jaxlib.

arg name was added in xla_extension_version 261.

PiperOrigin-RevId: 629745255
This commit is contained in:
Jieying Luo 2024-05-01 09:25:46 -07:00 committed by jax authors
parent e75e4a5991
commit 26049b1059

View File

@ -2538,9 +2538,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
cai=cai, gpu_backend=backend, device_id=device_id
)
else:
object = xc._xla.cuda_array_interface_to_buffer(
cai=cai, gpu_backend=backend
)
object = xc._xla.cuda_array_interface_to_buffer(cai, backend)
object = tree_map(lambda leaf: leaf.__jax_array__()
if hasattr(leaf, "__jax_array__") else leaf, object)