mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax2tf] Bump the default JAX serialization version to 7.
This enables shape assertion checking, the support for which landed in XlaCallModule on July 12th, 2023. See the CHANGELOG for details. PiperOrigin-RevId: 556222908
This commit is contained in:
parent
580b860284
commit
cf4e1d414b
@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
|
||||
separate unfiltered/filtered tracebacks, which was the old behavior) or
|
||||
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
||||
* jax2tf default serialization version is now 7, which introduces new shape
|
||||
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
|
||||
|
||||
* Breaking changes:
|
||||
* jax2tf now uses native serialization by default. See
|
||||
|
@ -688,8 +688,8 @@ jax_serialization_version = config.define_int_state(
|
||||
# Note: bump the default serialization version at least one month after
|
||||
# we update XlaCallModule to support the new version, so that serialized
|
||||
# modules are forward compatible with deployed versions of XlaCallModule.
|
||||
# Version 6 of XlaCallModule is supported since June 7th, 2023.
|
||||
default=int_env('JAX_SERIALIZATION_VERSION', 6),
|
||||
# Version 7 of XlaCallModule is supported since July 12th, 2023.
|
||||
default=int_env('JAX_SERIALIZATION_VERSION', 7),
|
||||
help=(
|
||||
'The version number to use for native serialization. This must be '
|
||||
'within the range of versions supported by the tf.XlaCallModule '
|
||||
|
@ -814,12 +814,14 @@ We list here a history of the serialization version numbers:
|
||||
for some specialized use cases. Used in JAX from May 3rd, 2023
|
||||
(cl/529106145).
|
||||
* Version 6 adds support for the `disabled_checks` attribute. This version
|
||||
mandates a non-empty `platforms` attribute.
|
||||
Used in JAX since June 13th, 2023 (JAX 0.4.13).
|
||||
mandates a non-empty `platforms` attribute. Supported by XlaCallModule
|
||||
since June 7th, 2023 and available in JAX since
|
||||
June 13th, 2023 (JAX 0.4.13).
|
||||
* Version 7 adds support for `stablehlo.shape_assertion` operations and
|
||||
for `shape_assertions` specified in `disabled_checks`.
|
||||
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
|
||||
Available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
|
||||
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
|
||||
since July 12th, 2023 (cl/547482522) and
|
||||
available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
|
||||
* Version 8 adds support for the `jax.uses_shape_polymorphism` module
|
||||
attribute and enables the shape refinement pass only when the
|
||||
attribute is present. Supported by XlaCallModule since July 21st, 2023
|
||||
|
@ -699,7 +699,9 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
self.run_one_test(func, data,
|
||||
polymorphic_shapes=("_, b",),
|
||||
check_results=check_top_k_results)
|
||||
check_results=check_top_k_results,
|
||||
# TODO(necula): now also includes shape_assertion
|
||||
compare_with_current=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user