2020-01-29 23:19:14 -05:00
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
from absl.testing import absltest, parameterized
|
|
|
|
|
|
|
|
import jax
|
|
|
|
from jax.config import config
|
|
|
|
import jax.dlpack
|
|
|
|
import jax.numpy as jnp
|
|
|
|
from jax import test_util as jtu
|
|
|
|
|
|
|
|
config.parse_flags_with_absl()
|
2020-06-24 16:15:59 -04:00
|
|
|
FLAGS = config.FLAGS
|
2020-01-29 23:19:14 -05:00
|
|
|
|
|
|
|
try:
|
|
|
|
import torch
|
|
|
|
import torch.utils.dlpack
|
|
|
|
except ImportError:
|
|
|
|
torch = None
|
|
|
|
|
2020-01-31 10:09:40 -05:00
|
|
|
try:
|
|
|
|
import cupy
|
|
|
|
except ImportError:
|
|
|
|
cupy = None
|
2020-01-29 23:19:14 -05:00
|
|
|
|
2020-10-20 13:06:37 -07:00
|
|
|
try:
|
|
|
|
import tensorflow as tf
|
|
|
|
except ImportError:
|
|
|
|
tf = None
|
|
|
|
|
2020-01-31 10:09:40 -05:00
|
|
|
|
2020-02-14 15:14:38 -05:00
|
|
|
dlpack_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
2020-10-20 13:06:37 -07:00
|
|
|
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
|
|
|
jnp.float16, jnp.float32, jnp.float64]
|
2020-01-31 10:09:40 -05:00
|
|
|
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
2020-02-14 15:14:38 -05:00
|
|
|
jnp.uint8, jnp.float16, jnp.float32, jnp.float64]
|
2020-01-29 23:19:14 -05:00
|
|
|
|
|
|
|
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
empty_array_shapes = []
|
2020-03-05 13:10:20 -05:00
|
|
|
empty_array_shapes += [(0,), (0, 4), (3, 0),]
|
|
|
|
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
|
2020-01-29 23:19:14 -05:00
|
|
|
|
|
|
|
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
|
|
|
all_shapes = nonempty_array_shapes + empty_array_shapes
|
|
|
|
|
|
|
|
class DLPackTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
super(DLPackTest, self).setUp()
|
2020-01-29 23:19:14 -05:00
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
self.skipTest("DLPack not supported on TPU")
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-10-20 13:06:37 -07:00
|
|
|
{"testcase_name": "_{}_take_ownership={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
take_ownership),
|
|
|
|
"shape": shape, "dtype": dtype, "take_ownership": take_ownership}
|
2020-01-29 23:19:14 -05:00
|
|
|
for shape in all_shapes
|
2020-10-20 13:06:37 -07:00
|
|
|
for dtype in dlpack_dtypes
|
|
|
|
for take_ownership in [False, True]))
|
|
|
|
def testJaxRoundTrip(self, shape, dtype, take_ownership):
|
|
|
|
if jax.lib.version < (0, 1, 57) and not take_ownership:
|
|
|
|
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-29 23:19:14 -05:00
|
|
|
np = rng(shape, dtype)
|
|
|
|
x = jnp.array(np)
|
2020-10-20 13:06:37 -07:00
|
|
|
dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
|
|
|
|
self.assertEqual(take_ownership, x.device_buffer.is_deleted())
|
2020-01-29 23:19:14 -05:00
|
|
|
y = jax.dlpack.from_dlpack(dlpack)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np.astype(x.dtype), y)
|
2020-01-29 23:19:14 -05:00
|
|
|
|
|
|
|
self.assertRaisesRegex(RuntimeError,
|
|
|
|
"DLPack tensor may be consumed at most once",
|
|
|
|
lambda: jax.dlpack.from_dlpack(dlpack))
|
|
|
|
|
2020-10-20 13:06:37 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in all_shapes
|
|
|
|
for dtype in dlpack_dtypes))
|
|
|
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
|
|
|
def testTensorFlowToJax(self, shape, dtype):
|
|
|
|
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
|
|
|
|
jnp.float64]:
|
|
|
|
raise self.skipTest("x64 types are disabled by jax_enable_x64")
|
|
|
|
if (jtu.device_under_test() == "gpu" and
|
|
|
|
not tf.config.list_physical_devices("GPU")):
|
|
|
|
raise self.skipTest("TensorFlow not configured with GPU support")
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np = rng(shape, dtype)
|
|
|
|
with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"):
|
|
|
|
x = tf.constant(np)
|
|
|
|
dlpack = tf.experimental.dlpack.to_dlpack(x)
|
|
|
|
y = jax.dlpack.from_dlpack(dlpack)
|
|
|
|
self.assertAllClose(np, y)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in all_shapes
|
|
|
|
for dtype in dlpack_dtypes))
|
|
|
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
|
|
|
def testJaxToTensorFlow(self, shape, dtype):
|
2020-10-20 15:35:03 -07:00
|
|
|
if jax.lib.version < (0, 1, 57):
|
|
|
|
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
|
2020-10-20 13:06:37 -07:00
|
|
|
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
|
|
|
|
jnp.float64]:
|
|
|
|
self.skipTest("x64 types are disabled by jax_enable_x64")
|
|
|
|
if (jtu.device_under_test() == "gpu" and
|
|
|
|
not tf.config.list_physical_devices("GPU")):
|
|
|
|
raise self.skipTest("TensorFlow not configured with GPU support")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np = rng(shape, dtype)
|
|
|
|
x = jnp.array(np)
|
|
|
|
# TODO(b/171320191): this line works around a missing context initialization
|
|
|
|
# bug in TensorFlow.
|
|
|
|
_ = tf.add(1, 1)
|
|
|
|
dlpack = jax.dlpack.to_dlpack(x)
|
|
|
|
y = tf.experimental.dlpack.from_dlpack(dlpack)
|
|
|
|
self.assertAllClose(np, y.numpy())
|
|
|
|
|
2020-01-29 23:19:14 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in all_shapes
|
2020-01-31 10:09:40 -05:00
|
|
|
for dtype in torch_dtypes))
|
2020-01-29 23:19:14 -05:00
|
|
|
@unittest.skipIf(not torch, "Test requires PyTorch")
|
|
|
|
def testTorchToJax(self, shape, dtype):
|
2020-06-24 16:15:59 -04:00
|
|
|
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
|
|
|
|
self.skipTest("x64 types are disabled by jax_enable_x64")
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-29 23:19:14 -05:00
|
|
|
np = rng(shape, dtype)
|
|
|
|
x = torch.from_numpy(np)
|
|
|
|
x = x.cuda() if jtu.device_under_test() == "gpu" else x
|
|
|
|
dlpack = torch.utils.dlpack.to_dlpack(x)
|
|
|
|
y = jax.dlpack.from_dlpack(dlpack)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np, y)
|
2020-01-29 23:19:14 -05:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in all_shapes
|
2020-01-31 10:09:40 -05:00
|
|
|
for dtype in torch_dtypes))
|
2020-03-05 13:10:20 -05:00
|
|
|
@unittest.skipIf(not torch, "Test requires PyTorch")
|
2020-01-29 23:19:14 -05:00
|
|
|
def testJaxToTorch(self, shape, dtype):
|
2020-10-20 15:35:03 -07:00
|
|
|
if jax.lib.version < (0, 1, 57):
|
|
|
|
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
|
2020-10-20 13:06:37 -07:00
|
|
|
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
|
|
|
|
self.skipTest("x64 types are disabled by jax_enable_x64")
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-29 23:19:14 -05:00
|
|
|
np = rng(shape, dtype)
|
|
|
|
x = jnp.array(np)
|
|
|
|
dlpack = jax.dlpack.to_dlpack(x)
|
|
|
|
y = torch.utils.dlpack.from_dlpack(dlpack)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np, y.numpy())
|
2020-01-29 23:19:14 -05:00
|
|
|
|
|
|
|
|
2020-01-31 10:09:40 -05:00
|
|
|
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
super(CudaArrayInterfaceTest, self).setUp()
|
2020-01-31 10:09:40 -05:00
|
|
|
if jtu.device_under_test() != "gpu":
|
|
|
|
self.skipTest("__cuda_array_interface__ is only supported on GPU")
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in all_shapes
|
2020-02-14 15:14:38 -05:00
|
|
|
for dtype in dlpack_dtypes))
|
2020-03-05 13:10:20 -05:00
|
|
|
@unittest.skipIf(not cupy, "Test requires CuPy")
|
2020-01-31 10:09:40 -05:00
|
|
|
def testJaxToCuPy(self, shape, dtype):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-31 10:09:40 -05:00
|
|
|
x = rng(shape, dtype)
|
|
|
|
y = jnp.array(x)
|
|
|
|
z = cupy.asarray(y)
|
|
|
|
self.assertEqual(y.__cuda_array_interface__["data"][0],
|
|
|
|
z.__cuda_array_interface__["data"][0])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, cupy.asnumpy(z))
|
2020-01-31 10:09:40 -05:00
|
|
|
|
|
|
|
|
2020-01-29 23:19:14 -05:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|