Give auto sharder the mesh information specifically the mesh_shape and the devices

ids of devices in the mesh.

PiperOrigin-RevId: 438882041
This commit is contained in:
Yash Katariya 2022-04-01 12:18:56 -07:00 committed by jax authors
parent 8c3385c542
commit 7b7458b474

View File

@ -2418,6 +2418,9 @@ class MeshExecutable(stages.Executable):
device_assignment=device_assignment,
use_spmd_partitioning=spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
# Set by default. The decision to use them is taken in `xb.get_compile_options`.
auto_spmd_partitioning_mesh_shape=list(mesh.shape.values()),
auto_spmd_partitioning_mesh_ids=mesh.device_ids.reshape(-1),
)
compile_options.parameter_is_tupled_arguments = tuple_args
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \