[jax2tf] Removed call_tf tests that are not applicable anymore.

A recent change in TensorFlow makes copies of np.ndarray when they
are turned into tf.constant. This means that call_tf cannot guarantee
anymore no-copy. Removing those tests, and the paragraph in the
documentation that describes this property.

PiperOrigin-RevId: 521120090
This commit is contained in:
George Necula 2023-04-01 03:06:32 -07:00 committed by jax authors
parent 2432adefc3
commit 88f77bbcc6
2 changed files with 0 additions and 25 deletions

View File

@ -1360,12 +1360,6 @@ JAX XLA computation.
The TF custom gradients are respected, since it is TF that generates the
gradient computation.
In op-by-op mode, when we call TensorFlow in eager mode, we use
DLPack to try to avoid copying the data. This works for CPU (for
DeviceArray data or for np.ndarray that are aligned on 16-byte
boundaries) and on GPU (for DeviceArray).
The zero-copy does not yet work on TPU.
`call_tf` works even with shape polymorphism, but in that case
the user must pass the `output_shape_dtype` parameter to `call_tf` to declare
the expected output shapes. This allows JAX tracing to know the shape and

View File

@ -101,31 +101,12 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
res = _maybe_jit(with_jit, jax2tf.call_tf(lambda _: x))(x)
self.assertAllClose(x, res)
def test_eval_numpy_no_copy(self):
if jtu.device_under_test() != "cpu":
raise unittest.SkipTest("no_copy test works only on CPU")
# For ndarray, zero-copy only works for sufficiently-aligned arrays.
x = np.ones((16, 16), dtype=np.float32)
res = jax2tf.call_tf(lambda x: x)(x)
self.assertAllClose(x, res)
self.assertTrue(np.shares_memory(x, res))
@_parameterized_jit
def test_eval_devicearray_arg(self, with_jit=False):
x = jnp.ones((2, 3), dtype=np.float32)
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
self.assertAllClose(jnp.sin(x), res)
def test_eval_devicearray_no_copy(self):
if jtu.device_under_test() != "cpu":
# TODO(necula): add tests for GPU and TPU
raise unittest.SkipTest("no_copy test works only on CPU")
# For DeviceArray zero-copy works even if not aligned
x = jnp.ones((3, 3))
res = jax2tf.call_tf(lambda x: x)(x)
self.assertAllClose(x, res)
self.assertTrue(np.shares_memory(x, res))
x = jnp.array(3.0, dtype=jnp.bfloat16)
res = jax2tf.call_tf(lambda x: x)(x)
self.assertAllClose(x, res)