[jax2tf] Update CHANGELOG for native serialization.

PiperOrigin-RevId: 517994283
This commit is contained in:
George Necula 2023-03-20 09:42:59 -07:00 committed by jax authors
parent 9472b52273
commit 15acc49451

View File

@ -12,6 +12,14 @@ Remember to align the itemized text with the first line of an item within a list
* As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
`jax.config.jax_array` cannot be disabled anymore.
* `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore.
* {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization`
parameter to use JAX's native lowering to StableHLO to obtain a
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
has been renamed to `--jax2tf_native_serialization`.
* Deprecations
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,