mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Disable some jax2tf tests that fail on GPU.
Fix TF/JAX array interoperability test on GPU.
This commit is contained in:
parent
c5bfdccd64
commit
6e48050730
@ -175,6 +175,8 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized_jit
|
||||
def test_with_var_read(self, with_jit=True):
|
||||
if jtu.device_under_test() == "gpu":
|
||||
raise unittest.SkipTest("Test fails on GPU")
|
||||
outer_var = tf.Variable(3., dtype=np.float32)
|
||||
|
||||
def fun_tf(x):
|
||||
@ -211,6 +213,8 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized_jit
|
||||
def test_with_multiple_capture(self, with_jit=True):
|
||||
if jtu.device_under_test() == "gpu":
|
||||
raise unittest.SkipTest("Test fails on GPU")
|
||||
v2 = tf.Variable(2., dtype=np.float32)
|
||||
v3 = tf.Variable(3., dtype=np.float32)
|
||||
t4 = tf.constant(4., dtype=np.float32)
|
||||
|
@ -313,6 +313,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
Jax2TfLimitation(
|
||||
"jax2tf BUG: batch_group_count > 1 not yet converted",
|
||||
enabled=(harness.params["batch_group_count"] > 1)),
|
||||
missing_tf_kernel(dtypes=[np.complex64, np.complex128], devices="gpu"),
|
||||
custom_numeric(devices="gpu", tol=1e-4),
|
||||
custom_numeric(devices="tpu", tol=1e-3),
|
||||
# TODO(bchetioui): significant discrepancies in some float16 cases.
|
||||
@ -723,6 +724,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
tst.assertAllClose(result_jax[~special_cases], result_tf[~special_cases])
|
||||
|
||||
return [
|
||||
# TODO(necula): Produces mismatched outputs on GPU.
|
||||
Jax2TfLimitation("mismatched outputs on GPU",
|
||||
devices=("gpu",), skip_comparison=True),
|
||||
missing_tf_kernel(
|
||||
dtypes=[dtypes.bfloat16, np.float16]),
|
||||
custom_numeric(
|
||||
@ -758,6 +762,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
rtol=tol)
|
||||
|
||||
return [
|
||||
# TODO(necula): Produces mismatched outputs on GPU.
|
||||
Jax2TfLimitation("mismatched outputs on GPU",
|
||||
devices=("gpu",), skip_comparison=True),
|
||||
missing_tf_kernel(
|
||||
dtypes=[dtypes.bfloat16, np.float16]),
|
||||
custom_numeric(dtypes=np.float64, tol=1e-9),
|
||||
|
@ -93,17 +93,19 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in dlpack_dtypes))
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testTensorFlowToJax(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
|
||||
jnp.float64]:
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]:
|
||||
raise self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
if (jtu.device_under_test() == "gpu" and
|
||||
not tf.config.list_physical_devices("GPU")):
|
||||
raise self.skipTest("TensorFlow not configured with GPU support")
|
||||
|
||||
if jtu.device_under_test() == "gpu" and dtype == jnp.int32:
|
||||
raise self.skipTest("TensorFlow does not place int32 tensors on GPU")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"):
|
||||
x = tf.constant(np)
|
||||
x = tf.identity(tf.constant(np))
|
||||
dlpack = tf.experimental.dlpack.to_dlpack(x)
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user