diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6500eca14..c80475e18 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2000,6 +2000,8 @@ def jaxpr_transfer_mem_kinds( def are_all_shardings_default_mem_kind(da_object, shardings): + if da_object is None: + return True try: default_mem_kind = da_object.default_memory_kind except: @@ -2081,8 +2083,9 @@ def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr return tuple(safe_map(read, jaxpr.outvars)) -def _get_num_devices(shardings, device_assignment, lowering_platforms, - prim_requires_devices) -> int: +def _get_num_devices( + shardings, device_assignment, lowering_platforms, prim_requires_devices + ) -> tuple[int, tuple[xc.Device, ...] | None]: ext_abstract_mesh, concrete_sharding = None, False for s in shardings: if isinstance(s, UnspecifiedValue): @@ -2100,9 +2103,9 @@ def _get_num_devices(shardings, device_assignment, lowering_platforms, f"AbstractMesh size: {ext_abstract_mesh.size} does not match the" f" device assignment size: {len(device_assignment)}") if concrete_sharding: - return len(device_assignment) + return len(device_assignment), device_assignment if ext_abstract_mesh is None: - return len(device_assignment) + return len(device_assignment), device_assignment if lowering_platforms is None: raise ValueError( "Passing lowering_platforms via" @@ -2112,7 +2115,7 @@ def _get_num_devices(shardings, device_assignment, lowering_platforms, raise ValueError( "AbstractMesh cannot be used when jaxpr contains primitives that" " require devices to be present during lowering.") - return ext_abstract_mesh.size + return ext_abstract_mesh.size, None MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]] @@ -2262,7 +2265,8 @@ def lower_sharding_computation( # is good enough to lower with AbstractMesh but cannot be compiled. Once # I refactor, this will also work well with mesh being provided at # compile time. - num_devices = _get_num_devices( + # Sets device_assignment to None if only abstractMesh and unspecified exists. + num_devices, device_assignment = _get_num_devices( # type: ignore it.chain(unique_in_shardings, unique_out_shardings, unique_intermediate_shardings), device_assignment, lowering_platforms, prim_requires_devices) @@ -2273,7 +2277,8 @@ def lower_sharding_computation( or any(not isinstance(s, UnspecifiedValue) for s in it.chain( unique_in_shardings, unique_out_shardings, unique_intermediate_shardings))) - da_object = _create_da_object(tuple(device_assignment)) + da_object = (_create_da_object(tuple(device_assignment)) + if device_assignment is not None else None) transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr) all_default_mem_kind = are_all_shardings_default_mem_kind( @@ -2291,6 +2296,7 @@ def lower_sharding_computation( abstract_mesh = None if prim_requires_devices: + assert da_object is not None for sharding in it.chain(unique_in_shardings, unique_out_shardings, unique_intermediate_shardings): if isinstance(sharding, NamedSharding): @@ -2863,7 +2869,7 @@ class UnloadedMeshExecutable: keepalive: Any, kept_var_idx: set[int], backend: xb.XlaBackend, - device_assignment: xc.DeviceList | Sequence[xc.Device], + device_assignment: xc.DeviceList | Sequence[xc.Device] | None, committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, @@ -2877,8 +2883,9 @@ class UnloadedMeshExecutable: intermediate_shardings: Sequence[JSharding] | None = None, context_mesh: Mesh | None = None, ) -> MeshExecutable: - if any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) - for s in it.chain(in_shardings, out_shardings)): + if (device_assignment is None or + any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) + for s in it.chain(in_shardings, out_shardings))): raise RuntimeError( "A jitted computation cannot contain AbstractMesh in in_shardings and" " out_shardings during compilation. You can use `jax.export` to "