mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Add support for retaining ownership of DLPack tensors.
Move dlpack.py contents under jax/_src/dlpack.py. Add array interoperability test between JAX and TensorFlow using DLPack. Fixes: https://github.com/google/jax/issues/4636 PiperOrigin-RevId: 338120910
This commit is contained in:
parent
41eea2c7c4
commit
d001ac6b8a
@ -13,6 +13,7 @@ jaxlib 0.1.57 (unreleased)
|
||||
------------------------------
|
||||
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
|
||||
and could change (#4651).
|
||||
* Add support for retaining ownership when passing arrays to DLPack (#4636).
|
||||
|
||||
jax 0.2.4 (October 19 2020)
|
||||
--------------------------
|
||||
|
67
jax/_src/dlpack.py
Normal file
67
jax/_src/dlpack.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax import core
|
||||
from jax import lazy
|
||||
from jax.interpreters import xla
|
||||
import jax.lib
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import xla_bridge
|
||||
|
||||
def to_dlpack(x: xla.DeviceArray, take_ownership: bool = False):
|
||||
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
|
||||
|
||||
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
|
||||
state.
|
||||
|
||||
Args:
|
||||
x: a `DeviceArray`, on either CPU or GPU.
|
||||
take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack,
|
||||
and the consumer is free to mutate the buffer; the JAX buffer acts as if
|
||||
it were deleted. If ``False``, JAX retains ownership of the buffer; it is
|
||||
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
||||
owns.
|
||||
"""
|
||||
if not isinstance(x, xla.DeviceArray):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
|
||||
.format(type(x)))
|
||||
buf = xla._force(x).device_buffer
|
||||
if jax.lib.version >= (0, 1, 57):
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
buf, take_ownership=take_ownership)
|
||||
else:
|
||||
# Jaxlibs before 0.1.57 always take ownership.
|
||||
if not take_ownership:
|
||||
raise ValueError(
|
||||
"to_dlpack with take_ownership=False requires jaxlib >= 0.1.57")
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(buf)
|
||||
|
||||
def from_dlpack(dlpack, backend=None):
|
||||
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
|
||||
|
||||
The returned `DeviceArray` shares memory with `dlpack`.
|
||||
|
||||
Args:
|
||||
dlpack: a DLPack tensor, on either CPU or GPU.
|
||||
backend: experimental, optional: the platform on which `dlpack` lives.
|
||||
"""
|
||||
# TODO(phawkins): ideally the user wouldn't need to provide a backend and we
|
||||
# would be able to figure it out from the DLPack.
|
||||
backend = backend or xla_bridge.get_backend()
|
||||
client = getattr(backend, "client", backend)
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
|
||||
xla_shape = buf.shape()
|
||||
assert not xla_shape.is_tuple()
|
||||
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
||||
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf) # pytype: disable=attribute-error
|
@ -12,42 +12,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import core
|
||||
from . import lazy
|
||||
from .interpreters import xla
|
||||
from .lib import xla_client
|
||||
from .lib import xla_bridge
|
||||
|
||||
def to_dlpack(x: xla.DeviceArray):
|
||||
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
|
||||
|
||||
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
|
||||
state.
|
||||
|
||||
Args:
|
||||
x: a `DeviceArray`, on either CPU or GPU.
|
||||
"""
|
||||
if not isinstance(x, xla.DeviceArray):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
|
||||
.format(type(x)))
|
||||
buf = xla._force(x).device_buffer
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(buf)
|
||||
|
||||
def from_dlpack(dlpack, backend=None):
|
||||
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
|
||||
|
||||
The returned `DeviceArray` shares memory with `dlpack`.
|
||||
|
||||
Args:
|
||||
dlpack: a DLPack tensor, on either CPU or GPU.
|
||||
backend: experimental, optional: the platform on which `dlpack` lives.
|
||||
"""
|
||||
# TODO(phawkins): ideally the user wouldn't need to provide a backend and we
|
||||
# would be able to figure it out from the DLPack.
|
||||
backend = backend or xla_bridge.get_backend()
|
||||
client = getattr(backend, "client", backend)
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
|
||||
xla_shape = buf.shape()
|
||||
assert not xla_shape.is_tuple()
|
||||
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
||||
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf) # pytype: disable=attribute-error
|
||||
# flake8: noqa: F401
|
||||
from jax._src.dlpack import (to_dlpack, from_dlpack)
|
||||
|
@ -36,10 +36,15 @@ try:
|
||||
except ImportError:
|
||||
cupy = None
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
tf = None
|
||||
|
||||
|
||||
dlpack_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
||||
jnp.float16, jnp.float32, jnp.float64]
|
||||
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
||||
jnp.float16, jnp.float32, jnp.float64]
|
||||
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.float16, jnp.float32, jnp.float64]
|
||||
|
||||
@ -58,16 +63,21 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.skipTest("DLPack not supported on TPU")
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
{"testcase_name": "_{}_take_ownership={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
take_ownership),
|
||||
"shape": shape, "dtype": dtype, "take_ownership": take_ownership}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes))
|
||||
def testJaxRoundTrip(self, shape, dtype):
|
||||
for dtype in dlpack_dtypes
|
||||
for take_ownership in [False, True]))
|
||||
def testJaxRoundTrip(self, shape, dtype, take_ownership):
|
||||
if jax.lib.version < (0, 1, 57) and not take_ownership:
|
||||
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
x = jnp.array(np)
|
||||
dlpack = jax.dlpack.to_dlpack(x)
|
||||
dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
|
||||
self.assertEqual(take_ownership, x.device_buffer.is_deleted())
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np.astype(x.dtype), y)
|
||||
|
||||
@ -75,6 +85,53 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
"DLPack tensor may be consumed at most once",
|
||||
lambda: jax.dlpack.from_dlpack(dlpack))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes))
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testTensorFlowToJax(self, shape, dtype):
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
|
||||
jnp.float64]:
|
||||
raise self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
if (jtu.device_under_test() == "gpu" and
|
||||
not tf.config.list_physical_devices("GPU")):
|
||||
raise self.skipTest("TensorFlow not configured with GPU support")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"):
|
||||
x = tf.constant(np)
|
||||
dlpack = tf.experimental.dlpack.to_dlpack(x)
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes))
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testJaxToTensorFlow(self, shape, dtype):
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
|
||||
jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
if (jtu.device_under_test() == "gpu" and
|
||||
not tf.config.list_physical_devices("GPU")):
|
||||
raise self.skipTest("TensorFlow not configured with GPU support")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
x = jnp.array(np)
|
||||
# TODO(b/171320191): this line works around a missing context initialization
|
||||
# bug in TensorFlow.
|
||||
_ = tf.add(1, 1)
|
||||
dlpack = jax.dlpack.to_dlpack(x)
|
||||
y = tf.experimental.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y.numpy())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
@ -101,6 +158,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in torch_dtypes))
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
x = jnp.array(np)
|
||||
|
Loading…
x
Reference in New Issue
Block a user