Skye Wanderman-Milne a80cbc5626 [JAX] Implement the stream argument to jax.Array.__dlpack__ for CUDA GPU
Also implements jax.Array.__dlpack_device__. See
https://dmlc.github.io/dlpack/latest/python_spec.html

This requires plumbing the raw CUDA stream pointer through PJRT and
StreamExecutor (since the GPU PJRT implementation is still based on
SE). This is done via the new PJRT method
ExternalReference::WaitUntilBufferReadyOnStream.

I haven't plumbed this through the PJRT C API yet, because I'm still
debating whether this should be part of the main API or a GPU-specific
extension (plus either way it should probably be its own change).

PiperOrigin-RevId: 558245360
2023-08-18 14:20:38 -07:00
..
2023-03-24 12:33:33 -07:00
2023-06-23 09:22:11 -07:00
2023-07-21 14:49:44 -04:00
2023-06-27 03:41:38 -07:00
2023-07-21 14:49:44 -04:00
2023-07-21 14:49:44 -04:00
2023-07-21 14:49:44 -04:00