Use stablehlo.get_minimum_version in jax_export.py

The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API.

PiperOrigin-RevId: 535091927
This commit is contained in:
Eugene Burmako 2023-05-24 21:14:40 -07:00 committed by jax authors
parent 8e397f7f08
commit e25052c6f8
2 changed files with 26 additions and 2 deletions

View File

@ -276,7 +276,26 @@ def export(fun_jax: Callable,
xla_call_module_version = 5
mlir_str = mlir.module_to_bytecode(mlir_module)
target_version = stablehlo.get_earliest_forward_compatible_version()
if stablehlo.get_api_version() < 4:
target_version = stablehlo.get_earliest_forward_compatible_version()
else:
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
# where StableHLO portable artifacts serialized by `producer_version`
# can be deserialized by `consumer_version` within the window.
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `stablehlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = stablehlo.get_minimum_version()
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)

View File

@ -1658,7 +1658,12 @@ def get_serialized_computation(
mlir_module = lowered._lowering.stablehlo()
xla_call_module_version = 5
mlir_str = mlir.module_to_bytecode(mlir_module)
target_version = stablehlo.get_earliest_forward_compatible_version()
if stablehlo.get_api_version() < 4:
target_version = stablehlo.get_earliest_forward_compatible_version()
else:
# See comments next to the usage of stablehlo.get_minimum_version() in
# jax_export.py for an explanation how it works.
target_version = stablehlo.get_minimum_version()
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
return mlir_module_serialized, xla_call_module_version