[jax2tf] Added documentation explaining how to handle undefined TF ops

Added a test case showing how to mix compileable and non-compileable code.
This commit is contained in:
George Necula 2021-05-18 17:13:09 +03:00
parent 6a95a8cf50
commit a27109d1bd
2 changed files with 132 additions and 5 deletions

View File

@ -77,6 +77,14 @@ The Autograph feature of `tf.function` cannot be expected to work on
functions converted from JAX as above, so it is recommended to
set `autograph=False` in order to avoid warnings or outright errors.
It is a good idea to use XLA to compile the converted function; that is
the scenario for which we are optimizing for numerical and performance
accuracy w.r.t. the JAX execution:
```python
tf.function(jax2tf.convert(f_jax), autograph=False, jit_compile=True)(x)
```
## Usage: saved model
Since jax2tf provides a regular TensorFlow function using it with SavedModel
@ -386,16 +394,86 @@ jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))
```
## Caveats
## Known issues
### Incomplete TensorFlow data type coverage
There are a number of cases when the TensorFlow ops that are used by the
jax2tf converter are not supported by TensorFlow for fewer data types than JAX.
jax2tf converter are not supported by TensorFlow for the same data types as in JAX.
There is an
[up-to-date list of unimplemented cases](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md).
### Missing features
There are two kinds of errors you may see. For the primitives in the
[unimplemented cases](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md)
that are shown to be undefined on all devices and for all execution modes
(`eager`, `graph`, `compiled`), e.g., `lax.min` for booleans,
the conversion typically uses a TensorFlow operator that is not
registered for a certain data type:
```python
jax2tf.convert(lambda x: lax.min(x, x))(np.array([True]))
>>> InvalidArgumentError: Value for attr 'T' of bool is not in the list of allowed values:
>>> bfloat16, half, float, double, uint8, int16, int32, int64;
>>> NodeDef: {{node Minimum}};
>>> Op<name=Minimum; signature=x:T, y:T -> z:T; attr=T:type,allowed=[DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT16, DT_INT32, DT_INT64]> [Op:Minimum]
```
In the above cases, you should file a bug with JAX or TensorFlow, or consider
changing your JAX code. We are working on eliminating this kind of problem.
In other cases, the TensorFlow op is registered for the data type, but for the
`eager` or `graph` execution modes there is no TensorFlow kernel defined.
Such primitives appear in the
[unimplemented cases](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md)
as unimplemented for `eager` and `graph`, e.g., `lax.sign` for unsigned integers:
```python
jax2tf.convert(lax.sign)(np.array([5], dtype=np.uint32))
>>> NotFoundError: Could not find device for node: {{node Minimum}} = Acos[T=DT_UINT32]
>>> All kernels registered for op Minimum:
>>> device='CPU'; T in [DT_FLOAT]
>>> device='CPU'; T in [DT_DOUBLE]
>>> ...
```
In this situation, you can still run the converted program if you compile it with
XLA:
```python
tf.function(jax2tf.convert(lax.sign),
autograph=False, jit_compile=True)(np.array([5], dtype=np.uint32))
```
Our priority is to ensure numerical and performance accuracy for
the converted program **when using XLA to compile the converted program**.
It is always a good idea to use XLA on the JAX-converted function.
Sometimes you cannot compile the entire TensorFlow function for your
model, because in addition to the function that is converted from JAX,
it may include some pre-processing TensorFlow code that
is not compileable with XLA, e.g., string parsing. Even in those situations
you can instruct TensorFlow to compile only the portion that originates
from JAX:
```python
def entire_tf_fun(x):
y = preprocess_tf_fun_not_compileable(x)
# Compile the code that is converted from JAX
z = tf.function(jax2tf.convert(compute_jax_fn),
autograph=False, jit_compile=True)(y)
return postprocess_tf_fun_not_compileable(z)
```
You won't be able to compile the `entire_tf_fun`, but you can still execute
it knowing that the JAX-converted code is compiled. You can even save
the function to a SavedModel, knowing that upon restore the
JAX-converted code will be compiled.
For a more elaborate example, see the test `test_tf_mix_jax_with_uncompileable`
in [savedmodel_test.py](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/tests/tests/savedmodel_test.py).
### Missing converter features
There is currently no support for replicated (e.g. `pmap`) or multi-device
(e.g. `sharded_jit`) functions. The collective operations are not yet handled.
@ -455,8 +533,8 @@ We use the following TFXLA ops:
* `XlaPad` (wraps XLA Pad operator). We use this instead of `tf.pad` in order to
support `lax.pad` interior padding (dilation) or negative edge padding.
* `XlaConv` (wraps XLA ConvGeneralDilated operator).
* `XlaDot` and `XlaDotV2` (wraps XLA DotGeneral operator).
* `XlaConv` and `XlaConv2` (wrap XLA ConvGeneralDilated operator).
* `XlaDot` and `XlaDotV2` (wrap XLA DotGeneral operator).
* `XlaGather` (wraps XLA Gather operator). We could use `tf.gather` in some
cases but not always. Also, `tf.gather` has a different semantics than `lax.gather`
for index out of bounds.

View File

@ -130,6 +130,55 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
arr = np.arange(10, dtype=np.float32)
self._compare_with_saved_model(f_jax, arr)
# Test does not work on GPU/TPU; would need something like TPU inference
# converter to separate the model on what needs to run on CPU or accelerator.
@jtu.skip_on_devices("gpu", "tpu")
def test_tf_mix_jax_with_uncompileable(self):
"""Show how to combine TF-uncompileable code with compiled JAX-converted code."""
def tf_fn(x_str, compute_tf_fn=lambda x: x):
# Some TF preprocessing code that cannot be compiled with XLA because it
# uses strings.
numbers_f32 = tf.strings.to_number(x_str, out_type=tf.float32)
numbers_f16 = tf.cast(numbers_f32, tf.float16)
return compute_tf_fn(numbers_f16)
x_str = np.array(["3.14", "2.78"])
# Test that we get an error if we try to TF-compile `tf_fn`
with self.assertRaisesRegex(
Exception,
"Detected unsupported operations when trying to compile graph"):
tf.function(tf_fn, jit_compile=True)(x_str)
def compute_jax_fn(x):
# A JAX function whose conversion does not run in TF without XLA because
# tf.math.atan is not supported on float16 without XLA.
return lax.atan(x) + lax.atan(x)
with self.assertRaisesRegex(
tf.errors.NotFoundError,
"Could not find device for node.*Atan.*DT_HALF"):
tf_fn(x_str, compute_tf_fn=jax2tf.convert(compute_jax_fn))
# Plug in the TF-compiled JAX-converted `compute_jax_fn`.
composed_fn = lambda x_str: tf_fn(
x_str,
compute_tf_fn=tf.function(jax2tf.convert(compute_jax_fn),
autograph=True,
jit_compile=True))
res_tf = composed_fn(x_str)
self.assertAllClose(res_tf.numpy(),
compute_jax_fn(np.array([3.14, 2.78], dtype=np.float16)))
# Save and restore SavedModel
model = tf.Module()
model.f = tf.function(
composed_fn,
input_signature=[tf.TensorSpec((2,), dtype=tf.string)])
restored_model = self.save_and_load_model(model)
res_tf_restored = restored_model.f(x_str)
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())