[jax2tf] Bump the default JAX serialization version to 7.

This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
This commit is contained in:
George Necula 2023-08-11 22:49:04 -07:00 committed by jax authors
parent 580b860284
commit cf4e1d414b
4 changed files with 13 additions and 7 deletions

View File

@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Breaking changes:
* jax2tf now uses native serialization by default. See

View File

@ -688,8 +688,8 @@ jax_serialization_version = config.define_int_state(
# Note: bump the default serialization version at least one month after
# we update XlaCallModule to support the new version, so that serialized
# modules are forward compatible with deployed versions of XlaCallModule.
# Version 6 of XlaCallModule is supported since June 7th, 2023.
default=int_env('JAX_SERIALIZATION_VERSION', 6),
# Version 7 of XlaCallModule is supported since July 12th, 2023.
default=int_env('JAX_SERIALIZATION_VERSION', 7),
help=(
'The version number to use for native serialization. This must be '
'within the range of versions supported by the tf.XlaCallModule '

View File

@ -814,12 +814,14 @@ We list here a history of the serialization version numbers:
for some specialized use cases. Used in JAX from May 3rd, 2023
(cl/529106145).
* Version 6 adds support for the `disabled_checks` attribute. This version
mandates a non-empty `platforms` attribute.
Used in JAX since June 13th, 2023 (JAX 0.4.13).
mandates a non-empty `platforms` attribute. Supported by XlaCallModule
since June 7th, 2023 and available in JAX since
June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`.
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
Available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
since July 12th, 2023 (cl/547482522) and
available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
* Version 8 adds support for the `jax.uses_shape_polymorphism` module
attribute and enables the shape refinement pass only when the
attribute is present. Supported by XlaCallModule since July 21st, 2023

View File

@ -699,7 +699,9 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data,
polymorphic_shapes=("_, b",),
check_results=check_top_k_results)
check_results=check_top_k_results,
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
if __name__ == "__main__":