mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix cuda array interface with old jaxlib.
arg name was added in xla_extension_version 261. PiperOrigin-RevId: 629745255
This commit is contained in:
parent
e75e4a5991
commit
26049b1059
@ -2538,9 +2538,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
cai=cai, gpu_backend=backend, device_id=device_id
|
||||
)
|
||||
else:
|
||||
object = xc._xla.cuda_array_interface_to_buffer(
|
||||
cai=cai, gpu_backend=backend
|
||||
)
|
||||
object = xc._xla.cuda_array_interface_to_buffer(cai, backend)
|
||||
|
||||
object = tree_map(lambda leaf: leaf.__jax_array__()
|
||||
if hasattr(leaf, "__jax_array__") else leaf, object)
|
||||
|
Loading…
x
Reference in New Issue
Block a user