mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

Instead, we skip tests that the PJRT C API doesn't support. We had this tag for feature development so it was easy to broadly disable, but now we don't expect to need to do that.
131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
# Copyright 2019 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import threading
|
|
from unittest import SkipTest
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax import lax, numpy as jnp
|
|
from jax import config
|
|
from jax.experimental import host_callback as hcb
|
|
from jax._src import core
|
|
from jax._src import xla_bridge
|
|
from jax._src.lib import xla_client
|
|
import jax._src.test_util as jtu
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
class InfeedTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if xla_bridge.using_pjrt_c_api():
|
|
raise SkipTest("infeed not implemented in PJRT C API")
|
|
super().setUp()
|
|
|
|
@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.
|
|
def testInfeed(self):
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
token = lax.create_token(x)
|
|
(y,), token = lax.infeed(
|
|
token, shape=(core.ShapedArray((3, 4), jnp.float32),))
|
|
(z,), _ = lax.infeed(
|
|
token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),))
|
|
return x + y + z
|
|
|
|
x = np.float32(1.5)
|
|
y = np.reshape(np.arange(12, dtype=np.float32), (3, 4)) # self.rng().randn(3, 4).astype(np.float32)
|
|
z = self.rng().randn(3, 1, 1).astype(np.float32)
|
|
device = jax.local_devices()[0]
|
|
device.transfer_to_infeed((y,))
|
|
device.transfer_to_infeed((z,))
|
|
self.assertAllClose(f(x), x + y + z)
|
|
|
|
def testInfeedPytree(self):
|
|
|
|
x = np.float32(1.5)
|
|
y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))
|
|
to_infeed = dict(a=x, b=y)
|
|
to_infeed_shape = dict(a=core.ShapedArray((), dtype=np.float32),
|
|
b=core.ShapedArray((3, 4), dtype=np.int16))
|
|
@jax.jit
|
|
def f(x):
|
|
token = lax.create_token(x)
|
|
res, token = lax.infeed(token, shape=to_infeed_shape)
|
|
return res
|
|
|
|
device = jax.local_devices()[0]
|
|
# We must transfer the flattened data, as a tuple!!!
|
|
flat_to_infeed, _ = jax.tree_util.tree_flatten(to_infeed)
|
|
device.transfer_to_infeed(tuple(flat_to_infeed))
|
|
self.assertAllClose(f(x), to_infeed)
|
|
|
|
@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.
|
|
def testInfeedThenOutfeed(self):
|
|
hcb.stop_outfeed_receiver()
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
token = lax.create_token(x)
|
|
y, token = lax.infeed(
|
|
token, shape=core.ShapedArray((3, 4), jnp.float32))
|
|
token = lax.outfeed(token, y + np.float32(1))
|
|
return x - 1
|
|
|
|
x = np.float32(7.5)
|
|
y = self.rng().randn(3, 4).astype(np.float32)
|
|
execution = threading.Thread(target=lambda: f(x))
|
|
execution.start()
|
|
device = jax.local_devices()[0]
|
|
device.transfer_to_infeed((y,))
|
|
out, = device.transfer_from_outfeed(
|
|
xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent())
|
|
execution.join()
|
|
self.assertAllClose(out, y + np.float32(1))
|
|
|
|
def testInfeedThenOutfeedInALoop(self):
|
|
hcb.stop_outfeed_receiver()
|
|
|
|
def doubler(_, token):
|
|
y, token = lax.infeed(
|
|
token, shape=core.ShapedArray((3, 4), jnp.float32))
|
|
return lax.outfeed(token, y * np.float32(2))
|
|
|
|
@jax.jit
|
|
def f(n):
|
|
token = lax.create_token(n)
|
|
token = lax.fori_loop(0, n, doubler, token)
|
|
return n
|
|
|
|
device = jax.local_devices()[0]
|
|
n = 10
|
|
execution = threading.Thread(target=lambda: f(n))
|
|
execution.start()
|
|
for _ in range(n):
|
|
x = self.rng().randn(3, 4).astype(np.float32)
|
|
device.transfer_to_infeed((x,))
|
|
y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,))
|
|
.with_major_to_minor_layout_if_absent())
|
|
self.assertAllClose(y, x * np.float32(2))
|
|
execution.join()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|