[jax2tf] Bump the default JAX serialization version to 7.

This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
This commit is contained in:
George Necula 2023-08-11 22:49:04 -07:00 committed by jax authors
parent 580b860284
commit cf4e1d414b
4 changed files with 13 additions and 7 deletions

View File

@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Breaking changes: * Breaking changes:
* jax2tf now uses native serialization by default. See * jax2tf now uses native serialization by default. See

View File

@ -688,8 +688,8 @@ jax_serialization_version = config.define_int_state(
# Note: bump the default serialization version at least one month after # Note: bump the default serialization version at least one month after
# we update XlaCallModule to support the new version, so that serialized # we update XlaCallModule to support the new version, so that serialized
# modules are forward compatible with deployed versions of XlaCallModule. # modules are forward compatible with deployed versions of XlaCallModule.
# Version 6 of XlaCallModule is supported since June 7th, 2023. # Version 7 of XlaCallModule is supported since July 12th, 2023.
default=int_env('JAX_SERIALIZATION_VERSION', 6), default=int_env('JAX_SERIALIZATION_VERSION', 7),
help=( help=(
'The version number to use for native serialization. This must be ' 'The version number to use for native serialization. This must be '
'within the range of versions supported by the tf.XlaCallModule ' 'within the range of versions supported by the tf.XlaCallModule '

View File

@ -814,12 +814,14 @@ We list here a history of the serialization version numbers:
for some specialized use cases. Used in JAX from May 3rd, 2023 for some specialized use cases. Used in JAX from May 3rd, 2023
(cl/529106145). (cl/529106145).
* Version 6 adds support for the `disabled_checks` attribute. This version * Version 6 adds support for the `disabled_checks` attribute. This version
mandates a non-empty `platforms` attribute. mandates a non-empty `platforms` attribute. Supported by XlaCallModule
Used in JAX since June 13th, 2023 (JAX 0.4.13). since June 7th, 2023 and available in JAX since
June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and * Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`. for `shape_assertions` specified in `disabled_checks`.
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
Available in JAX serialization since July 20th, 2023 (JAX 0.4.14). since July 12th, 2023 (cl/547482522) and
available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
* Version 8 adds support for the `jax.uses_shape_polymorphism` module * Version 8 adds support for the `jax.uses_shape_polymorphism` module
attribute and enables the shape refinement pass only when the attribute and enables the shape refinement pass only when the
attribute is present. Supported by XlaCallModule since July 21st, 2023 attribute is present. Supported by XlaCallModule since July 21st, 2023

View File

@ -699,7 +699,9 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, self.run_one_test(func, data,
polymorphic_shapes=("_, b",), polymorphic_shapes=("_, b",),
check_results=check_top_k_results) check_results=check_top_k_results,
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
if __name__ == "__main__": if __name__ == "__main__":