Change mhlo.is_same_data_across_replicas from unit attr to bool attr

Using bool attrs aligns better with StableHLO. Since [VHLO does not define unit attrs](https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/VhloAttrs.td), serializing StableHLO modules containing unit attrs fails. This becomes a problem when we want to serialize MHLO modules containing `mhlo.is_same_data_across_replicas` by converting them into StableHLO then VHLO.

JAX emits `mhlo.is_same_data_across_replicas` as a bool attr only after a new jaxlib version since this requires the jaxlib to understand the new attr type.

PiperOrigin-RevId: 550745955
This commit is contained in:
Junwhan Ahn 2023-07-24 19:49:59 -07:00 committed by jax authors
parent 7821516105
commit 14a6089e89

View File

@ -972,7 +972,10 @@ def lower_jaxpr_to_fun(
in zip(replicated_args, input_types)]
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
if replicated:
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
if xla_extension_version < 172:
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
else:
attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)
if use_sharding_annotations and ir_arg_shardings is not None:
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):