mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Set device_assignment to None when only AbstractMesh exist in the computation
PiperOrigin-RevId: 705899088
This commit is contained in:
parent
64eae324ee
commit
80cf141863
@ -2000,6 +2000,8 @@ def jaxpr_transfer_mem_kinds(
|
||||
|
||||
|
||||
def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
if da_object is None:
|
||||
return True
|
||||
try:
|
||||
default_mem_kind = da_object.default_memory_kind
|
||||
except:
|
||||
@ -2081,8 +2083,9 @@ def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
|
||||
return tuple(safe_map(read, jaxpr.outvars))
|
||||
|
||||
|
||||
def _get_num_devices(shardings, device_assignment, lowering_platforms,
|
||||
prim_requires_devices) -> int:
|
||||
def _get_num_devices(
|
||||
shardings, device_assignment, lowering_platforms, prim_requires_devices
|
||||
) -> tuple[int, tuple[xc.Device, ...] | None]:
|
||||
ext_abstract_mesh, concrete_sharding = None, False
|
||||
for s in shardings:
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
@ -2100,9 +2103,9 @@ def _get_num_devices(shardings, device_assignment, lowering_platforms,
|
||||
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
|
||||
f" device assignment size: {len(device_assignment)}")
|
||||
if concrete_sharding:
|
||||
return len(device_assignment)
|
||||
return len(device_assignment), device_assignment
|
||||
if ext_abstract_mesh is None:
|
||||
return len(device_assignment)
|
||||
return len(device_assignment), device_assignment
|
||||
if lowering_platforms is None:
|
||||
raise ValueError(
|
||||
"Passing lowering_platforms via"
|
||||
@ -2112,7 +2115,7 @@ def _get_num_devices(shardings, device_assignment, lowering_platforms,
|
||||
raise ValueError(
|
||||
"AbstractMesh cannot be used when jaxpr contains primitives that"
|
||||
" require devices to be present during lowering.")
|
||||
return ext_abstract_mesh.size
|
||||
return ext_abstract_mesh.size, None
|
||||
|
||||
|
||||
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
@ -2262,7 +2265,8 @@ def lower_sharding_computation(
|
||||
# is good enough to lower with AbstractMesh but cannot be compiled. Once
|
||||
# I refactor, this will also work well with mesh being provided at
|
||||
# compile time.
|
||||
num_devices = _get_num_devices(
|
||||
# Sets device_assignment to None if only abstractMesh and unspecified exists.
|
||||
num_devices, device_assignment = _get_num_devices( # type: ignore
|
||||
it.chain(unique_in_shardings, unique_out_shardings,
|
||||
unique_intermediate_shardings),
|
||||
device_assignment, lowering_platforms, prim_requires_devices)
|
||||
@ -2273,7 +2277,8 @@ def lower_sharding_computation(
|
||||
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
|
||||
unique_in_shardings, unique_out_shardings, unique_intermediate_shardings)))
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
da_object = (_create_da_object(tuple(device_assignment))
|
||||
if device_assignment is not None else None)
|
||||
|
||||
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
@ -2291,6 +2296,7 @@ def lower_sharding_computation(
|
||||
|
||||
abstract_mesh = None
|
||||
if prim_requires_devices:
|
||||
assert da_object is not None
|
||||
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
|
||||
unique_intermediate_shardings):
|
||||
if isinstance(sharding, NamedSharding):
|
||||
@ -2863,7 +2869,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive: Any,
|
||||
kept_var_idx: set[int],
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device],
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device] | None,
|
||||
committed: bool,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
@ -2877,8 +2883,9 @@ class UnloadedMeshExecutable:
|
||||
intermediate_shardings: Sequence[JSharding] | None = None,
|
||||
context_mesh: Mesh | None = None,
|
||||
) -> MeshExecutable:
|
||||
if any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
|
||||
for s in it.chain(in_shardings, out_shardings)):
|
||||
if (device_assignment is None or
|
||||
any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
|
||||
for s in it.chain(in_shardings, out_shardings))):
|
||||
raise RuntimeError(
|
||||
"A jitted computation cannot contain AbstractMesh in in_shardings and"
|
||||
" out_shardings during compilation. You can use `jax.export` to "
|
||||
|
Loading…
x
Reference in New Issue
Block a user