mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Turn on JAX native serialization by default.
See changes to the README.md for mechanisms to override the default. PiperOrigin-RevId: 554390866
This commit is contained in:
parent
364a245ab2
commit
8d80e2587b
@ -17,6 +17,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
separate unfiltered/filtered tracebacks, which was the old behavior) or
|
||||
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
||||
|
||||
* Breaking changes:
|
||||
* jax2tf now uses native serialization by default. See
|
||||
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
|
||||
for details and for mechanisms to override the default.
|
||||
|
||||
## jaxlib 0.4.15
|
||||
|
||||
## jax 0.4.14 (July 27, 2023)
|
||||
|
@ -679,7 +679,7 @@ jax2tf_associative_scan_reductions = config.define_bool_state(
|
||||
|
||||
jax2tf_default_native_serialization = config.define_bool_state(
|
||||
name='jax2tf_default_native_serialization',
|
||||
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', False),
|
||||
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
|
||||
help=(
|
||||
'Sets the default value of the native_serialization parameter to '
|
||||
'jax2tf.convert. Prefer using the parameter instead of the flag, '
|
||||
|
@ -18,27 +18,23 @@ These APIs can be combined, e.g., to reload in JAX a program that
|
||||
has been serialized from JAX to a TensorFlow SavedModel, or to save to
|
||||
TensorFlow SavedModel a JAX program that uses a TensorFlow library.
|
||||
|
||||
Tip: As of version 0.4.7 (March 2023), there is a new option
|
||||
`native_serialization` to use JAX's native lowering to StableHLO to obtain
|
||||
one StableHLO module for the entire JAX function instead of lowering each
|
||||
JAX primitive to a TensorFlow op.
|
||||
|
||||
The preferred mode of JAX-TensorFlow interoperation is by way of
|
||||
**native serialization** in which the target function is lowered to StableHLO
|
||||
Tip: As of version 0.4.14 (July 2023) the default mode of JAX-TensorFlow
|
||||
interoperation is by way of **native serialization** in which the target
|
||||
function is lowered to StableHLO
|
||||
using standard native JAX or TensorFlow APIs, and then the StableHLO module
|
||||
is invoked from the other framework.
|
||||
To enable this mode, set `native_serialization=True` (soon to be the default).
|
||||
This has several advantages:
|
||||
The native serialization mode has several advantages:
|
||||
|
||||
* supports virtually all operations supported by native execution, e.g.,
|
||||
`xmap`, `shard_map`, `pmap`, parallel collective operations, and all
|
||||
primitives at all data types.
|
||||
* uses standard native code paths in each framework, and thus it is easier
|
||||
* uses standard native JAX code paths for lowering, and thus it is easier
|
||||
to trust that the semantics and performance stays faithful to the native
|
||||
semantics, across platforms. Has optional checking that the code runs on
|
||||
the platform for which it was serialized.
|
||||
semantics, across platforms.
|
||||
* the metadata associated with the operations, e.g., source location, is
|
||||
identical to what native execution uses.
|
||||
* includes safety checking that the serialized code is executed on
|
||||
the platform for which it was serialized.
|
||||
|
||||
At the moment when using JAX native serialization the whole
|
||||
JAX compilation unit is wrapped with a single thin TensorFlow op,
|
||||
@ -61,7 +57,8 @@ The reasons we wrap the StableHLO in a TensorFlow op are:
|
||||
|
||||
For backwards compatibility purposes, and for special uses,
|
||||
the JAX-TensorFlow interoperation APIs can be used also
|
||||
in a **graph serialization** mode (the only mode available before version 0.4.7),
|
||||
in a **graph serialization** mode (the only mode available before version 0.4.7,
|
||||
and the default mode before JAX version 0.4.15),
|
||||
without going through StableHLO.
|
||||
|
||||
* For calling JAX functions from TensorFlow,
|
||||
@ -90,6 +87,13 @@ without going through StableHLO.
|
||||
be useful if the target TensorFlow function is not lowerable to HLO, e.g.,
|
||||
is using strings.
|
||||
|
||||
To disable native serialization, you can do the following, in decreasing
|
||||
priority order:
|
||||
|
||||
* set `native_serialization=False`, or
|
||||
* use the configuration flag `--jax2tf_default_native_serialization=false`, or
|
||||
* use the environment variable `JAX2TF_DEFAULT_NATIVE_SERIALIZATION=false`.
|
||||
|
||||
We describe below some general concepts and capabilities, first for
|
||||
`jax2tf.convert` and [later](#calling-tensorflow-functions-from-jax)
|
||||
for `jax2tf.call_tf`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user