mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Give auto sharder the mesh information specifically the mesh_shape and the devices
ids of devices in the mesh. PiperOrigin-RevId: 438882041
This commit is contained in:
parent
8c3385c542
commit
7b7458b474
@ -2418,6 +2418,9 @@ class MeshExecutable(stages.Executable):
|
||||
device_assignment=device_assignment,
|
||||
use_spmd_partitioning=spmd_lowering,
|
||||
use_auto_spmd_partitioning=auto_spmd_lowering,
|
||||
# Set by default. The decision to use them is taken in `xb.get_compile_options`.
|
||||
auto_spmd_partitioning_mesh_shape=list(mesh.shape.values()),
|
||||
auto_spmd_partitioning_mesh_ids=mesh.device_ids.reshape(-1),
|
||||
)
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
|
||||
|
Loading…
x
Reference in New Issue
Block a user