Tweaks the utility function _get_ppspec_from_executable to get the shardings directly from the executable (instead of from its HLO modules).

PiperOrigin-RevId: 549473458
This commit is contained in:
jax authors 2023-07-19 17:38:14 -07:00
parent c006e52f1a
commit e2a49ee297

View File

@ -1997,12 +1997,16 @@ def _get_op_sharding_from_executable(
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
input_op_shardings: Sequence[xc.OpSharding] = executable.hlo_modules()[0].spmd_parameters_shardings
output_op_sharding: xc.OpSharding = executable.hlo_modules()[0].spmd_output_sharding
input_op_shardings, output_op_sharding = _get_op_sharding_from_executable(
executable
)
in_ppspec: list[ParsedPartitionSpec] = []
for s in input_op_shardings:
in_ppspec.extend(parse_flatten_op_sharding(s, mesh))
out_ppspec = parse_flatten_op_sharding(output_op_sharding, mesh)
out_ppspec: list[ParsedPartitionSpec] = []
for s in output_op_sharding:
out_ppspec.extend(parse_flatten_op_sharding(s, mesh))
return in_ppspec, out_ppspec