rocm_jax/jax/_src/layout.py
Yash Katariya 0a72e856cf Add **experimental** with_dll_constraint API. This is for cases when the users wants to let SPMD decide the sharding.
But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.

PiperOrigin-RevId: 744888557
2025-04-07 16:21:58 -07:00

146 lines
5.3 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the ific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Union
import numpy as np
from jax._src.dtypes import iinfo, issubdtype
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding
from jax._src.lib import xla_client as xc
Shape = tuple[int, ...]
class AutoLayout:
def __repr__(self):
return "AUTO"
class DeviceLocalLayout:
major_to_minor: tuple[int, ...]
_tiling: tuple[tuple[int, ...], ...] | None
_sub_byte_element_size_in_bits: int
AUTO = AutoLayout()
def __init__(self, major_to_minor: tuple[int, ...],
_tiling: tuple[tuple[int, ...], ...] | None = None,
_sub_byte_element_size_in_bits: int = 0):
self.major_to_minor = tuple(major_to_minor)
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
xla_layout = pjrt_layout._xla_layout()
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
xla_layout.tiling(), # type: ignore[arg-type]
xla_layout.element_size_in_bits())
def __repr__(self):
return (
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' _tiling={self._tiling},'
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
)
def __hash__(self):
return hash((self.major_to_minor, self._tiling,
self._sub_byte_element_size_in_bits))
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return (self.major_to_minor == other.major_to_minor and
self._tiling == other._tiling and
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
def _to_xla_layout(self, dtype) -> xc.Layout:
if self._tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
else:
if self._sub_byte_element_size_in_bits != 0:
sub_byte_size = self._sub_byte_element_size_in_bits
elif issubdtype(dtype, np.integer):
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
else:
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling,
sub_byte_size)
return xla_layout
def check_compatible_aval(self, aval_shape: Shape):
if len(self.major_to_minor) != len(aval_shape):
raise ValueError(
f'Length of major_to_minor and the rank of the value should match.'
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
ShardingOptions = Union[Sharding, None, AutoSharding]
class Layout:
__slots__ = ['device_local_layout', 'sharding']
def __init__(self, device_local_layout: LayoutOptions = None,
sharding: ShardingOptions = None):
# If layout is concrete and sharding is not, error.
if (isinstance(device_local_layout, DeviceLocalLayout) and
(sharding is None or isinstance(sharding, AutoSharding))):
raise ValueError(
'Sharding has to be concrete when layout is of type'
f' {type(device_local_layout)}. Please pass a'
' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or'
' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got'
f' sharding {sharding}'
)
if not isinstance(
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
raise TypeError(
'Invalid value received for the device_local_layout argument.'
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an'
f' instance of `DeviceLocalLayout`. Got {device_local_layout} of'
f' type {type(device_local_layout)}'
)
if not isinstance(
sharding, (Sharding, type(None), AutoSharding)):
raise TypeError(
'Invalid value received for the sharding argument. Expected values'
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
f' {sharding} of type {type(sharding)}')
self.device_local_layout = device_local_layout
self.sharding = sharding
@property
def dll(self):
return self.device_local_layout
def __repr__(self):
return (f'Layout(device_local_layout={self.device_local_layout},'
f' sharding={self.sharding})')
def __hash__(self):
return hash((self.device_local_layout, self.sharding))
def __eq__(self, other):
if not isinstance(other, Layout):
return False
return (self.device_local_layout == other.device_local_layout and
self.sharding == other.sharding)