[jax2tf] Turn on JAX native serialization by default.

See changes to the README.md for mechanisms to override the default.

PiperOrigin-RevId: 554390866
This commit is contained in:
George Necula 2023-08-07 01:03:15 -07:00 committed by jax authors
parent 364a245ab2
commit 8d80e2587b
3 changed files with 23 additions and 14 deletions

View File

@ -17,6 +17,11 @@ Remember to align the itemized text with the first line of an item within a list
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* Breaking changes:
* jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
## jaxlib 0.4.15
## jax 0.4.14 (July 27, 2023)

View File

@ -679,7 +679,7 @@ jax2tf_associative_scan_reductions = config.define_bool_state(
jax2tf_default_native_serialization = config.define_bool_state(
name='jax2tf_default_native_serialization',
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', False),
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
help=(
'Sets the default value of the native_serialization parameter to '
'jax2tf.convert. Prefer using the parameter instead of the flag, '

View File

@ -18,27 +18,23 @@ These APIs can be combined, e.g., to reload in JAX a program that
has been serialized from JAX to a TensorFlow SavedModel, or to save to
TensorFlow SavedModel a JAX program that uses a TensorFlow library.
Tip: As of version 0.4.7 (March 2023), there is a new option
`native_serialization` to use JAX's native lowering to StableHLO to obtain
one StableHLO module for the entire JAX function instead of lowering each
JAX primitive to a TensorFlow op.
The preferred mode of JAX-TensorFlow interoperation is by way of
**native serialization** in which the target function is lowered to StableHLO
Tip: As of version 0.4.14 (July 2023) the default mode of JAX-TensorFlow
interoperation is by way of **native serialization** in which the target
function is lowered to StableHLO
using standard native JAX or TensorFlow APIs, and then the StableHLO module
is invoked from the other framework.
To enable this mode, set `native_serialization=True` (soon to be the default).
This has several advantages:
The native serialization mode has several advantages:
* supports virtually all operations supported by native execution, e.g.,
`xmap`, `shard_map`, `pmap`, parallel collective operations, and all
primitives at all data types.
* uses standard native code paths in each framework, and thus it is easier
* uses standard native JAX code paths for lowering, and thus it is easier
to trust that the semantics and performance stays faithful to the native
semantics, across platforms. Has optional checking that the code runs on
the platform for which it was serialized.
semantics, across platforms.
* the metadata associated with the operations, e.g., source location, is
identical to what native execution uses.
* includes safety checking that the serialized code is executed on
the platform for which it was serialized.
At the moment when using JAX native serialization the whole
JAX compilation unit is wrapped with a single thin TensorFlow op,
@ -61,7 +57,8 @@ The reasons we wrap the StableHLO in a TensorFlow op are:
For backwards compatibility purposes, and for special uses,
the JAX-TensorFlow interoperation APIs can be used also
in a **graph serialization** mode (the only mode available before version 0.4.7),
in a **graph serialization** mode (the only mode available before version 0.4.7,
and the default mode before JAX version 0.4.15),
without going through StableHLO.
* For calling JAX functions from TensorFlow,
@ -90,6 +87,13 @@ without going through StableHLO.
be useful if the target TensorFlow function is not lowerable to HLO, e.g.,
is using strings.
To disable native serialization, you can do the following, in decreasing
priority order:
* set `native_serialization=False`, or
* use the configuration flag `--jax2tf_default_native_serialization=false`, or
* use the environment variable `JAX2TF_DEFAULT_NATIVE_SERIALIZATION=false`.
We describe below some general concepts and capabilities, first for
`jax2tf.convert` and [later](#calling-tensorflow-functions-from-jax)
for `jax2tf.call_tf`.