mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Tweaks the utility function _get_ppspec_from_executable
to get the shardings directly from the executable (instead of from its HLO modules).
PiperOrigin-RevId: 549473458
This commit is contained in:
parent
c006e52f1a
commit
e2a49ee297
@ -1997,12 +1997,16 @@ def _get_op_sharding_from_executable(
|
||||
|
||||
|
||||
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
||||
input_op_shardings: Sequence[xc.OpSharding] = executable.hlo_modules()[0].spmd_parameters_shardings
|
||||
output_op_sharding: xc.OpSharding = executable.hlo_modules()[0].spmd_output_sharding
|
||||
input_op_shardings, output_op_sharding = _get_op_sharding_from_executable(
|
||||
executable
|
||||
)
|
||||
in_ppspec: list[ParsedPartitionSpec] = []
|
||||
for s in input_op_shardings:
|
||||
in_ppspec.extend(parse_flatten_op_sharding(s, mesh))
|
||||
out_ppspec = parse_flatten_op_sharding(output_op_sharding, mesh)
|
||||
|
||||
out_ppspec: list[ParsedPartitionSpec] = []
|
||||
for s in output_op_sharding:
|
||||
out_ppspec.extend(parse_flatten_op_sharding(s, mesh))
|
||||
return in_ppspec, out_ppspec
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user