diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d84c486d..05e1d7f55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,14 @@ Remember to align the itemized text with the first line of an item within a list * As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. * `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. + * {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` + parameter to use JAX's native lowering to StableHLO to obtain a + StableHLO module for the entire JAX function instead of lowering each JAX + primitive to a TensorFlow op. This simplifies the internals and increases + the confidence that what you serialize matches the JAX native semantics. + See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + As part of this change the config flag `--jax2tf_default_experimental_native_lowering` + has been renamed to `--jax2tf_native_serialization`. * Deprecations * The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,