Rename SpecifiedLayout to DeviceLocalLayout

PiperOrigin-RevId: 620934348
This commit is contained in:
Yash Katariya 2024-04-01 13:18:56 -07:00 committed by jax authors
parent 011ced4431
commit 6557f680fd
6 changed files with 24 additions and 29 deletions

View File

@ -533,7 +533,7 @@ class ArrayImpl(basearray.Array):
def layout(self):
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
try:
return layout.SpecifiedLayout(self._pjrt_layout)
return layout.DeviceLocalLayout(self._pjrt_layout)
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):

View File

@ -46,7 +46,7 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, SpecifiedLayout
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
@ -834,7 +834,7 @@ def _to_physical_op_sharding(
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None:
if layout is None:
return "default"
if isinstance(layout, AutoLayout):
@ -862,8 +862,8 @@ def lower_jaxpr_to_module(
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
num_replicas: int = 1,

View File

@ -60,7 +60,7 @@ from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import SpecifiedLayout, AutoLayout
from jax._src.layout import DeviceLocalLayout, AutoLayout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
@ -1996,7 +1996,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
return False
return True
MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]]
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
class AllArgsInfo(NamedTuple):
@ -2607,7 +2607,7 @@ def maybe_get_orig_out_sharding(
def _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts, num_ordered_effects
) -> tuple[Sequence[SpecifiedLayout | None], Sequence[SpecifiedLayout | None]]:
) -> tuple[Sequence[DeviceLocalLayout | None], Sequence[DeviceLocalLayout | None]]:
try:
in_layouts_xla = xla_executable.get_parameter_layouts()
out_layouts_xla = xla_executable.get_output_layouts()
@ -2620,8 +2620,8 @@ def _get_layouts_from_executable(
new_in_layouts = []
for x, i in safe_zip(in_layouts_xla, in_layouts):
x = SpecifiedLayout(x)
if isinstance(i, SpecifiedLayout):
x = DeviceLocalLayout(x)
if isinstance(i, DeviceLocalLayout):
if i != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
@ -2631,8 +2631,8 @@ def _get_layouts_from_executable(
new_out_layouts = []
for x, o in safe_zip(out_layouts_xla, out_layouts):
x = SpecifiedLayout(x)
if isinstance(o, SpecifiedLayout):
x = DeviceLocalLayout(x)
if isinstance(o, DeviceLocalLayout):
if o != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
@ -2640,8 +2640,8 @@ def _get_layouts_from_executable(
else:
new_out_layouts.append(x)
assert all(isinstance(i, SpecifiedLayout) for i in new_in_layouts)
assert all(isinstance(o, SpecifiedLayout) for o in new_out_layouts)
assert all(isinstance(i, DeviceLocalLayout) for i in new_in_layouts)
assert all(isinstance(o, DeviceLocalLayout) for o in new_out_layouts)
return new_in_layouts, new_out_layouts # type: ignore
@ -2823,8 +2823,8 @@ class UnloadedMeshExecutable:
kept_var_idx: set[int]
mut: MutationData | None
auto_spmd_lowering: bool
in_layouts: Sequence[SpecifiedLayout | None]
out_layouts: Sequence[SpecifiedLayout | None]
in_layouts: Sequence[DeviceLocalLayout | None]
out_layouts: Sequence[DeviceLocalLayout | None]
all_args_info: AllArgsInfo | None
def build_unsafe_call(self):
@ -3219,7 +3219,7 @@ def check_device_backend_on_shardings(shardings) -> bool:
def check_array_xla_sharding_layout_match(
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
in_xla_layouts: Sequence[SpecifiedLayout],
in_xla_layouts: Sequence[DeviceLocalLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
from jax._src.array import ArrayImpl
arg_names = ([''] * len(args) if jaxpr_debug_info is None else

View File

@ -17,12 +17,7 @@ from __future__ import annotations
from jax._src.lib import xla_client as xc
# TODO(yashkatariya): Revist the 3 class hierarchy after ifrt::Layout lands.
class Layout:
pass
class SpecifiedLayout(Layout):
class DeviceLocalLayout:
layout: xc.PjRtLayout
def __init__(self, layout: xc.PjRtLayout):
@ -30,13 +25,13 @@ class SpecifiedLayout(Layout):
self._layout_str = str(self._layout)
def __repr__(self):
return f'SpecifiedLayout({self._layout_str})'
return f'DeviceLocalLayout({self._layout_str})'
def __hash__(self):
return hash(self._layout)
def __eq__(self, other):
if not isinstance(other, SpecifiedLayout):
if not isinstance(other, DeviceLocalLayout):
return False
return self._layout == other._layout

View File

@ -44,7 +44,7 @@ from jax._src import traceback_util
from jax._src import tree_util
from jax._src.tree_util import tree_unflatten, keystr
from jax._src import util
from jax._src.layout import SpecifiedLayout
from jax._src.layout import DeviceLocalLayout
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
@ -513,7 +513,7 @@ class Compiled(Stage):
def _input_layouts(self):
layouts_flat = self._executable.input_layouts()
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
# Some input layouts got DCE'd
if self.in_tree.num_leaves > len(layouts_flat):
iter_layouts_flat = iter(layouts_flat)
@ -523,7 +523,7 @@ class Compiled(Stage):
def _output_layouts(self):
layouts_flat = self._executable.output_layouts()
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
@staticmethod

View File

@ -13,6 +13,6 @@
# limitations under the License.
from jax._src.layout import (
SpecifiedLayout as SpecifiedLayout,
DeviceLocalLayout as DeviceLocalLayout,
AUTO as AUTO,
)