mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Removed call_tf tests that are not applicable anymore.
A recent change in TensorFlow makes copies of np.ndarray when they are turned into tf.constant. This means that call_tf cannot guarantee anymore no-copy. Removing those tests, and the paragraph in the documentation that describes this property. PiperOrigin-RevId: 521120090
This commit is contained in:
parent
2432adefc3
commit
88f77bbcc6
@ -1360,12 +1360,6 @@ JAX XLA computation.
|
||||
The TF custom gradients are respected, since it is TF that generates the
|
||||
gradient computation.
|
||||
|
||||
In op-by-op mode, when we call TensorFlow in eager mode, we use
|
||||
DLPack to try to avoid copying the data. This works for CPU (for
|
||||
DeviceArray data or for np.ndarray that are aligned on 16-byte
|
||||
boundaries) and on GPU (for DeviceArray).
|
||||
The zero-copy does not yet work on TPU.
|
||||
|
||||
`call_tf` works even with shape polymorphism, but in that case
|
||||
the user must pass the `output_shape_dtype` parameter to `call_tf` to declare
|
||||
the expected output shapes. This allows JAX tracing to know the shape and
|
||||
|
@ -101,31 +101,12 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(lambda _: x))(x)
|
||||
self.assertAllClose(x, res)
|
||||
|
||||
def test_eval_numpy_no_copy(self):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("no_copy test works only on CPU")
|
||||
# For ndarray, zero-copy only works for sufficiently-aligned arrays.
|
||||
x = np.ones((16, 16), dtype=np.float32)
|
||||
res = jax2tf.call_tf(lambda x: x)(x)
|
||||
self.assertAllClose(x, res)
|
||||
self.assertTrue(np.shares_memory(x, res))
|
||||
|
||||
@_parameterized_jit
|
||||
def test_eval_devicearray_arg(self, with_jit=False):
|
||||
x = jnp.ones((2, 3), dtype=np.float32)
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
|
||||
self.assertAllClose(jnp.sin(x), res)
|
||||
|
||||
def test_eval_devicearray_no_copy(self):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
# TODO(necula): add tests for GPU and TPU
|
||||
raise unittest.SkipTest("no_copy test works only on CPU")
|
||||
# For DeviceArray zero-copy works even if not aligned
|
||||
x = jnp.ones((3, 3))
|
||||
res = jax2tf.call_tf(lambda x: x)(x)
|
||||
self.assertAllClose(x, res)
|
||||
self.assertTrue(np.shares_memory(x, res))
|
||||
|
||||
x = jnp.array(3.0, dtype=jnp.bfloat16)
|
||||
res = jax2tf.call_tf(lambda x: x)(x)
|
||||
self.assertAllClose(x, res)
|
||||
|
Loading…
x
Reference in New Issue
Block a user