mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use stablehlo.get_minimum_version in jax_export.py
The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API. PiperOrigin-RevId: 535091927
This commit is contained in:
parent
8e397f7f08
commit
e25052c6f8
@ -276,7 +276,26 @@ def export(fun_jax: Callable,
|
||||
|
||||
xla_call_module_version = 5
|
||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
if stablehlo.get_api_version() < 4:
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
else:
|
||||
# `target_version` is used to manage situations when a StableHLO producer
|
||||
# (in this case, jax2tf) and a StableHLO consumer were built using
|
||||
# different versions of StableHLO.
|
||||
#
|
||||
# Each StableHLO version `producer_version` has a compatibility window,
|
||||
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
|
||||
# where StableHLO portable artifacts serialized by `producer_version`
|
||||
# can be deserialized by `consumer_version` within the window.
|
||||
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
|
||||
# for the exact extent of these compatibility guarantees.
|
||||
#
|
||||
# `stablehlo.get_minimum_version()` returns `consumer_version_min`
|
||||
# for the current version of StableHLO. We are using it here to maximize
|
||||
# forward compatibility, i.e. to maximize how far into the past we can go
|
||||
# and still have the payloads produced by `serialize_portable_artifact`
|
||||
# compatible with potential consumers from the past.
|
||||
target_version = stablehlo.get_minimum_version()
|
||||
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||
mlir_str, target_version)
|
||||
|
||||
|
@ -1658,7 +1658,12 @@ def get_serialized_computation(
|
||||
mlir_module = lowered._lowering.stablehlo()
|
||||
xla_call_module_version = 5
|
||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
if stablehlo.get_api_version() < 4:
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
else:
|
||||
# See comments next to the usage of stablehlo.get_minimum_version() in
|
||||
# jax_export.py for an explanation how it works.
|
||||
target_version = stablehlo.get_minimum_version()
|
||||
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||
mlir_str, target_version)
|
||||
return mlir_module_serialized, xla_call_module_version
|
||||
|
Loading…
x
Reference in New Issue
Block a user