mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Added documentation explaining how to handle undefined TF ops
Added a test case showing how to mix compileable and non-compileable code.
This commit is contained in:
parent
6a95a8cf50
commit
a27109d1bd
@ -77,6 +77,14 @@ The Autograph feature of `tf.function` cannot be expected to work on
|
||||
functions converted from JAX as above, so it is recommended to
|
||||
set `autograph=False` in order to avoid warnings or outright errors.
|
||||
|
||||
It is a good idea to use XLA to compile the converted function; that is
|
||||
the scenario for which we are optimizing for numerical and performance
|
||||
accuracy w.r.t. the JAX execution:
|
||||
|
||||
```python
|
||||
tf.function(jax2tf.convert(f_jax), autograph=False, jit_compile=True)(x)
|
||||
```
|
||||
|
||||
## Usage: saved model
|
||||
|
||||
Since jax2tf provides a regular TensorFlow function using it with SavedModel
|
||||
@ -386,16 +394,86 @@ jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
|
||||
polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))
|
||||
```
|
||||
|
||||
## Caveats
|
||||
## Known issues
|
||||
|
||||
### Incomplete TensorFlow data type coverage
|
||||
|
||||
There are a number of cases when the TensorFlow ops that are used by the
|
||||
jax2tf converter are not supported by TensorFlow for fewer data types than JAX.
|
||||
jax2tf converter are not supported by TensorFlow for the same data types as in JAX.
|
||||
There is an
|
||||
[up-to-date list of unimplemented cases](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md).
|
||||
|
||||
### Missing features
|
||||
There are two kinds of errors you may see. For the primitives in the
|
||||
[unimplemented cases](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md)
|
||||
that are shown to be undefined on all devices and for all execution modes
|
||||
(`eager`, `graph`, `compiled`), e.g., `lax.min` for booleans,
|
||||
the conversion typically uses a TensorFlow operator that is not
|
||||
registered for a certain data type:
|
||||
|
||||
```python
|
||||
jax2tf.convert(lambda x: lax.min(x, x))(np.array([True]))
|
||||
|
||||
>>> InvalidArgumentError: Value for attr 'T' of bool is not in the list of allowed values:
|
||||
>>> bfloat16, half, float, double, uint8, int16, int32, int64;
|
||||
>>> NodeDef: {{node Minimum}};
|
||||
>>> Op<name=Minimum; signature=x:T, y:T -> z:T; attr=T:type,allowed=[DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT16, DT_INT32, DT_INT64]> [Op:Minimum]
|
||||
```
|
||||
|
||||
In the above cases, you should file a bug with JAX or TensorFlow, or consider
|
||||
changing your JAX code. We are working on eliminating this kind of problem.
|
||||
|
||||
In other cases, the TensorFlow op is registered for the data type, but for the
|
||||
`eager` or `graph` execution modes there is no TensorFlow kernel defined.
|
||||
Such primitives appear in the
|
||||
[unimplemented cases](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md)
|
||||
as unimplemented for `eager` and `graph`, e.g., `lax.sign` for unsigned integers:
|
||||
|
||||
```python
|
||||
jax2tf.convert(lax.sign)(np.array([5], dtype=np.uint32))
|
||||
|
||||
>>> NotFoundError: Could not find device for node: {{node Minimum}} = Acos[T=DT_UINT32]
|
||||
>>> All kernels registered for op Minimum:
|
||||
>>> device='CPU'; T in [DT_FLOAT]
|
||||
>>> device='CPU'; T in [DT_DOUBLE]
|
||||
>>> ...
|
||||
```
|
||||
|
||||
In this situation, you can still run the converted program if you compile it with
|
||||
XLA:
|
||||
```python
|
||||
tf.function(jax2tf.convert(lax.sign),
|
||||
autograph=False, jit_compile=True)(np.array([5], dtype=np.uint32))
|
||||
```
|
||||
|
||||
Our priority is to ensure numerical and performance accuracy for
|
||||
the converted program **when using XLA to compile the converted program**.
|
||||
It is always a good idea to use XLA on the JAX-converted function.
|
||||
|
||||
Sometimes you cannot compile the entire TensorFlow function for your
|
||||
model, because in addition to the function that is converted from JAX,
|
||||
it may include some pre-processing TensorFlow code that
|
||||
is not compileable with XLA, e.g., string parsing. Even in those situations
|
||||
you can instruct TensorFlow to compile only the portion that originates
|
||||
from JAX:
|
||||
|
||||
```python
|
||||
def entire_tf_fun(x):
|
||||
y = preprocess_tf_fun_not_compileable(x)
|
||||
# Compile the code that is converted from JAX
|
||||
z = tf.function(jax2tf.convert(compute_jax_fn),
|
||||
autograph=False, jit_compile=True)(y)
|
||||
return postprocess_tf_fun_not_compileable(z)
|
||||
```
|
||||
|
||||
You won't be able to compile the `entire_tf_fun`, but you can still execute
|
||||
it knowing that the JAX-converted code is compiled. You can even save
|
||||
the function to a SavedModel, knowing that upon restore the
|
||||
JAX-converted code will be compiled.
|
||||
|
||||
For a more elaborate example, see the test `test_tf_mix_jax_with_uncompileable`
|
||||
in [savedmodel_test.py](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/tests/tests/savedmodel_test.py).
|
||||
|
||||
### Missing converter features
|
||||
|
||||
There is currently no support for replicated (e.g. `pmap`) or multi-device
|
||||
(e.g. `sharded_jit`) functions. The collective operations are not yet handled.
|
||||
@ -455,8 +533,8 @@ We use the following TFXLA ops:
|
||||
|
||||
* `XlaPad` (wraps XLA Pad operator). We use this instead of `tf.pad` in order to
|
||||
support `lax.pad` interior padding (dilation) or negative edge padding.
|
||||
* `XlaConv` (wraps XLA ConvGeneralDilated operator).
|
||||
* `XlaDot` and `XlaDotV2` (wraps XLA DotGeneral operator).
|
||||
* `XlaConv` and `XlaConv2` (wrap XLA ConvGeneralDilated operator).
|
||||
* `XlaDot` and `XlaDotV2` (wrap XLA DotGeneral operator).
|
||||
* `XlaGather` (wraps XLA Gather operator). We could use `tf.gather` in some
|
||||
cases but not always. Also, `tf.gather` has a different semantics than `lax.gather`
|
||||
for index out of bounds.
|
||||
|
@ -130,6 +130,55 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
arr = np.arange(10, dtype=np.float32)
|
||||
self._compare_with_saved_model(f_jax, arr)
|
||||
|
||||
# Test does not work on GPU/TPU; would need something like TPU inference
|
||||
# converter to separate the model on what needs to run on CPU or accelerator.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def test_tf_mix_jax_with_uncompileable(self):
|
||||
"""Show how to combine TF-uncompileable code with compiled JAX-converted code."""
|
||||
def tf_fn(x_str, compute_tf_fn=lambda x: x):
|
||||
# Some TF preprocessing code that cannot be compiled with XLA because it
|
||||
# uses strings.
|
||||
numbers_f32 = tf.strings.to_number(x_str, out_type=tf.float32)
|
||||
numbers_f16 = tf.cast(numbers_f32, tf.float16)
|
||||
return compute_tf_fn(numbers_f16)
|
||||
|
||||
x_str = np.array(["3.14", "2.78"])
|
||||
|
||||
# Test that we get an error if we try to TF-compile `tf_fn`
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"Detected unsupported operations when trying to compile graph"):
|
||||
tf.function(tf_fn, jit_compile=True)(x_str)
|
||||
|
||||
def compute_jax_fn(x):
|
||||
# A JAX function whose conversion does not run in TF without XLA because
|
||||
# tf.math.atan is not supported on float16 without XLA.
|
||||
return lax.atan(x) + lax.atan(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.NotFoundError,
|
||||
"Could not find device for node.*Atan.*DT_HALF"):
|
||||
tf_fn(x_str, compute_tf_fn=jax2tf.convert(compute_jax_fn))
|
||||
|
||||
# Plug in the TF-compiled JAX-converted `compute_jax_fn`.
|
||||
composed_fn = lambda x_str: tf_fn(
|
||||
x_str,
|
||||
compute_tf_fn=tf.function(jax2tf.convert(compute_jax_fn),
|
||||
autograph=True,
|
||||
jit_compile=True))
|
||||
res_tf = composed_fn(x_str)
|
||||
self.assertAllClose(res_tf.numpy(),
|
||||
compute_jax_fn(np.array([3.14, 2.78], dtype=np.float16)))
|
||||
|
||||
# Save and restore SavedModel
|
||||
model = tf.Module()
|
||||
model.f = tf.function(
|
||||
composed_fn,
|
||||
input_signature=[tf.TensorSpec((2,), dtype=tf.string)])
|
||||
restored_model = self.save_and_load_model(model)
|
||||
res_tf_restored = restored_model.f(x_str)
|
||||
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user