mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs. Default to check_dtypes=True. Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense. No functional changes intended. * Fix a number of lax reference implementations to preserve types.
136 lines
4.4 KiB
Python
136 lines
4.4 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
from absl.testing import absltest, parameterized
|
|
|
|
import jax
|
|
from jax.config import config
|
|
import jax.dlpack
|
|
import jax.numpy as jnp
|
|
from jax import test_util as jtu
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
try:
|
|
import torch
|
|
import torch.utils.dlpack
|
|
except ImportError:
|
|
torch = None
|
|
|
|
try:
|
|
import cupy
|
|
except ImportError:
|
|
cupy = None
|
|
|
|
|
|
dlpack_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
|
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
|
jnp.float16, jnp.float32, jnp.float64]
|
|
all_dtypes = dlpack_dtypes + [jnp.bool_, jnp.bfloat16]
|
|
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
|
jnp.uint8, jnp.float16, jnp.float32, jnp.float64]
|
|
|
|
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
|
empty_array_shapes = []
|
|
empty_array_shapes += [(0,), (0, 4), (3, 0),]
|
|
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
|
|
|
|
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
|
all_shapes = nonempty_array_shapes + empty_array_shapes
|
|
|
|
class DLPackTest(jtu.JaxTestCase):
|
|
def setUp(self):
|
|
super(DLPackTest, self).setUp()
|
|
if jtu.device_under_test() == "tpu":
|
|
self.skipTest("DLPack not supported on TPU")
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in all_shapes
|
|
for dtype in dlpack_dtypes))
|
|
def testJaxRoundTrip(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
np = rng(shape, dtype)
|
|
x = jnp.array(np)
|
|
dlpack = jax.dlpack.to_dlpack(x)
|
|
y = jax.dlpack.from_dlpack(dlpack)
|
|
self.assertAllClose(np.astype(x.dtype), y)
|
|
|
|
self.assertRaisesRegex(RuntimeError,
|
|
"DLPack tensor may be consumed at most once",
|
|
lambda: jax.dlpack.from_dlpack(dlpack))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in all_shapes
|
|
for dtype in torch_dtypes))
|
|
@unittest.skipIf(not torch, "Test requires PyTorch")
|
|
def testTorchToJax(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
np = rng(shape, dtype)
|
|
x = torch.from_numpy(np)
|
|
x = x.cuda() if jtu.device_under_test() == "gpu" else x
|
|
dlpack = torch.utils.dlpack.to_dlpack(x)
|
|
y = jax.dlpack.from_dlpack(dlpack)
|
|
self.assertAllClose(np, y)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in all_shapes
|
|
for dtype in torch_dtypes))
|
|
@unittest.skipIf(not torch, "Test requires PyTorch")
|
|
def testJaxToTorch(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
np = rng(shape, dtype)
|
|
x = jnp.array(np)
|
|
dlpack = jax.dlpack.to_dlpack(x)
|
|
y = torch.utils.dlpack.from_dlpack(dlpack)
|
|
self.assertAllClose(np, y.numpy())
|
|
|
|
|
|
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super(CudaArrayInterfaceTest, self).setUp()
|
|
if jtu.device_under_test() != "gpu":
|
|
self.skipTest("__cuda_array_interface__ is only supported on GPU")
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in all_shapes
|
|
for dtype in dlpack_dtypes))
|
|
@unittest.skipIf(not cupy, "Test requires CuPy")
|
|
def testJaxToCuPy(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng(shape, dtype)
|
|
y = jnp.array(x)
|
|
z = cupy.asarray(y)
|
|
self.assertEqual(y.__cuda_array_interface__["data"][0],
|
|
z.__cuda_array_interface__["data"][0])
|
|
self.assertAllClose(x, cupy.asnumpy(z))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|