mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Change mhlo.is_same_data_across_replicas
from unit attr to bool attr
Using bool attrs aligns better with StableHLO. Since [VHLO does not define unit attrs](https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/VhloAttrs.td), serializing StableHLO modules containing unit attrs fails. This becomes a problem when we want to serialize MHLO modules containing `mhlo.is_same_data_across_replicas` by converting them into StableHLO then VHLO. JAX emits `mhlo.is_same_data_across_replicas` as a bool attr only after a new jaxlib version since this requires the jaxlib to understand the new attr type. PiperOrigin-RevId: 550745955
This commit is contained in:
parent
7821516105
commit
14a6089e89
@ -972,7 +972,10 @@ def lower_jaxpr_to_fun(
|
||||
in zip(replicated_args, input_types)]
|
||||
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
|
||||
if replicated:
|
||||
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
|
||||
if xla_extension_version < 172:
|
||||
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
|
||||
else:
|
||||
attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)
|
||||
|
||||
if use_sharding_annotations and ir_arg_shardings is not None:
|
||||
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
|
||||
|
Loading…
x
Reference in New Issue
Block a user