mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Update CHANGELOG for native serialization.
PiperOrigin-RevId: 517994283
This commit is contained in:
parent
9472b52273
commit
15acc49451
@ -12,6 +12,14 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
|
||||
`jax.config.jax_array` cannot be disabled anymore.
|
||||
* `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore.
|
||||
* {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization`
|
||||
parameter to use JAX's native lowering to StableHLO to obtain a
|
||||
StableHLO module for the entire JAX function instead of lowering each JAX
|
||||
primitive to a TensorFlow op. This simplifies the internals and increases
|
||||
the confidence that what you serialize matches the JAX native semantics.
|
||||
See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
|
||||
As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
|
||||
has been renamed to `--jax2tf_native_serialization`.
|
||||
|
||||
* Deprecations
|
||||
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
|
||||
|
Loading…
x
Reference in New Issue
Block a user