[JAX] Add support for retaining ownership of DLPack tensors.

Move dlpack.py contents under jax/_src/dlpack.py.

Add array interoperability test between JAX and TensorFlow using DLPack.

Fixes: https://github.com/google/jax/issues/4636
PiperOrigin-RevId: 338120910
This commit is contained in:
Peter Hawkins 2020-10-20 13:06:37 -07:00 committed by jax authors
parent 41eea2c7c4
commit d001ac6b8a
4 changed files with 137 additions and 47 deletions

View File

@ -13,6 +13,7 @@ jaxlib 0.1.57 (unreleased)
------------------------------
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
and could change (#4651).
* Add support for retaining ownership when passing arrays to DLPack (#4636).
jax 0.2.4 (October 19 2020)
--------------------------

67
jax/_src/dlpack.py Normal file
View File

@ -0,0 +1,67 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import core
from jax import lazy
from jax.interpreters import xla
import jax.lib
from jax.lib import xla_client
from jax.lib import xla_bridge
def to_dlpack(x: xla.DeviceArray, take_ownership: bool = False):
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
state.
Args:
x: a `DeviceArray`, on either CPU or GPU.
take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack,
and the consumer is free to mutate the buffer; the JAX buffer acts as if
it were deleted. If ``False``, JAX retains ownership of the buffer; it is
undefined behavior if the DLPack consumer writes to a buffer that JAX
owns.
"""
if not isinstance(x, xla.DeviceArray):
raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
.format(type(x)))
buf = xla._force(x).device_buffer
if jax.lib.version >= (0, 1, 57):
return xla_client._xla.buffer_to_dlpack_managed_tensor(
buf, take_ownership=take_ownership)
else:
# Jaxlibs before 0.1.57 always take ownership.
if not take_ownership:
raise ValueError(
"to_dlpack with take_ownership=False requires jaxlib >= 0.1.57")
return xla_client._xla.buffer_to_dlpack_managed_tensor(buf)
def from_dlpack(dlpack, backend=None):
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
The returned `DeviceArray` shares memory with `dlpack`.
Args:
dlpack: a DLPack tensor, on either CPU or GPU.
backend: experimental, optional: the platform on which `dlpack` lives.
"""
# TODO(phawkins): ideally the user wouldn't need to provide a backend and we
# would be able to figure it out from the DLPack.
backend = backend or xla_bridge.get_backend()
client = getattr(backend, "client", backend)
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
xla_shape = buf.shape()
assert not xla_shape.is_tuple()
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf) # pytype: disable=attribute-error

View File

@ -12,42 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import core
from . import lazy
from .interpreters import xla
from .lib import xla_client
from .lib import xla_bridge
def to_dlpack(x: xla.DeviceArray):
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
state.
Args:
x: a `DeviceArray`, on either CPU or GPU.
"""
if not isinstance(x, xla.DeviceArray):
raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
.format(type(x)))
buf = xla._force(x).device_buffer
return xla_client._xla.buffer_to_dlpack_managed_tensor(buf)
def from_dlpack(dlpack, backend=None):
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
The returned `DeviceArray` shares memory with `dlpack`.
Args:
dlpack: a DLPack tensor, on either CPU or GPU.
backend: experimental, optional: the platform on which `dlpack` lives.
"""
# TODO(phawkins): ideally the user wouldn't need to provide a backend and we
# would be able to figure it out from the DLPack.
backend = backend or xla_bridge.get_backend()
client = getattr(backend, "client", backend)
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
xla_shape = buf.shape()
assert not xla_shape.is_tuple()
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf) # pytype: disable=attribute-error
# flake8: noqa: F401
from jax._src.dlpack import (to_dlpack, from_dlpack)

View File

@ -36,10 +36,15 @@ try:
except ImportError:
cupy = None
try:
import tensorflow as tf
except ImportError:
tf = None
dlpack_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
jnp.float16, jnp.float32, jnp.float64]
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
jnp.float16, jnp.float32, jnp.float64]
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.float16, jnp.float32, jnp.float64]
@ -58,16 +63,21 @@ class DLPackTest(jtu.JaxTestCase):
self.skipTest("DLPack not supported on TPU")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
{"testcase_name": "_{}_take_ownership={}".format(
jtu.format_shape_dtype_string(shape, dtype),
take_ownership),
"shape": shape, "dtype": dtype, "take_ownership": take_ownership}
for shape in all_shapes
for dtype in dlpack_dtypes))
def testJaxRoundTrip(self, shape, dtype):
for dtype in dlpack_dtypes
for take_ownership in [False, True]))
def testJaxRoundTrip(self, shape, dtype, take_ownership):
if jax.lib.version < (0, 1, 57) and not take_ownership:
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = jnp.array(np)
dlpack = jax.dlpack.to_dlpack(x)
dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
self.assertEqual(take_ownership, x.device_buffer.is_deleted())
y = jax.dlpack.from_dlpack(dlpack)
self.assertAllClose(np.astype(x.dtype), y)
@ -75,6 +85,53 @@ class DLPackTest(jtu.JaxTestCase):
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in dlpack_dtypes))
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testTensorFlowToJax(self, shape, dtype):
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
jnp.float64]:
raise self.skipTest("x64 types are disabled by jax_enable_x64")
if (jtu.device_under_test() == "gpu" and
not tf.config.list_physical_devices("GPU")):
raise self.skipTest("TensorFlow not configured with GPU support")
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"):
x = tf.constant(np)
dlpack = tf.experimental.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in dlpack_dtypes))
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testJaxToTensorFlow(self, shape, dtype):
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
jnp.float64]:
self.skipTest("x64 types are disabled by jax_enable_x64")
if (jtu.device_under_test() == "gpu" and
not tf.config.list_physical_devices("GPU")):
raise self.skipTest("TensorFlow not configured with GPU support")
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = jnp.array(np)
# TODO(b/171320191): this line works around a missing context initialization
# bug in TensorFlow.
_ = tf.add(1, 1)
dlpack = jax.dlpack.to_dlpack(x)
y = tf.experimental.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y.numpy())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
@ -101,6 +158,8 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in torch_dtypes))
@unittest.skipIf(not torch, "Test requires PyTorch")
def testJaxToTorch(self, shape, dtype):
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = jnp.array(np)