diff --git a/CHANGELOG.md b/CHANGELOG.md index a601a973f..5b9bdae9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two separate unfiltered/filtered tracebacks, which was the old behavior) or `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). + * jax2tf default serialization version is now 7, which introduces new shape + [safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). * Breaking changes: * jax2tf now uses native serialization by default. See diff --git a/jax/_src/config.py b/jax/_src/config.py index 7bd5150a9..f5826a57d 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -688,8 +688,8 @@ jax_serialization_version = config.define_int_state( # Note: bump the default serialization version at least one month after # we update XlaCallModule to support the new version, so that serialized # modules are forward compatible with deployed versions of XlaCallModule. - # Version 6 of XlaCallModule is supported since June 7th, 2023. - default=int_env('JAX_SERIALIZATION_VERSION', 6), + # Version 7 of XlaCallModule is supported since July 12th, 2023. + default=int_env('JAX_SERIALIZATION_VERSION', 7), help=( 'The version number to use for native serialization. This must be ' 'within the range of versions supported by the tf.XlaCallModule ' diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 385c28632..b4ddb7491 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -814,12 +814,14 @@ We list here a history of the serialization version numbers: for some specialized use cases. Used in JAX from May 3rd, 2023 (cl/529106145). * Version 6 adds support for the `disabled_checks` attribute. This version - mandates a non-empty `platforms` attribute. - Used in JAX since June 13th, 2023 (JAX 0.4.13). + mandates a non-empty `platforms` attribute. Supported by XlaCallModule + since June 7th, 2023 and available in JAX since + June 13th, 2023 (JAX 0.4.13). * Version 7 adds support for `stablehlo.shape_assertion` operations and for `shape_assertions` specified in `disabled_checks`. - See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). - Available in JAX serialization since July 20th, 2023 (JAX 0.4.14). + See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + since July 12th, 2023 (cl/547482522) and + available in JAX serialization since July 20th, 2023 (JAX 0.4.14). * Version 8 adds support for the `jax.uses_shape_polymorphism` module attribute and enables the shape refinement pass only when the attribute is present. Supported by XlaCallModule since July 21st, 2023 diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index a512e836a..7ca059c50 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -699,7 +699,9 @@ class CompatTest(bctu.CompatTestBase): self.run_one_test(func, data, polymorphic_shapes=("_, b",), - check_results=check_top_k_results) + check_results=check_top_k_results, + # TODO(necula): now also includes shape_assertion + compare_with_current=False) if __name__ == "__main__":