mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax2tf] Bump the default JAX serialization version to 7.
This enables shape assertion checking, the support for which landed in XlaCallModule on July 12th, 2023. See the CHANGELOG for details. PiperOrigin-RevId: 556222908
This commit is contained in:
parent
580b860284
commit
cf4e1d414b
@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
|
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
|
||||||
separate unfiltered/filtered tracebacks, which was the old behavior) or
|
separate unfiltered/filtered tracebacks, which was the old behavior) or
|
||||||
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
||||||
|
* jax2tf default serialization version is now 7, which introduces new shape
|
||||||
|
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
|
||||||
|
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
* jax2tf now uses native serialization by default. See
|
* jax2tf now uses native serialization by default. See
|
||||||
|
@ -688,8 +688,8 @@ jax_serialization_version = config.define_int_state(
|
|||||||
# Note: bump the default serialization version at least one month after
|
# Note: bump the default serialization version at least one month after
|
||||||
# we update XlaCallModule to support the new version, so that serialized
|
# we update XlaCallModule to support the new version, so that serialized
|
||||||
# modules are forward compatible with deployed versions of XlaCallModule.
|
# modules are forward compatible with deployed versions of XlaCallModule.
|
||||||
# Version 6 of XlaCallModule is supported since June 7th, 2023.
|
# Version 7 of XlaCallModule is supported since July 12th, 2023.
|
||||||
default=int_env('JAX_SERIALIZATION_VERSION', 6),
|
default=int_env('JAX_SERIALIZATION_VERSION', 7),
|
||||||
help=(
|
help=(
|
||||||
'The version number to use for native serialization. This must be '
|
'The version number to use for native serialization. This must be '
|
||||||
'within the range of versions supported by the tf.XlaCallModule '
|
'within the range of versions supported by the tf.XlaCallModule '
|
||||||
|
@ -814,12 +814,14 @@ We list here a history of the serialization version numbers:
|
|||||||
for some specialized use cases. Used in JAX from May 3rd, 2023
|
for some specialized use cases. Used in JAX from May 3rd, 2023
|
||||||
(cl/529106145).
|
(cl/529106145).
|
||||||
* Version 6 adds support for the `disabled_checks` attribute. This version
|
* Version 6 adds support for the `disabled_checks` attribute. This version
|
||||||
mandates a non-empty `platforms` attribute.
|
mandates a non-empty `platforms` attribute. Supported by XlaCallModule
|
||||||
Used in JAX since June 13th, 2023 (JAX 0.4.13).
|
since June 7th, 2023 and available in JAX since
|
||||||
|
June 13th, 2023 (JAX 0.4.13).
|
||||||
* Version 7 adds support for `stablehlo.shape_assertion` operations and
|
* Version 7 adds support for `stablehlo.shape_assertion` operations and
|
||||||
for `shape_assertions` specified in `disabled_checks`.
|
for `shape_assertions` specified in `disabled_checks`.
|
||||||
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
|
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
|
||||||
Available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
|
since July 12th, 2023 (cl/547482522) and
|
||||||
|
available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
|
||||||
* Version 8 adds support for the `jax.uses_shape_polymorphism` module
|
* Version 8 adds support for the `jax.uses_shape_polymorphism` module
|
||||||
attribute and enables the shape refinement pass only when the
|
attribute and enables the shape refinement pass only when the
|
||||||
attribute is present. Supported by XlaCallModule since July 21st, 2023
|
attribute is present. Supported by XlaCallModule since July 21st, 2023
|
||||||
|
@ -699,7 +699,9 @@ class CompatTest(bctu.CompatTestBase):
|
|||||||
|
|
||||||
self.run_one_test(func, data,
|
self.run_one_test(func, data,
|
||||||
polymorphic_shapes=("_, b",),
|
polymorphic_shapes=("_, b",),
|
||||||
check_results=check_top_k_results)
|
check_results=check_top_k_results,
|
||||||
|
# TODO(necula): now also includes shape_assertion
|
||||||
|
compare_with_current=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user