Set device_assignment to None when only AbstractMesh exist in the computation

PiperOrigin-RevId: 705899088
This commit is contained in:
Yash Katariya 2024-12-13 08:55:07 -08:00 committed by jax authors
parent 64eae324ee
commit 80cf141863

View File

@ -2000,6 +2000,8 @@ def jaxpr_transfer_mem_kinds(
def are_all_shardings_default_mem_kind(da_object, shardings):
if da_object is None:
return True
try:
default_mem_kind = da_object.default_memory_kind
except:
@ -2081,8 +2083,9 @@ def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
return tuple(safe_map(read, jaxpr.outvars))
def _get_num_devices(shardings, device_assignment, lowering_platforms,
prim_requires_devices) -> int:
def _get_num_devices(
shardings, device_assignment, lowering_platforms, prim_requires_devices
) -> tuple[int, tuple[xc.Device, ...] | None]:
ext_abstract_mesh, concrete_sharding = None, False
for s in shardings:
if isinstance(s, UnspecifiedValue):
@ -2100,9 +2103,9 @@ def _get_num_devices(shardings, device_assignment, lowering_platforms,
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
f" device assignment size: {len(device_assignment)}")
if concrete_sharding:
return len(device_assignment)
return len(device_assignment), device_assignment
if ext_abstract_mesh is None:
return len(device_assignment)
return len(device_assignment), device_assignment
if lowering_platforms is None:
raise ValueError(
"Passing lowering_platforms via"
@ -2112,7 +2115,7 @@ def _get_num_devices(shardings, device_assignment, lowering_platforms,
raise ValueError(
"AbstractMesh cannot be used when jaxpr contains primitives that"
" require devices to be present during lowering.")
return ext_abstract_mesh.size
return ext_abstract_mesh.size, None
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
@ -2262,7 +2265,8 @@ def lower_sharding_computation(
# is good enough to lower with AbstractMesh but cannot be compiled. Once
# I refactor, this will also work well with mesh being provided at
# compile time.
num_devices = _get_num_devices(
# Sets device_assignment to None if only abstractMesh and unspecified exists.
num_devices, device_assignment = _get_num_devices( # type: ignore
it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings),
device_assignment, lowering_platforms, prim_requires_devices)
@ -2273,7 +2277,8 @@ def lower_sharding_computation(
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
unique_in_shardings, unique_out_shardings, unique_intermediate_shardings)))
da_object = _create_da_object(tuple(device_assignment))
da_object = (_create_da_object(tuple(device_assignment))
if device_assignment is not None else None)
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
all_default_mem_kind = are_all_shardings_default_mem_kind(
@ -2291,6 +2296,7 @@ def lower_sharding_computation(
abstract_mesh = None
if prim_requires_devices:
assert da_object is not None
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings):
if isinstance(sharding, NamedSharding):
@ -2863,7 +2869,7 @@ class UnloadedMeshExecutable:
keepalive: Any,
kept_var_idx: set[int],
backend: xb.XlaBackend,
device_assignment: xc.DeviceList | Sequence[xc.Device],
device_assignment: xc.DeviceList | Sequence[xc.Device] | None,
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
@ -2877,8 +2883,9 @@ class UnloadedMeshExecutable:
intermediate_shardings: Sequence[JSharding] | None = None,
context_mesh: Mesh | None = None,
) -> MeshExecutable:
if any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
for s in it.chain(in_shardings, out_shardings)):
if (device_assignment is None or
any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
for s in it.chain(in_shardings, out_shardings))):
raise RuntimeError(
"A jitted computation cannot contain AbstractMesh in in_shardings and"
" out_shardings during compilation. You can use `jax.export` to "