From 8d80e2587b613d12533273b2299937a6f93dc893 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 7 Aug 2023 01:03:15 -0700 Subject: [PATCH] [jax2tf] Turn on JAX native serialization by default. See changes to the README.md for mechanisms to override the default. PiperOrigin-RevId: 554390866 --- CHANGELOG.md | 5 +++++ jax/_src/config.py | 2 +- jax/experimental/jax2tf/README.md | 30 +++++++++++++++++------------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb0ee8b36..70dde5574 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,11 @@ Remember to align the itemized text with the first line of an item within a list separate unfiltered/filtered tracebacks, which was the old behavior) or `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). +* Breaking changes: + * jax2tf now uses native serialization by default. See + the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) + for details and for mechanisms to override the default. + ## jaxlib 0.4.15 ## jax 0.4.14 (July 27, 2023) diff --git a/jax/_src/config.py b/jax/_src/config.py index 4c5f01b6a..50f235d18 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -679,7 +679,7 @@ jax2tf_associative_scan_reductions = config.define_bool_state( jax2tf_default_native_serialization = config.define_bool_state( name='jax2tf_default_native_serialization', - default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', False), + default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True), help=( 'Sets the default value of the native_serialization parameter to ' 'jax2tf.convert. Prefer using the parameter instead of the flag, ' diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 4aca22e10..385c28632 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -18,27 +18,23 @@ These APIs can be combined, e.g., to reload in JAX a program that has been serialized from JAX to a TensorFlow SavedModel, or to save to TensorFlow SavedModel a JAX program that uses a TensorFlow library. -Tip: As of version 0.4.7 (March 2023), there is a new option -`native_serialization` to use JAX's native lowering to StableHLO to obtain -one StableHLO module for the entire JAX function instead of lowering each -JAX primitive to a TensorFlow op. - -The preferred mode of JAX-TensorFlow interoperation is by way of -**native serialization** in which the target function is lowered to StableHLO +Tip: As of version 0.4.14 (July 2023) the default mode of JAX-TensorFlow +interoperation is by way of **native serialization** in which the target +function is lowered to StableHLO using standard native JAX or TensorFlow APIs, and then the StableHLO module is invoked from the other framework. -To enable this mode, set `native_serialization=True` (soon to be the default). -This has several advantages: +The native serialization mode has several advantages: * supports virtually all operations supported by native execution, e.g., `xmap`, `shard_map`, `pmap`, parallel collective operations, and all primitives at all data types. - * uses standard native code paths in each framework, and thus it is easier + * uses standard native JAX code paths for lowering, and thus it is easier to trust that the semantics and performance stays faithful to the native - semantics, across platforms. Has optional checking that the code runs on - the platform for which it was serialized. + semantics, across platforms. * the metadata associated with the operations, e.g., source location, is identical to what native execution uses. + * includes safety checking that the serialized code is executed on + the platform for which it was serialized. At the moment when using JAX native serialization the whole JAX compilation unit is wrapped with a single thin TensorFlow op, @@ -61,7 +57,8 @@ The reasons we wrap the StableHLO in a TensorFlow op are: For backwards compatibility purposes, and for special uses, the JAX-TensorFlow interoperation APIs can be used also -in a **graph serialization** mode (the only mode available before version 0.4.7), +in a **graph serialization** mode (the only mode available before version 0.4.7, +and the default mode before JAX version 0.4.15), without going through StableHLO. * For calling JAX functions from TensorFlow, @@ -90,6 +87,13 @@ without going through StableHLO. be useful if the target TensorFlow function is not lowerable to HLO, e.g., is using strings. +To disable native serialization, you can do the following, in decreasing +priority order: + + * set `native_serialization=False`, or + * use the configuration flag `--jax2tf_default_native_serialization=false`, or + * use the environment variable `JAX2TF_DEFAULT_NATIVE_SERIALIZATION=false`. + We describe below some general concepts and capabilities, first for `jax2tf.convert` and [later](#calling-tensorflow-functions-from-jax) for `jax2tf.call_tf`.