mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #12270 from gnecula:tf_readme
PiperOrigin-RevId: 473219612
This commit is contained in:
commit
bc59bd1ddc
@ -7,13 +7,16 @@ This package provides experimental support for interoperation between JAX and Te
|
||||
There are two interoperation directions:
|
||||
|
||||
- `jax2tf.convert`: for using JAX functions in a TensorFlow context, e.g.,
|
||||
for eager or graph execution, or for saving as a TensorFlow SavedModel; and
|
||||
for eager or graph TensorFlow execution,
|
||||
or for saving as a TensorFlow SavedModel; and
|
||||
- `jax2tf.call_tf`: for using TensorFlow functions in a JAX context, e.g., to call a
|
||||
TensorFlow library or a SavedModel inside a JAX function.
|
||||
|
||||
The `jax2tf.convert` mechanism can wrap a function
|
||||
written in JAX, possibly including JAX transformations, and turn it into
|
||||
a function that uses only TensorFlow operations. The converted function
|
||||
`jax2tf.convert` directs JAX to use an alternative code
|
||||
generator (lowering) and emit TensorFlow operations instead of the regular HLO operations
|
||||
emitted in native JAX lowering. In all other respects the JAX function is
|
||||
processed as in native JAX execution, e.g., for the JAX transformations.
|
||||
The resulting function
|
||||
can be called or traced from TensorFlow and will behave as if it was written in TensorFlow.
|
||||
In practice this means that you can take some code written in JAX and execute it using
|
||||
TensorFlow eager mode, or stage it out as a TensorFlow graph, even use it
|
||||
@ -26,8 +29,8 @@ or TensorFlow Hub.
|
||||
|
||||
This package also contains the `jax2tf.call_tf` mechanism to call TensorFlow functions
|
||||
from JAX. These functions can be called in JAX's op-by-op execution mode,
|
||||
in which case the callee is executed in eager mode, or in JAX's jit (staged) context,
|
||||
in which case the callee is compiled to XLA and embedded in JAX's staged XLA.
|
||||
in which case the callee is executed in TensorFlow eager mode, or in JAX's jit (staged) context,
|
||||
in which case the callee is compiled to XLA and embedded in JAX's lowered HLO.
|
||||
|
||||
Both interoperation directions rely on the ability of
|
||||
TensorFlow to use the XLA compiler (`tf.function(jit_compile=True)`). For the
|
||||
@ -35,9 +38,10 @@ TensorFlow to use the XLA compiler (`tf.function(jit_compile=True)`). For the
|
||||
that the performance characteristics of the code match those of the JAX source.
|
||||
For the `call_tf` direction, JIT compilation is an essential part of the implementation
|
||||
mechanism. Only TensorFlow functions that can be JIT-compiled can be called from
|
||||
JAX. Since the TensorFlow functions that are produced by `jax2tf.convert` can
|
||||
be JIT-compiled by design, we can round-trip from JAX to TensorFlow
|
||||
(e.g., a SavedModel) and back.
|
||||
JAX in a jit context.
|
||||
Since the TensorFlow functions that are produced by `jax2tf.convert` can
|
||||
be JIT-compiled by design, we can call them using `jax2tf.call_tf` thus achieving
|
||||
a round-trip from JAX to TensorFlow (e.g., a SavedModel) and back.
|
||||
|
||||
We describe below some general concepts and capabilities, first for
|
||||
`jax2tf.convert` and [later](#calling-tensorflow-functions-from-jax)
|
||||
@ -51,13 +55,12 @@ For details on saving a batch-polymorphic SavedModel see [below](#shape-polymorp
|
||||
|
||||
See also some internal ongoing design discussions at `go/jax2tf-doc`.
|
||||
|
||||
## Usage: converting basic functions.
|
||||
## Usage: basic functions.
|
||||
|
||||
As a rule of thumb, if you can `jax.jit` your function then you should be able
|
||||
to use `jax2tf.convert`:
|
||||
|
||||
```python
|
||||
import jax
|
||||
from jax.experimental import jax2tf
|
||||
from jax import numpy as jnp
|
||||
|
||||
@ -67,7 +70,7 @@ import tensorflow as tf
|
||||
def f_jax(x):
|
||||
return jnp.sin(jnp.cos(x))
|
||||
|
||||
# jax2tf.convert is a higher order function that returns a wrapped function with
|
||||
# jax2tf.convert is a higher-order function that returns a wrapped function with
|
||||
# the same signature as your input function but accepting TensorFlow tensors (or
|
||||
# variables) as input.
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
@ -81,10 +84,10 @@ f_tf_graph = tf.function(f_tf, autograph=False)
|
||||
```
|
||||
|
||||
The Autograph feature of `tf.function` cannot be expected to work on
|
||||
functions converted from JAX as above, so it is recommended to
|
||||
functions lowered from JAX as above, so it is recommended to
|
||||
set `autograph=False` in order to avoid warnings or outright errors.
|
||||
|
||||
It is a good idea to use XLA to compile the converted function; that is
|
||||
It is a good idea to use XLA to compile the lowered function; that is
|
||||
the scenario for which we are optimizing for numerical and performance
|
||||
accuracy w.r.t. the JAX execution:
|
||||
|
||||
@ -118,7 +121,7 @@ restored_model = tf.saved_model.load('/some/directory')
|
||||
```
|
||||
|
||||
An important point is that in the above code snippet **everything after the
|
||||
jax2tf conversion is standard TensorFlow code.
|
||||
jax2tf invocation is standard TensorFlow code.
|
||||
In particular, the saving of the model is not directly part
|
||||
of the jax2tf API, and the user has full control over how to create the SavedModel**.
|
||||
|
||||
@ -149,19 +152,19 @@ def model_jax(inputs):
|
||||
return param0 + param1 * inputs
|
||||
```
|
||||
|
||||
If you just convert and save the model directly, the values of
|
||||
If you just lower and save the model directly, the values of
|
||||
`param0` and `param1` will be embedded in the computation graph. In fact, the
|
||||
value of `param1` is needed for the gradient computation and
|
||||
will be embedded twice: once in the computation
|
||||
graph for the forward computation and once for the backward computation,
|
||||
unless you turn off the conversion of gradients or their saving as discussed
|
||||
unless you turn off the staging of gradients or their saving as discussed
|
||||
further below (e.g., `with_gradient=False`). Note also that if one
|
||||
views the above function as an ML model parameterized by `param0` and `param1`
|
||||
then the gradient function will be w.r.t. the inputs, while you probably
|
||||
want gradients w.r.t. the parameters.
|
||||
|
||||
A better way to deal with parameters (or any large constants) is to
|
||||
pass them as parameters to the function to be converted:
|
||||
pass them as parameters to the function to be lowered:
|
||||
|
||||
```python
|
||||
def model_jax(params, inputs):
|
||||
@ -194,19 +197,20 @@ For examples of how to save a Flax model as a SavedModel see the
|
||||
|
||||
### Saved model and differentiation
|
||||
|
||||
The converted code supports differentiation from TensorFlow. In order to
|
||||
The code lowered from JAX supports differentiation from TensorFlow. In order to
|
||||
ensure that the result of TensorFlow differentiation is identical to the
|
||||
one that JAX differentiation would produce, the jax2tf converter will
|
||||
annotate the converter function with a ``tf.custom_gradient`` that,
|
||||
one that JAX differentiation would produce, we will
|
||||
annotate the lowered primal function with a ``tf.custom_gradient`` that,
|
||||
upon TensorFlow differentiation, will lazily
|
||||
call into JAX to compute the ``jax.vjp`` of the converted function, followed by
|
||||
jax2tf conversion. This ensures that ultimately it is JAX that performs the
|
||||
call into JAX to compute the ``jax.vjp`` of the lowered primal function, followed by
|
||||
jax2tf lowering of the gradient function.
|
||||
This ensures that ultimately it is JAX that performs the
|
||||
differentiation, thus respecting any custom gradients that may be present
|
||||
in the original function.
|
||||
|
||||
The jax2tf converter has an option ``with_gradient=False`` to skip the
|
||||
custom gradients and wrap instead the converted function with
|
||||
``tf.raw_ops.PreventGradient`` to generated an error in case a gradient
|
||||
The `jax2tf.convert` function has an option ``with_gradient=False`` to skip the
|
||||
custom gradients and wrap instead the lowered function with
|
||||
``tf.raw_ops.PreventGradient`` to generate an error in case a gradient
|
||||
computation is attempted.
|
||||
|
||||
SavedModels enables saving custom derivative rules by using the `experimental_custom_gradients` option:
|
||||
@ -257,21 +261,21 @@ you will not be able to compute the gradients of the function loaded from the Sa
|
||||
## Support for partitioning
|
||||
|
||||
jax2tf supports JAX functions that use `jax.pjit`, for single-host meshes.
|
||||
The conversion is actually similar as for a `jax.jit`, except that the
|
||||
The lowering is actually similar as for a `jax.jit`, except that the
|
||||
arguments and results will be wrapped with
|
||||
`tensorflow.compiler.xla.experimental.xla_sharding.XlaSharding` TensorFlow ops.
|
||||
|
||||
Note that when saving a model, the parameters to the model are wrapped with
|
||||
`tf.Variable` before calling the converted function (see [above](#saved_model_with_parameters)),
|
||||
`tf.Variable` before calling the lowered function (see [above](#saved_model_with_parameters)),
|
||||
therefore outside of the `XlaSharding` wrapper.
|
||||
|
||||
## Shape-polymorphic conversion
|
||||
|
||||
**The shape polymorphism support is work in progress. It is meant to be sound,
|
||||
but it may fail to convert some programs. Please report any bugs you encounter.**
|
||||
but it may fail to lower some programs. Please report any bugs you encounter.**
|
||||
|
||||
We described above how to include in the SavedModel several specializations
|
||||
of a converted function for a few specific input shapes. The converter can
|
||||
of a lowered function for a few specific input shapes. `jax2tf` can
|
||||
also produce a shape-polymorphic TensorFlow graph that is usable with inputs
|
||||
of any shape matching
|
||||
certain constraints. This is useful, e.g., to allow a single SavedModel
|
||||
@ -312,7 +316,7 @@ error messages. The real need for named shape
|
||||
variables arises when there are
|
||||
multiple unknown dimensions and there is a relationship between them.
|
||||
For example,
|
||||
if the function to be converted is also polymorphic on the size of each
|
||||
if the function to be lowered is also polymorphic on the size of each
|
||||
image while requiring the images to be square,
|
||||
we would add a dimension variable `d` to stand for
|
||||
the unknown image size:
|
||||
@ -330,7 +334,7 @@ same shape of a batch of square matrices that can be passed to `jnp.matmul`.
|
||||
|
||||
### Correctness of shape-polymorphic tracing
|
||||
|
||||
We want to trust that the converted program produces the same results as the
|
||||
We want to trust that the lowered program produces the same results as the
|
||||
original JAX program. More precisely:
|
||||
|
||||
For any function `f_jax` and any input signature `abs_sig` containing partially
|
||||
@ -354,22 +358,22 @@ by reusing the same JAX tracing and shape checking mechanism as when the shapes
|
||||
|
||||
### Coverage of shape-polymorphic tracing
|
||||
|
||||
Besides correctness, a secondary goal is to be able to convert many shape-polymorphic programs,
|
||||
Besides correctness, a secondary goal is to be able to lower many shape-polymorphic programs,
|
||||
but at the very
|
||||
least batch-size-polymorphic programs, so that one SavedModel can be used for any batch sizes.
|
||||
For example, we want to ensure that any function written using `jax.vmap` at the top level can be
|
||||
converted with the batch dimension polymorphic and the remaining dimensions concrete.
|
||||
lowered with the batch dimension polymorphic and the remaining dimensions concrete.
|
||||
|
||||
It is reasonable to expect that there will be JAX programs for which there is a
|
||||
shape-polymorphic TensorFlow graph, but which will give an error when converting with jax2tf.
|
||||
shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf.
|
||||
|
||||
### Details
|
||||
|
||||
In order to be able to use shape polymorphism effectively with jax2tf, it
|
||||
is worth considering what happens under the hood. When the converted function
|
||||
is invoked with a `TensorSpec`, the jax2tf converter will combine the
|
||||
is worth considering what happens under the hood. When the lowered function
|
||||
is invoked with a `TensorSpec`, `jax2tf` will combine the
|
||||
`TensorSpec` from the actual argument with the `polymorphic_shapes` parameter to
|
||||
obtain a shape abstraction to be used to specialize the converted function.
|
||||
obtain a shape abstraction to be used to specialize the lowered function.
|
||||
Normally, the shape abstraction contains the dimension sizes, but in the
|
||||
presence of shape polymorphism, some dimensions may be dimension variables.
|
||||
|
||||
@ -406,7 +410,7 @@ A few examples of shape specifications and uses:
|
||||
* `polymorphic_shapes=["(b, _, _)", None]` can be used for a function with two arguments, the first
|
||||
having a batch leading dimension that should be polymorphic. The other dimensions for the
|
||||
first argument and the shape of the second argument are specialized based on the actual
|
||||
`TensorSpec`, which must be known. The converted function can be used, e.g.,
|
||||
`TensorSpec`, which must be known. The lowered function can be used, e.g.,
|
||||
with `TensorSpec`s `[None, 28, 28]` and `[28, 16]` for the first and second argument
|
||||
respectively. An alternative `TensorSpec` pair can be `[1, 28, 28]` and `[28, 16]`,
|
||||
in which case the JAX tracing is done for the same polymorphic shape given by
|
||||
@ -481,13 +485,13 @@ jax2tf.convert(lambda x: 0 if x.shape[0] + 1 == x.shape[1] else 1,
|
||||
```
|
||||
|
||||
Note that it would be unsound for JAX to compute `x.shape[0] + 1 == x.shape[1]`
|
||||
as `False` and produce a converted function that returns `1` just because the dimension polynomials
|
||||
as `False` and produce a lowered function that returns `1` just because the dimension polynomials
|
||||
are not identical: there are some concrete input shapes for which the function
|
||||
should return `0`.
|
||||
|
||||
### Dimension variables appearing in the numeric computation
|
||||
|
||||
There are some situations when dimension variables arise in the staged computation itself.
|
||||
There are some situations when dimension variables arise in the lowered computation itself.
|
||||
You can see in the following example how elements from the input shapes
|
||||
`(1024, 28, 28)` and `(28, 28)` appear in the computation and specifically
|
||||
in the `shape` parameter of the `broadcast_in_dim` JAX primitive.
|
||||
@ -508,12 +512,12 @@ print(jax.make_jaxpr(image_mask_jax)(np.ones((1024, 28, 28)), np.ones((28, 28)))
|
||||
jax2tf.convert(image_mask_jax, polymorphic_shapes=["(b, w, w)", "(w, w)"])
|
||||
```
|
||||
|
||||
When tracing and converting with abstract shapes some primitive parameters will be dimension variables
|
||||
When tracing and lowering with abstract shapes some primitive parameters will be dimension variables
|
||||
instead of just constants, e.g., the `shape` parameter of `broadcast_in_dim` will be `(1, w, w)`.
|
||||
Note that JAX primitives distinguish the inputs, which are array values,
|
||||
e.g., `b` for `broadcast_in_dim` above, and the parameters, e.g., `broadcast_dimensions` and `shape`.
|
||||
|
||||
The conversion of `image_mask_jax` would use `tf.shape` to compute the
|
||||
The lowering of `image_mask_jax` would use `tf.shape` to compute the
|
||||
values of the dimension variables `b` and `w`:
|
||||
|
||||
```python
|
||||
@ -524,7 +528,7 @@ def image_mask_tf(images, mask):
|
||||
[b, w, w]))
|
||||
```
|
||||
|
||||
To achieve this, when we start converting a function we construct a shape environment,
|
||||
To achieve this, when we start lowering a function we construct a shape environment,
|
||||
mapping the dimension variables in the `polymorphic_shapes` specification to TensorFlow expressions
|
||||
using `tf.shape` on the input parameters.
|
||||
|
||||
@ -559,7 +563,7 @@ will want to ensure the size of the two axes is the same (`v == 4`).
|
||||
Note that `v` can stand for any integer greater than 0, so the value of the
|
||||
equality expression can be true or false. Since it is not always true
|
||||
that `v == 4`, the shape checking rules fail with the above error.
|
||||
Since the converted function works only for square matrices, the correct
|
||||
Since the lowered function works only for square matrices, the correct
|
||||
`polymorphic_shapes` is `["(v, v)"]`.
|
||||
|
||||
|
||||
@ -618,27 +622,97 @@ jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
||||
|
||||
## Known issues
|
||||
|
||||
`jax2tf` has been in use since 2020 and the vast majority of users encounter
|
||||
no problems. However, there are a few rare corner cases
|
||||
in which the different conventions of JAX and TensorFlow result in a breakage.
|
||||
We try to give an exhaustive list below.
|
||||
|
||||
### Different 64-bit precision in JAX and TensorFlow
|
||||
|
||||
JAX behaves somewhat differently than TensorFlow in the handling
|
||||
of 32-bit vs. 64-bit values. However, the `jax2tf` lowered function
|
||||
always behaves like the JAX function.
|
||||
|
||||
JAX interprets the type of Python scalars differently based on
|
||||
`JAX_ENABLE_X64` flag. (See
|
||||
[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).)
|
||||
In the default configuration, the
|
||||
flag is unset, and JAX interprets Python constants as 32-bit,
|
||||
e.g., the type of `3.14` is `float32`. This is also what
|
||||
TensorFlow always does. JAX goes further, it forces
|
||||
all explicitly-specified 64-bit values to be interpreted as
|
||||
32-bit:
|
||||
|
||||
```python
|
||||
# with JAX_ENABLE_X64=0
|
||||
jnp.sin(3.14) # Has type float32
|
||||
tf.math.sin(3.14) # Has type float32
|
||||
|
||||
jnp.sin(np.float64(3.14)) # Also has type float32
|
||||
tf.math.sin(np.float64(3.14)) # Has type float64
|
||||
|
||||
# The jax2tf.convert function behaves like the JAX function.
|
||||
jax2tf.convert(jnp.sin)(3.14) # Has type float32
|
||||
jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32
|
||||
|
||||
# The following will still compute `sin` in float32 (with a tf.cast on the argument).
|
||||
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64))
|
||||
```
|
||||
|
||||
When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types
|
||||
for Python scalars and respects the explicit 64-bit types:
|
||||
|
||||
```python
|
||||
# with JAX_ENABLE_X64=1
|
||||
jnp.sin(3.14) # Has type float64
|
||||
tf.math.sin(3.14) # Has type float32
|
||||
|
||||
# The jax2tf.convert function behaves like the JAX function.
|
||||
jax2tf.convert(jnp.sin)(3.14) # Has type float64
|
||||
|
||||
# The following will compute `sin` in float64.
|
||||
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64))
|
||||
|
||||
# The following will compute `sin` in float32.
|
||||
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14))
|
||||
```
|
||||
|
||||
This is achieved by inserting `tf.cast` operations
|
||||
on the input arguments inside the lowered function,
|
||||
if necessary.
|
||||
|
||||
If you want to create a `tf.Variable` or `tf.TensorSpec` with the
|
||||
same dtype, you should use `jax2tf.dtype_of_val`:
|
||||
|
||||
```python
|
||||
# The following two calls will lower jax_fun at the same dtypes
|
||||
# independently of the value of JAX_ENABLE_X64.
|
||||
jax2tf.convert(jax_fun)(3.14)
|
||||
jax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14))
|
||||
```
|
||||
|
||||
### Incomplete TensorFlow data type coverage
|
||||
|
||||
There are a number of cases when the TensorFlow ops that are used by the
|
||||
jax2tf converter are not supported by TensorFlow for the same data types as in JAX.
|
||||
`jax2tf` are not supported by TensorFlow for the same data types as in JAX.
|
||||
There is an
|
||||
[up-to-date list of unimplemented cases](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md).
|
||||
|
||||
If you try to convert and run in TensorFlow a program with partially supported primitives, you may see TensorFlow errors that
|
||||
a TensorFlow op is used with an supported data type, or that
|
||||
If you try to lower and run in TensorFlow a program with partially supported primitives,
|
||||
you may see TensorFlow errors that
|
||||
a TensorFlow op is used with an unsupported data type, or that
|
||||
there is no supported TensorFlow kernel for the op for the given
|
||||
data type. The former case can happen even if you `jit_compile`
|
||||
the TensorFlow program, and it is a priority to fit. The latter
|
||||
case only appears in TensorFlow non-compiled mode and you can
|
||||
case only appears in TensorFlow non-compiled mode; you can
|
||||
avoid the problem if you use XLA to `jit_compile` (always recommended).
|
||||
|
||||
Our priority is to ensure numerical and performance accuracy for
|
||||
the converted program **when using XLA to compile the converted program**.
|
||||
It is always a good idea to use XLA on the JAX-converted function.
|
||||
the lowered program **when using XLA to compile the lowered program**.
|
||||
It is always a good idea to use XLA on the lowered function.
|
||||
|
||||
Sometimes you cannot compile the entire TensorFlow function for your
|
||||
model, because in addition to the function that is converted from JAX,
|
||||
model, because in addition to the function that is lowered from JAX,
|
||||
it may include some pre-processing TensorFlow code that
|
||||
is not compileable with XLA, e.g., string parsing. Even in those situations
|
||||
you can instruct TensorFlow to compile only the portion that originates
|
||||
@ -647,127 +721,35 @@ from JAX:
|
||||
```python
|
||||
def entire_tf_fun(x):
|
||||
y = preprocess_tf_fun_not_compileable(x)
|
||||
# Compile the code that is converted from JAX
|
||||
# Compile the code that is lowered from JAX
|
||||
z = tf.function(jax2tf.convert(compute_jax_fn),
|
||||
autograph=False, jit_compile=True)(y)
|
||||
return postprocess_tf_fun_not_compileable(z)
|
||||
```
|
||||
|
||||
You won't be able to compile the `entire_tf_fun`, but you can still execute
|
||||
it knowing that the JAX-converted code is compiled. You can even save
|
||||
it knowing that the jax2tf-lowered code is compiled. You can even save
|
||||
the function to a SavedModel, knowing that upon restore the
|
||||
JAX-converted code will be compiled.
|
||||
jax2tf-lowered code will be compiled.
|
||||
|
||||
For a more elaborate example, see the test `test_tf_mix_jax_with_uncompileable`
|
||||
in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py).
|
||||
|
||||
### Missing converter features
|
||||
### Functions whose arguments and results are nested Python data structures
|
||||
|
||||
There is currently no support for `pmap` or`xmap`, nor for the collective
|
||||
operations. There is support for `pjit`.
|
||||
|
||||
### SavedModel may be large
|
||||
|
||||
If you suspect that the SavedModel is larger than it should be, check first
|
||||
that you are not including the parameters as constants in the graph (see [above](#usage-saved-model)).
|
||||
|
||||
### SavedModel supports only first-order gradients
|
||||
|
||||
The `jax2tf`-converted function supports higher-order gradients, but when the
|
||||
function is saved in a SavedModel, only the first-order gradient is saved.
|
||||
|
||||
### Converting gradients for functions with integer arguments or unused arguments
|
||||
|
||||
When JAX differentiates functions with integer or boolean arguments, the gradients will
|
||||
be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)).
|
||||
This type is translated to `int32` when converting to TF.
|
||||
For example,
|
||||
|
||||
```python
|
||||
x = np.int16(2)
|
||||
def f_jax(x): # x: int16
|
||||
return x * 2.
|
||||
|
||||
jax.grad(f_jax, allow_int=True)(x)
|
||||
# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])
|
||||
|
||||
jax2tf.convert(jax.grad(f_jax, allow_int=True))(x))
|
||||
# returns a tf.Tensor(0, shape=(), dtype=int32)
|
||||
```
|
||||
|
||||
Note that this is different from how TensorFlow handles gradients
|
||||
for integer or boolean arguments: sometimes the gradient is `None`,
|
||||
sometimes it is a zero with the same dtype as the argument, and
|
||||
sometimes it is a one with the same dtype as the argument (e.g.,
|
||||
for the identity function).
|
||||
|
||||
```python
|
||||
def f_tf(x): # x: int16
|
||||
return tf.cast(x, tf.float32) * 2.
|
||||
|
||||
xv = tf.Variable(x)
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
print(tape.gradient(f_tf(xv), xv))
|
||||
# returns None
|
||||
print(tape.gradient(f_tf(xv), xv,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO))
|
||||
# returns 0 with the same shape and dtype as x
|
||||
```
|
||||
|
||||
When differentiating functions with unused arguments, TF by default
|
||||
returns the value `None` for the corresponding gradients. The
|
||||
`tape.gradient` function takes the option `tf.UnconnectedGradients.ZERO`
|
||||
to ask that gradients for unused arguments be zero.
|
||||
|
||||
Functions converted with `jax2tf.convert` behave the same way under
|
||||
`tf.UnconnectedGradients.ZERO`, but by default, they will return
|
||||
`None` only for gradients corresponding to integer arguments.
|
||||
|
||||
```python
|
||||
# x1 and x3 are not used. x3 has integer type.
|
||||
def fn(x0, x1, x2, x3):
|
||||
return x0 * 0. + x2 * 2.
|
||||
|
||||
xs = [tf.Variable(x) for x in [10., 11., 12., 13]]
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
res = fn(*xs)
|
||||
|
||||
g_tf_native = tape.gradient(res, xs)
|
||||
# Returns: 0., None, 2., None
|
||||
|
||||
g_tf_native_0 = tape.gradient(res, xs,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
||||
# Returns: 0., 0., 2., 0
|
||||
|
||||
# Now with jax2tf.convert
|
||||
with tf.GradientTape() as tape:
|
||||
res = jax2tf.convert(fn, with_gradient=True)(*xs0
|
||||
|
||||
g_jax2tf = tape.gradient(res, xs)
|
||||
# Returns: 0., 0., 2., None
|
||||
# Note that the gradient for x1 is 0.
|
||||
|
||||
g_jaxx2tf_0 = tape.gradient(res, xs,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
||||
# Returns: 0., 0., 2., 0
|
||||
# In this case we get the same result as for TF native.
|
||||
```
|
||||
|
||||
### Functions whose arguments and results are Python nested data structures
|
||||
|
||||
jax2tf can convert functions with arguments and results that are nested
|
||||
`jax2tf` can lower functions with arguments and results that are nested
|
||||
collections (tuples, lists, dictionaries) of numeric values or JAX arrays
|
||||
([pytrees](https://jax.readthedocs.io/en/latest/pytrees.html)). The
|
||||
resulting TensorFlow function will take the same kind of arguments except the
|
||||
leaves can be numeric values or TensorFlow tensors (`tf.Tensor`, `tf.TensorSpec`, `tf.Variable`).
|
||||
|
||||
As long as the arguments use only standard Python containers (tuple, list, dictionaries),
|
||||
both JAX and TensorFlow can flatten and unflatten them and you can use the converted
|
||||
both JAX and TensorFlow can flatten and unflatten them and you can use the lowered
|
||||
function in TensorFlow without limitations.
|
||||
|
||||
However, if your JAX function takes a custom container, you can register it with
|
||||
the JAX `tree_util` module so that JAX will know how to operate with it, and you
|
||||
can still convert the function to use it in TensorFlow
|
||||
can still lower the function to use it in TensorFlow
|
||||
eager and with `tf.function`, but you won't be able to save it to a SavedModel, nor
|
||||
will you be able to compute gradients with TensorFlow
|
||||
(code from `jax2tf_test.test_custom_pytree_readme`):
|
||||
@ -829,77 +811,129 @@ self.assertAllClose(grad_jax.a, grad_tf[0])
|
||||
self.assertAllClose(grad_jax.b, grad_tf[1])
|
||||
```
|
||||
|
||||
### Different 64-bit precision in JAX and TensorFlow
|
||||
### Lowering gradients for functions with integer arguments or unused arguments
|
||||
|
||||
JAX behaves somewhat differently than TensorFlow in the handling
|
||||
of 32-bit vs. 64-bit values. However, the `jax2tf.convert` function
|
||||
always behaves like the JAX function.
|
||||
|
||||
JAX interprets the type of Python scalars differently based on
|
||||
`JAX_ENABLE_X64` flag. (See
|
||||
[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).)
|
||||
In the default configuration, the
|
||||
flag is unset, and JAX interprets Python constants as 32-bit,
|
||||
e.g., the type of `3.14` is `float32`. This is also what
|
||||
TensorFlow always does. JAX goes further, it forces
|
||||
all explicitly-specified 64-bit values to be interpreted as
|
||||
32-bit:
|
||||
When JAX differentiates functions with integer or boolean arguments, the gradients will
|
||||
be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)).
|
||||
This type is translated to `int32` when lowering to TF.
|
||||
For example,
|
||||
|
||||
```python
|
||||
# with JAX_ENABLE_X64=0
|
||||
jnp.sin(3.14) # Has type float32
|
||||
tf.math.sin(3.14) # Has type float32
|
||||
x = np.int16(2)
|
||||
def f_jax(x): # x: int16
|
||||
return x * 2.
|
||||
|
||||
jnp.sin(np.float64(3.14)) # Also has type float32
|
||||
tf.math.sin(np.float64(3.14)) # Has type float64
|
||||
jax.grad(f_jax, allow_int=True)(x)
|
||||
# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])
|
||||
|
||||
# The jax2tf.convert function behaves like the JAX function.
|
||||
jax2tf.convert(jnp.sin)(3.14) # Has type float32
|
||||
jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32
|
||||
|
||||
# The following will still compute `sin` in float32 (with a tf.cast on the argument).
|
||||
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64))
|
||||
jax2tf.convert(jax.grad(f_jax, allow_int=True))(x))
|
||||
# returns a tf.Tensor(0, shape=(), dtype=int32)
|
||||
```
|
||||
|
||||
When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types
|
||||
for Python scalars and respects the explicit 64-bit types:
|
||||
Note that this is different from how TensorFlow handles gradients
|
||||
for integer or boolean arguments: sometimes the gradient is `None`,
|
||||
sometimes it is a zero with the same dtype as the argument, and
|
||||
sometimes it is a one with the same dtype as the argument (e.g.,
|
||||
for the identity function).
|
||||
|
||||
```python
|
||||
# with JAX_ENABLE_X64=1
|
||||
jnp.sin(3.14) # Has type float64
|
||||
tf.math.sin(3.14) # Has type float32
|
||||
def f_tf(x): # x: int16
|
||||
return tf.cast(x, tf.float32) * 2.
|
||||
|
||||
# The jax2tf.convert function behaves like the JAX function.
|
||||
jax2tf.convert(jnp.sin)(3.14) # Has type float64
|
||||
|
||||
# The following will compute `sin` in float64.
|
||||
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64))
|
||||
|
||||
# The following will compute `sin` in float32.
|
||||
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14))
|
||||
xv = tf.Variable(x)
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
print(tape.gradient(f_tf(xv), xv))
|
||||
# returns None
|
||||
print(tape.gradient(f_tf(xv), xv,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO))
|
||||
# returns 0 with the same shape and dtype as x
|
||||
```
|
||||
|
||||
This is achieved by inserting `tf.cast` operations
|
||||
on the input arguments inside the converted function,
|
||||
if necessary.
|
||||
When differentiating functions with unused arguments, TF by default
|
||||
returns the value `None` for the corresponding gradients. The
|
||||
`tape.gradient` function takes the option `tf.UnconnectedGradients.ZERO`
|
||||
to ask that gradients for unused arguments be zero.
|
||||
|
||||
If you want to create a `tf.Variable` or `tf.TensorSpec` with the
|
||||
same dtype, you should use `jax2tf.dtype_of_val`:
|
||||
Functions lowered with `jax2tf.convert` behave the same way under
|
||||
`tf.UnconnectedGradients.ZERO`, but by default, they will return
|
||||
`None` only for gradients corresponding to integer arguments.
|
||||
|
||||
```python
|
||||
# The following two calls will convert jax_fun at the same dtypes
|
||||
# independently of the value of JAX_ENABLE_X64.
|
||||
jax2tf.convert(jax_fun)(3.14)
|
||||
jax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14))
|
||||
# x1 and x3 are not used. x3 has integer type.
|
||||
def fn(x0, x1, x2, x3):
|
||||
return x0 * 0. + x2 * 2.
|
||||
|
||||
xs = [tf.Variable(x) for x in [10., 11., 12., 13]]
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
res = fn(*xs)
|
||||
|
||||
g_tf_native = tape.gradient(res, xs)
|
||||
# Returns: 0., None, 2., None
|
||||
|
||||
g_tf_native_0 = tape.gradient(res, xs,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
||||
# Returns: 0., 0., 2., 0
|
||||
|
||||
# Now with jax2tf.convert
|
||||
with tf.GradientTape() as tape:
|
||||
res = jax2tf.convert(fn, with_gradient=True)(*xs0
|
||||
|
||||
g_jax2tf = tape.gradient(res, xs)
|
||||
# Returns: 0., 0., 2., None
|
||||
# Note that the gradient for x1 is 0.
|
||||
|
||||
g_jax2tf_0 = tape.gradient(res, xs,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
||||
# Returns: 0., 0., 2., 0
|
||||
# In this case we get the same result as for TF native.
|
||||
```
|
||||
|
||||
|
||||
### Errors due to tf.Module magic conversion during attribute assignment
|
||||
|
||||
`tf.Module` will automatically wrap the standard Python container data types into
|
||||
trackable classes during attribute assignment.
|
||||
Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper
|
||||
classes.
|
||||
In most situation, these Wrapper classes work exactly as the standard
|
||||
Python data types. However, the low-level pytree data structures are different
|
||||
and this can lead to errors.
|
||||
|
||||
In such cases, the user can use this workaround:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
input_data = #Any data object
|
||||
|
||||
m = tf.Module()
|
||||
flat, tree_def = jax.tree_util.tree_flatten(input_data)
|
||||
m.input_data = {"flat": flat, "tree_def": tree_def}
|
||||
```
|
||||
|
||||
Later the user can use `tree_unflatten` for the reverse process:
|
||||
|
||||
```python
|
||||
input_data = jax.tree_util.tree_unflatten(m.input_data['tree_def'], m.input_data['flat'])
|
||||
```
|
||||
|
||||
### Unimplemented jax2tf features
|
||||
|
||||
There is currently no support for `pmap` or`xmap`, nor for the collective
|
||||
operations. There is support for `pjit`.
|
||||
|
||||
### SavedModel supports only first-order gradients
|
||||
|
||||
The `jax2tf`-lowered function supports higher-order gradients, but when the
|
||||
function is saved in a SavedModel, only the first-order gradient is saved.
|
||||
This is primarily a limitation of the SavedModel support for custom gradients.
|
||||
|
||||
### Slow implementation of associative reductions for CPU
|
||||
|
||||
Operations like ``jax.numpy.cumsum`` are compiled by JAX differently based
|
||||
on the platform. For TPU, the compilation uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)
|
||||
Operations like ``jax.numpy.cumsum`` are lowered by JAX differently based
|
||||
on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)
|
||||
operation, which has an efficient implementation for the cases when the
|
||||
reduction function is associative. For CPU and GPU, JAX uses an alternative
|
||||
implementation using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801).
|
||||
lowering using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801).
|
||||
jax2tf uses the TPU lowering (because it does not support backend-specific lowering)
|
||||
and hence it can be slow in some cases on CPU and GPU.
|
||||
|
||||
@ -914,100 +948,51 @@ Use this only if it improves the performance for your application.
|
||||
Note that this lowering may not work as well as the default one in presence
|
||||
of shape polymorphism.
|
||||
|
||||
### Unchecked assumption that the dimension variables take strictly positive values
|
||||
|
||||
The shape polymorphic conversion is sound with the assumption that the dimension
|
||||
variables take non-zero values. In the following example, the function to be converted
|
||||
has different behavior for empty shapes. The broken assumption is caught by jax2tf if
|
||||
the converted function is executed eagerly, but not if it is first traced to a
|
||||
TensorFlow graph:
|
||||
|
||||
```python
|
||||
def f_jax(x):
|
||||
return 0 if x.shape[0] == 0 else 1
|
||||
|
||||
x0 = np.array([], np.float32)
|
||||
self.assertEqual(0, f_jax(x0)) # JAX sees that the x.shape[0] == 0
|
||||
|
||||
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
|
||||
# eagerly.
|
||||
# Raises: ValueError: Dimension variable b must have integer value >= 1. Found value 0 when solving b == 0
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b"])(x0))
|
||||
|
||||
# However, if we first trace to a TensorFlow graph, we may miss the broken assumption:
|
||||
f_tf = tf.function(
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b"])).get_concrete_function(tf.TensorSpec([None], dtype=np.float32))
|
||||
self.assertEqual(1, f_tf(x0))
|
||||
```
|
||||
|
||||
Another possible source of unsoundness is that JAX assumes that all unknown
|
||||
dimensions represented by the same dimension variable have equal size. As before,
|
||||
this assumption is checked if the converted function is executed eagerly, but
|
||||
it may be missed if it is first traced to a TensorFlow graph:
|
||||
|
||||
```python
|
||||
def f_jax(x):
|
||||
return 0 if x.shape[0] != x.shape[1] else 1
|
||||
|
||||
x45 = np.ones((4, 5), dtype=np.float32)
|
||||
self.assertEqual(0, f_jax(x45)) # JAX seems that x.shape[0] != x.shape[1]
|
||||
|
||||
# jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the converted
|
||||
# function is executed eagerly.
|
||||
# Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),)
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])(x45)
|
||||
|
||||
# However, if we first trace to a TensorFlow graph, we may miss the broken assumption.
|
||||
f_tf = tf.function(
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
|
||||
self.assertEqual(1, f_tf(x45))
|
||||
```
|
||||
|
||||
### TensorFlow XLA ops
|
||||
|
||||
For most JAX primitives there is a natural TF op that fits the needed semantics.
|
||||
For most JAX primitives there is a natural TensorFlow op that fits the needed semantics.
|
||||
There are a few (listed below) JAX primitives for which there is no
|
||||
single TF op with matching semantics.
|
||||
single TensorFlow op with matching semantics.
|
||||
This is not so surprising, because JAX primitives have been designed
|
||||
to be compiled to [HLO ops](https://www.tensorflow.org/xla/operation_semantics),
|
||||
while the corresponding TF ops are sometimes higher-level.
|
||||
For the cases when there is no matching canonical TF op,
|
||||
we use a set of special TF ops that are thin wrappers over HLO ops
|
||||
while the corresponding TensorFlow ops are sometimes higher-level.
|
||||
For the cases when there is no matching canonical TensorFlow op,
|
||||
we use a set of special TensorFlow ops that are thin wrappers over HLO ops
|
||||
(a subset of those registered in
|
||||
[tf2xla/ops/xla_ops.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/ops/xla_ops.cc)
|
||||
and implemented in,
|
||||
e.g.,
|
||||
[tf2xla/kernels/xla_pad_op.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc).)
|
||||
We refer to these ops here as the XLA TF ops. Note that these are
|
||||
We refer to these ops here as the XLA TensorFlow ops. Note that these are
|
||||
still regular TF ops, e.g., they can be saved in a SavedModel.
|
||||
|
||||
There are several drawbacks of using XLA TF ops:
|
||||
There are several drawbacks of using XLA TensorFlow ops:
|
||||
|
||||
* These ops will only be executable by a consumer that has XLA linked in.
|
||||
This should not be a problem for TPU execution, since that requires XLA anyway.
|
||||
* These ops are not yet recognized by tools that process
|
||||
tf.Graph, e.g., TensorFlow.js converter or the TensorFlow Lite converter.
|
||||
|
||||
As an experimental feature we implemented alternative conversions to avoid the XLA TF ops.
|
||||
As an experimental feature we implemented alternative conversions to avoid the XLA TensorFlow ops.
|
||||
You can enable this with the `enable_xla=False` parameter to `jax2tf.convert`.
|
||||
For more details see [no_xla_limitations.md](g3doc/no_xla_limitations.md).
|
||||
|
||||
### Different performance characteristics
|
||||
|
||||
The converted code may have slightly different performance characteristics than
|
||||
The lowered code may have slightly different performance characteristics than
|
||||
the original JAX code.
|
||||
We do expect that the performance characteristics of converted code
|
||||
should approximate those of JAX when used with the XLA compiler (`tf.function(jit_compile=True)`).
|
||||
We do expect that the performance characteristics of lowered code
|
||||
should be the same as those of JAX when used with the XLA compiler (`tf.function(jit_compile=True)`).
|
||||
This is because
|
||||
during conversion we try to generate one TensorFlow op for one JAX primitive.
|
||||
during lowering we try to generate one TensorFlow op for one JAX primitive.
|
||||
We expect that the lowering that XLA does is similar to that done by JAX
|
||||
before conversion. (This is a hypothesis, we have not yet verified it extensively.)
|
||||
|
||||
There is one know case when the performance of the converted code will be different.
|
||||
There is one know case when the performance of the lowered code will be different.
|
||||
JAX programs use a [stateless
|
||||
deterministic PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md)
|
||||
and it has an internal JAX primitive for it.
|
||||
This primitive is at the moment converted to a soup of tf.bitwise operations,
|
||||
This primitive is at the moment lowered to a soup of tf.bitwise operations,
|
||||
which has a clear performance penalty. We plan to look into using the
|
||||
HLO [RNGBitGenerator](https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator)
|
||||
(exposed as a TFXLA op), which does implement
|
||||
@ -1025,38 +1010,60 @@ a custom C++ “high-level” kernel implementing batch normalization is execute
|
||||
In JAX, there is no primitive for batch normalization, and instead the
|
||||
operation is decomposed into low-level primitives (e.g., [flax.linen.BatchNorm](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html),
|
||||
or haiku.BatchNorm).
|
||||
Once those primitives are converted to TensorFlow, and the resulting code is
|
||||
Once those primitives are lowered to TensorFlow, and the resulting code is
|
||||
run without XLA, the ensemble of the kernels executed will quite
|
||||
possibly behave differently, performance-wise or even numerically,
|
||||
than either the TensorFlow native or JAX native batch normalization.
|
||||
A similar example is that of an LSTM cell.
|
||||
|
||||
|
||||
### Errors due to tf.Module magic conversion during attribute assignment
|
||||
### Unchecked assumption that the dimension variables take strictly positive values
|
||||
|
||||
tf.Module will automatically wrap the standard Python container data types into
|
||||
trackable classes during attribute assignment.
|
||||
Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper
|
||||
classes.
|
||||
In most situation, these Wrapper classes work exactly as the standard
|
||||
Python data types. However, the low-level pytree data structures are different
|
||||
and this can lead to errors.
|
||||
|
||||
In such cases, the user can use this walkaround:
|
||||
The shape polymorphic conversion is sound with the assumption that the dimension
|
||||
variables take non-zero values. In the following example, the function to be lowered
|
||||
has different behavior for empty shapes. The broken assumption is caught by jax2tf if
|
||||
the lowered function is executed eagerly, but not if it is first traced to a
|
||||
TensorFlow graph:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
input_data = #Any data object
|
||||
def f_jax(x):
|
||||
return 0 if x.shape[0] == 0 else 1
|
||||
|
||||
m = tf.Module()
|
||||
flat, tree_def = jax.tree_util.tree_flatten(input_data)
|
||||
m.input_data = {"flat": flat, "tree_def": tree_def}
|
||||
x0 = np.array([], np.float32)
|
||||
self.assertEqual(0, f_jax(x0)) # JAX sees that the x.shape[0] == 0
|
||||
|
||||
# jax2tf catches the broken assumption b >= 1 if the lowered function is executed
|
||||
# eagerly.
|
||||
# Raises: ValueError: Dimension variable b must have integer value >= 1. Found value 0 when solving b == 0
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b"])(x0))
|
||||
|
||||
# However, if we first trace to a TensorFlow graph, we may miss the broken assumption:
|
||||
f_tf = tf.function(
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b"])).get_concrete_function(tf.TensorSpec([None], dtype=np.float32))
|
||||
self.assertEqual(1, f_tf(x0))
|
||||
```
|
||||
|
||||
Later the user can use `tree_unflatten` for the reverse process:
|
||||
Another possible source of unsoundness is that JAX assumes that all unknown
|
||||
dimensions represented by the same dimension variable have equal size. As before,
|
||||
this assumption is checked if the lowered function is executed eagerly, but
|
||||
it may be missed if it is first traced to a TensorFlow graph:
|
||||
|
||||
```python
|
||||
input_data = jax.tree_util.tree_unflatten(m.input_data['tree_def'], m.input_data['flat'])
|
||||
def f_jax(x):
|
||||
return 0 if x.shape[0] != x.shape[1] else 1
|
||||
|
||||
x45 = np.ones((4, 5), dtype=np.float32)
|
||||
self.assertEqual(0, f_jax(x45)) # JAX seems that x.shape[0] != x.shape[1]
|
||||
|
||||
# jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the lowered
|
||||
# function is executed eagerly.
|
||||
# Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),)
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])(x45)
|
||||
|
||||
# However, if we first trace to a TensorFlow graph, we may miss the broken assumption.
|
||||
f_tf = tf.function(
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
|
||||
self.assertEqual(1, f_tf(x45))
|
||||
```
|
||||
|
||||
# Calling TensorFlow functions from JAX
|
||||
|
Loading…
x
Reference in New Issue
Block a user