mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Rename SpecifiedLayout to DeviceLocalLayout
PiperOrigin-RevId: 620934348
This commit is contained in:
parent
011ced4431
commit
6557f680fd
@ -533,7 +533,7 @@ class ArrayImpl(basearray.Array):
|
||||
def layout(self):
|
||||
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
|
||||
try:
|
||||
return layout.SpecifiedLayout(self._pjrt_layout)
|
||||
return layout.DeviceLocalLayout(self._pjrt_layout)
|
||||
except xe.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
|
@ -46,7 +46,7 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, SpecifiedLayout
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
@ -834,7 +834,7 @@ def _to_physical_op_sharding(
|
||||
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
|
||||
|
||||
|
||||
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
|
||||
def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None:
|
||||
if layout is None:
|
||||
return "default"
|
||||
if isinstance(layout, AutoLayout):
|
||||
@ -862,8 +862,8 @@ def lower_jaxpr_to_module(
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
|
||||
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
|
||||
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
arg_names: Sequence[str | None] | None = None,
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
num_replicas: int = 1,
|
||||
|
@ -60,7 +60,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import SpecifiedLayout, AutoLayout
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -1996,7 +1996,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
return False
|
||||
return True
|
||||
|
||||
MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]]
|
||||
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
|
||||
|
||||
class AllArgsInfo(NamedTuple):
|
||||
@ -2607,7 +2607,7 @@ def maybe_get_orig_out_sharding(
|
||||
|
||||
def _get_layouts_from_executable(
|
||||
xla_executable, in_layouts, out_layouts, num_ordered_effects
|
||||
) -> tuple[Sequence[SpecifiedLayout | None], Sequence[SpecifiedLayout | None]]:
|
||||
) -> tuple[Sequence[DeviceLocalLayout | None], Sequence[DeviceLocalLayout | None]]:
|
||||
try:
|
||||
in_layouts_xla = xla_executable.get_parameter_layouts()
|
||||
out_layouts_xla = xla_executable.get_output_layouts()
|
||||
@ -2620,8 +2620,8 @@ def _get_layouts_from_executable(
|
||||
|
||||
new_in_layouts = []
|
||||
for x, i in safe_zip(in_layouts_xla, in_layouts):
|
||||
x = SpecifiedLayout(x)
|
||||
if isinstance(i, SpecifiedLayout):
|
||||
x = DeviceLocalLayout(x)
|
||||
if isinstance(i, DeviceLocalLayout):
|
||||
if i != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
|
||||
@ -2631,8 +2631,8 @@ def _get_layouts_from_executable(
|
||||
|
||||
new_out_layouts = []
|
||||
for x, o in safe_zip(out_layouts_xla, out_layouts):
|
||||
x = SpecifiedLayout(x)
|
||||
if isinstance(o, SpecifiedLayout):
|
||||
x = DeviceLocalLayout(x)
|
||||
if isinstance(o, DeviceLocalLayout):
|
||||
if o != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
|
||||
@ -2640,8 +2640,8 @@ def _get_layouts_from_executable(
|
||||
else:
|
||||
new_out_layouts.append(x)
|
||||
|
||||
assert all(isinstance(i, SpecifiedLayout) for i in new_in_layouts)
|
||||
assert all(isinstance(o, SpecifiedLayout) for o in new_out_layouts)
|
||||
assert all(isinstance(i, DeviceLocalLayout) for i in new_in_layouts)
|
||||
assert all(isinstance(o, DeviceLocalLayout) for o in new_out_layouts)
|
||||
return new_in_layouts, new_out_layouts # type: ignore
|
||||
|
||||
|
||||
@ -2823,8 +2823,8 @@ class UnloadedMeshExecutable:
|
||||
kept_var_idx: set[int]
|
||||
mut: MutationData | None
|
||||
auto_spmd_lowering: bool
|
||||
in_layouts: Sequence[SpecifiedLayout | None]
|
||||
out_layouts: Sequence[SpecifiedLayout | None]
|
||||
in_layouts: Sequence[DeviceLocalLayout | None]
|
||||
out_layouts: Sequence[DeviceLocalLayout | None]
|
||||
all_args_info: AllArgsInfo | None
|
||||
|
||||
def build_unsafe_call(self):
|
||||
@ -3219,7 +3219,7 @@ def check_device_backend_on_shardings(shardings) -> bool:
|
||||
|
||||
def check_array_xla_sharding_layout_match(
|
||||
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
in_xla_layouts: Sequence[SpecifiedLayout],
|
||||
in_xla_layouts: Sequence[DeviceLocalLayout],
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
|
||||
from jax._src.array import ArrayImpl
|
||||
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
|
||||
|
@ -17,12 +17,7 @@ from __future__ import annotations
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
# TODO(yashkatariya): Revist the 3 class hierarchy after ifrt::Layout lands.
|
||||
class Layout:
|
||||
pass
|
||||
|
||||
|
||||
class SpecifiedLayout(Layout):
|
||||
class DeviceLocalLayout:
|
||||
layout: xc.PjRtLayout
|
||||
|
||||
def __init__(self, layout: xc.PjRtLayout):
|
||||
@ -30,13 +25,13 @@ class SpecifiedLayout(Layout):
|
||||
self._layout_str = str(self._layout)
|
||||
|
||||
def __repr__(self):
|
||||
return f'SpecifiedLayout({self._layout_str})'
|
||||
return f'DeviceLocalLayout({self._layout_str})'
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._layout)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, SpecifiedLayout):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return self._layout == other._layout
|
||||
|
||||
|
@ -44,7 +44,7 @@ from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src.tree_util import tree_unflatten, keystr
|
||||
from jax._src import util
|
||||
from jax._src.layout import SpecifiedLayout
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -513,7 +513,7 @@ class Compiled(Stage):
|
||||
|
||||
def _input_layouts(self):
|
||||
layouts_flat = self._executable.input_layouts()
|
||||
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
|
||||
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
|
||||
# Some input layouts got DCE'd
|
||||
if self.in_tree.num_leaves > len(layouts_flat):
|
||||
iter_layouts_flat = iter(layouts_flat)
|
||||
@ -523,7 +523,7 @@ class Compiled(Stage):
|
||||
|
||||
def _output_layouts(self):
|
||||
layouts_flat = self._executable.output_layouts()
|
||||
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
|
||||
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
|
||||
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
|
||||
|
||||
@staticmethod
|
||||
|
@ -13,6 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.layout import (
|
||||
SpecifiedLayout as SpecifiedLayout,
|
||||
DeviceLocalLayout as DeviceLocalLayout,
|
||||
AUTO as AUTO,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user