mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax2tf] Turn on JAX native serialization by default.
See changes to the README.md for mechanisms to override the default. PiperOrigin-RevId: 554390866
This commit is contained in:
parent
364a245ab2
commit
8d80e2587b
@ -17,6 +17,11 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
separate unfiltered/filtered tracebacks, which was the old behavior) or
|
separate unfiltered/filtered tracebacks, which was the old behavior) or
|
||||||
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
||||||
|
|
||||||
|
* Breaking changes:
|
||||||
|
* jax2tf now uses native serialization by default. See
|
||||||
|
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
|
||||||
|
for details and for mechanisms to override the default.
|
||||||
|
|
||||||
## jaxlib 0.4.15
|
## jaxlib 0.4.15
|
||||||
|
|
||||||
## jax 0.4.14 (July 27, 2023)
|
## jax 0.4.14 (July 27, 2023)
|
||||||
|
@ -679,7 +679,7 @@ jax2tf_associative_scan_reductions = config.define_bool_state(
|
|||||||
|
|
||||||
jax2tf_default_native_serialization = config.define_bool_state(
|
jax2tf_default_native_serialization = config.define_bool_state(
|
||||||
name='jax2tf_default_native_serialization',
|
name='jax2tf_default_native_serialization',
|
||||||
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', False),
|
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
|
||||||
help=(
|
help=(
|
||||||
'Sets the default value of the native_serialization parameter to '
|
'Sets the default value of the native_serialization parameter to '
|
||||||
'jax2tf.convert. Prefer using the parameter instead of the flag, '
|
'jax2tf.convert. Prefer using the parameter instead of the flag, '
|
||||||
|
@ -18,27 +18,23 @@ These APIs can be combined, e.g., to reload in JAX a program that
|
|||||||
has been serialized from JAX to a TensorFlow SavedModel, or to save to
|
has been serialized from JAX to a TensorFlow SavedModel, or to save to
|
||||||
TensorFlow SavedModel a JAX program that uses a TensorFlow library.
|
TensorFlow SavedModel a JAX program that uses a TensorFlow library.
|
||||||
|
|
||||||
Tip: As of version 0.4.7 (March 2023), there is a new option
|
Tip: As of version 0.4.14 (July 2023) the default mode of JAX-TensorFlow
|
||||||
`native_serialization` to use JAX's native lowering to StableHLO to obtain
|
interoperation is by way of **native serialization** in which the target
|
||||||
one StableHLO module for the entire JAX function instead of lowering each
|
function is lowered to StableHLO
|
||||||
JAX primitive to a TensorFlow op.
|
|
||||||
|
|
||||||
The preferred mode of JAX-TensorFlow interoperation is by way of
|
|
||||||
**native serialization** in which the target function is lowered to StableHLO
|
|
||||||
using standard native JAX or TensorFlow APIs, and then the StableHLO module
|
using standard native JAX or TensorFlow APIs, and then the StableHLO module
|
||||||
is invoked from the other framework.
|
is invoked from the other framework.
|
||||||
To enable this mode, set `native_serialization=True` (soon to be the default).
|
The native serialization mode has several advantages:
|
||||||
This has several advantages:
|
|
||||||
|
|
||||||
* supports virtually all operations supported by native execution, e.g.,
|
* supports virtually all operations supported by native execution, e.g.,
|
||||||
`xmap`, `shard_map`, `pmap`, parallel collective operations, and all
|
`xmap`, `shard_map`, `pmap`, parallel collective operations, and all
|
||||||
primitives at all data types.
|
primitives at all data types.
|
||||||
* uses standard native code paths in each framework, and thus it is easier
|
* uses standard native JAX code paths for lowering, and thus it is easier
|
||||||
to trust that the semantics and performance stays faithful to the native
|
to trust that the semantics and performance stays faithful to the native
|
||||||
semantics, across platforms. Has optional checking that the code runs on
|
semantics, across platforms.
|
||||||
the platform for which it was serialized.
|
|
||||||
* the metadata associated with the operations, e.g., source location, is
|
* the metadata associated with the operations, e.g., source location, is
|
||||||
identical to what native execution uses.
|
identical to what native execution uses.
|
||||||
|
* includes safety checking that the serialized code is executed on
|
||||||
|
the platform for which it was serialized.
|
||||||
|
|
||||||
At the moment when using JAX native serialization the whole
|
At the moment when using JAX native serialization the whole
|
||||||
JAX compilation unit is wrapped with a single thin TensorFlow op,
|
JAX compilation unit is wrapped with a single thin TensorFlow op,
|
||||||
@ -61,7 +57,8 @@ The reasons we wrap the StableHLO in a TensorFlow op are:
|
|||||||
|
|
||||||
For backwards compatibility purposes, and for special uses,
|
For backwards compatibility purposes, and for special uses,
|
||||||
the JAX-TensorFlow interoperation APIs can be used also
|
the JAX-TensorFlow interoperation APIs can be used also
|
||||||
in a **graph serialization** mode (the only mode available before version 0.4.7),
|
in a **graph serialization** mode (the only mode available before version 0.4.7,
|
||||||
|
and the default mode before JAX version 0.4.15),
|
||||||
without going through StableHLO.
|
without going through StableHLO.
|
||||||
|
|
||||||
* For calling JAX functions from TensorFlow,
|
* For calling JAX functions from TensorFlow,
|
||||||
@ -90,6 +87,13 @@ without going through StableHLO.
|
|||||||
be useful if the target TensorFlow function is not lowerable to HLO, e.g.,
|
be useful if the target TensorFlow function is not lowerable to HLO, e.g.,
|
||||||
is using strings.
|
is using strings.
|
||||||
|
|
||||||
|
To disable native serialization, you can do the following, in decreasing
|
||||||
|
priority order:
|
||||||
|
|
||||||
|
* set `native_serialization=False`, or
|
||||||
|
* use the configuration flag `--jax2tf_default_native_serialization=false`, or
|
||||||
|
* use the environment variable `JAX2TF_DEFAULT_NATIVE_SERIALIZATION=false`.
|
||||||
|
|
||||||
We describe below some general concepts and capabilities, first for
|
We describe below some general concepts and capabilities, first for
|
||||||
`jax2tf.convert` and [later](#calling-tensorflow-functions-from-jax)
|
`jax2tf.convert` and [later](#calling-tensorflow-functions-from-jax)
|
||||||
for `jax2tf.call_tf`.
|
for `jax2tf.call_tf`.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user