mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
208 lines
7.3 KiB
Python
208 lines
7.3 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from functools import partial
|
|
|
|
from absl.testing import absltest
|
|
|
|
import numpy as np
|
|
import numpy.random as npr
|
|
from unittest import SkipTest
|
|
|
|
import jax
|
|
from jax._src import test_util as jtu
|
|
from jax import numpy as jnp
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
npr.seed(0)
|
|
|
|
|
|
class MultiBackendTest(jtu.JaxTestCase):
|
|
"""Tests jit targeting to different backends."""
|
|
|
|
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu', None])
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def testMultiBackend(self, backend):
|
|
if backend not in ('cpu', jtu.device_under_test(), None):
|
|
raise SkipTest("Backend is not CPU or the device under test")
|
|
@partial(jax.jit, backend=backend)
|
|
def fun(x, y):
|
|
return jnp.matmul(x, y)
|
|
|
|
x = npr.uniform(size=(10, 10))
|
|
y = npr.uniform(size=(10, 10))
|
|
z_host = np.matmul(x, y)
|
|
z = fun(x, y)
|
|
self.assertAllClose(z, z_host, rtol=1e-2)
|
|
correct_platform = backend if backend else jtu.device_under_test()
|
|
self.assertEqual(list(z.devices())[0].platform, correct_platform)
|
|
|
|
@jtu.sample_product(
|
|
ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)]
|
|
)
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def testMultiBackendNestedJit(self, ordering):
|
|
outer, inner = ordering
|
|
if outer not in ('cpu', jtu.device_under_test(), None):
|
|
raise SkipTest("Backend is not CPU or the device under test")
|
|
@partial(jax.jit, backend=outer)
|
|
def fun(x, y):
|
|
|
|
@partial(jax.jit, backend=inner)
|
|
def infun(x, y):
|
|
return jnp.matmul(x, y)
|
|
|
|
return infun(x, y) + jnp.ones_like(x)
|
|
|
|
x = npr.uniform(size=(10, 10))
|
|
y = npr.uniform(size=(10, 10))
|
|
z_host = np.matmul(x, y) + np.ones_like(x)
|
|
z = fun(x, y)
|
|
self.assertAllClose(z, z_host, rtol=1e-2)
|
|
correct_platform = outer if outer else jtu.device_under_test()
|
|
self.assertEqual(list(z.devices())[0].platform, correct_platform)
|
|
|
|
@jtu.sample_product(
|
|
ordering=[('cpu', 'gpu'), ('gpu', 'cpu'), ('cpu', 'tpu'), ('tpu', 'cpu'),
|
|
(None, 'cpu'), (None, 'gpu'), (None, 'tpu'),
|
|
],
|
|
)
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def testMultiBackendNestedJitConflict(self, ordering):
|
|
outer, inner = ordering
|
|
if outer not in ('cpu', jtu.device_under_test(), None):
|
|
raise SkipTest("Backend is not CPU or the device under test")
|
|
if inner not in ('cpu', jtu.device_under_test(), None):
|
|
raise SkipTest("Backend is not CPU or the device under test")
|
|
if outer is None and inner == jtu.device_under_test():
|
|
raise SkipTest("(None, device) is allowed")
|
|
if outer is None:
|
|
raise SkipTest("The inner device will dictate the device assignment for "
|
|
"the entire computation. So if inner is CPU and outer is "
|
|
"None, then the computation will be execute on CPU.")
|
|
|
|
@partial(jax.jit, backend=outer)
|
|
def fun(x, y):
|
|
|
|
@partial(jax.jit, backend=inner)
|
|
def infun(x, y):
|
|
return jnp.matmul(x, y)
|
|
|
|
return infun(x, y) + jnp.ones_like(x)
|
|
|
|
x = npr.uniform(size=(10, 10))
|
|
y = npr.uniform(size=(10, 10))
|
|
self.assertRaises(ValueError, lambda: fun(x, y))
|
|
|
|
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu'])
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def testGpuMultiBackendOpByOpReturn(self, backend):
|
|
if backend not in ('cpu', jtu.device_under_test()):
|
|
raise SkipTest("Backend is not CPU or the device under test")
|
|
@partial(jax.jit, backend=backend)
|
|
def fun(x, y):
|
|
return jnp.matmul(x, y)
|
|
x = npr.uniform(size=(10,10))
|
|
y = npr.uniform(size=(10,10))
|
|
z = fun(x, y)
|
|
w = jnp.sin(z)
|
|
self.assertEqual(list(z.devices())[0].platform, backend)
|
|
self.assertEqual(list(w.devices())[0].platform, backend)
|
|
|
|
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def testJitCpu(self):
|
|
@partial(jax.jit, backend='cpu')
|
|
def get_arr(scale):
|
|
return scale + jnp.ones((2, 2))
|
|
|
|
x = get_arr(0.1)
|
|
|
|
a = x / x.shape[0]
|
|
b = x + jnp.ones_like(x)
|
|
c = x + jnp.eye(2)
|
|
|
|
self.assertEqual(a.devices(), {jax.devices('cpu')[0]})
|
|
self.assertEqual(b.devices(), {jax.devices('cpu')[0]})
|
|
self.assertEqual(c.devices(), {jax.devices('cpu')[0]})
|
|
|
|
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def test_closed_over_values_device_placement(self):
|
|
# see https://github.com/jax-ml/jax/issues/1431
|
|
def f(): return jnp.add(3., 4.)
|
|
self.assertNotEqual(jax.jit(f)().devices(),
|
|
{jax.devices('cpu')[0]})
|
|
self.assertEqual(jax.jit(f, backend='cpu')().devices(),
|
|
{jax.devices('cpu')[0]})
|
|
|
|
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def test_jit_on_nondefault_backend(self):
|
|
cpus = jax.devices("cpu")
|
|
self.assertNotEmpty(cpus)
|
|
|
|
# Since we are not on CPU, some other backend will be the default
|
|
default_dev = jax.devices()[0]
|
|
self.assertNotEqual(default_dev.platform, "cpu")
|
|
|
|
data_on_cpu = jax.device_put(1, device=cpus[0])
|
|
self.assertEqual(data_on_cpu.devices(), {cpus[0]})
|
|
|
|
def my_sin(x): return jnp.sin(x)
|
|
# jit without any device spec follows the data
|
|
result1 = jax.jit(my_sin)(2)
|
|
self.assertEqual(result1.devices(), {default_dev})
|
|
result2 = jax.jit(my_sin)(data_on_cpu)
|
|
self.assertEqual(result2.devices(), {cpus[0]})
|
|
|
|
# jit with `device` spec places the data on the specified device
|
|
result3 = jax.jit(my_sin, device=cpus[0])(2)
|
|
self.assertEqual(result3.devices(), {cpus[0]})
|
|
|
|
# jit with `backend` spec places the data on the specified backend
|
|
result4 = jax.jit(my_sin, backend="cpu")(2)
|
|
self.assertEqual(result4.devices(), {cpus[0]})
|
|
|
|
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
|
|
def test_indexing(self):
|
|
# https://github.com/jax-ml/jax/issues/2905
|
|
cpus = jax.devices("cpu")
|
|
|
|
x = jax.device_put(np.ones(2), cpus[0])
|
|
y = x[0]
|
|
self.assertEqual(y.devices(), {cpus[0]})
|
|
|
|
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
|
|
def test_sum(self):
|
|
# https://github.com/jax-ml/jax/issues/2905
|
|
cpus = jax.devices("cpu")
|
|
|
|
x = jax.device_put(np.ones(2), cpus[0])
|
|
y = x.sum()
|
|
self.assertEqual(y.devices(), {cpus[0]})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|