mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

These are the following changes: * Add a temporary flag (`JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE`) (should not be used by user but needed in C++ in pjrt-ifrt code) on whether to fetch memory kinds from executable. If it is set to True, the host runtime dep needs to be linked in and should also work in OSS (more work needs to happen for that). So only the test sets it to True for now until jax memories is under development. * Add with_memory_kind method on Sharding to allow for easier creation of shardings with different memory kind. * Add lowering rules for device_put and jax.jit. * For device_put, we always add the annotation that describes a transfer to a memory and a sharding annotation. * For jax.jit, if the argument is on host memory, it will have an extra attribute _xla_buffer_placement. * Handle the correct output sharding in pxla.py by extracting the memory kind from the executable. * Handle the caching of pjit caches by canonicalizing the memory_kinds so that `NS(mesh, pspec) == NS(mesh, pspec, memory_kind='tpu_hbm')`. Also canonicalize memory_kind in `__hash__` and `__eq__` of shardings. * This is to not change the StableHLO to include device placement annotations right now since the host aware passes are not enabled by default and the work is under progress to make it work everywhere. PiperOrigin-RevId: 553833344
132 lines
4.8 KiB
Python
132 lines
4.8 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Mapping, Sequence
|
|
import functools
|
|
|
|
from jax._src import util
|
|
from jax._src import xla_bridge as xb
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
Shape = tuple[int, ...]
|
|
Device = xc.Device
|
|
Index = tuple[slice, ...]
|
|
XLADeviceAssignment = Sequence[Device]
|
|
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def _addressable_devices_indices_map(
|
|
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]:
|
|
if sharding.is_fully_addressable:
|
|
return sharding.devices_indices_map(global_shape)
|
|
return {d: ind for d, ind in sharding.devices_indices_map(global_shape).items()
|
|
if d.process_index == d.client.process_index()}
|
|
|
|
|
|
@util.use_cpp_class(xc.Sharding)
|
|
class Sharding:
|
|
"""Describes how a :class:`jax.Array` is laid out across devices.
|
|
"""
|
|
|
|
# Abstract methods below that subclasses should implement.
|
|
@property
|
|
def device_set(self) -> set[Device]:
|
|
"""The set of devices that this :class:`Sharding` spans.
|
|
|
|
In multi-controller JAX, the set of devices is global, i.e., includes
|
|
non-addressable devices from other processes.
|
|
"""
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
def devices_indices_map(
|
|
self, global_shape: Shape) -> Mapping[Device, Index | None]:
|
|
"""Returns a mapping from devices to the array slices each contains.
|
|
|
|
The mapping includes all global devices, i.e., including
|
|
non-addressable devices from other processes.
|
|
"""
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
|
"""Returns the shape of the data on each device.
|
|
|
|
The shard shape returned by this function is calculated from
|
|
``global_shape`` and the properties of the sharding.
|
|
"""
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
def is_equivalent_to(self, other: Sharding, ndim: int) -> bool:
|
|
"""Returns ``True`` if two shardings are equivalent.
|
|
|
|
Two shardings are equivalent if they place the same logical array shards on
|
|
the same devices.
|
|
|
|
For example, a :class:`NamedSharding` may be equivalent
|
|
to a :class:`PositionalSharding` if both place the same shards of the array
|
|
on the same devices.
|
|
"""
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
@property
|
|
def is_fully_replicated(self) -> bool:
|
|
"""Is this sharding fully replicated?
|
|
|
|
A sharding is fully replicated if each device has a complete copy of the
|
|
entire data."""
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
@property
|
|
def memory_kind(self) -> str | None:
|
|
"""Returns the memory kind of the sharding."""
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
def with_memory_kind(self, kind: str) -> Sharding:
|
|
"""Returns a new Sharding instance with the specified memory kind."""
|
|
raise NotImplementedError('Subclasses should implement this method')
|
|
|
|
#############################################################################
|
|
# Default implementations below that all subclasses will inherit.
|
|
|
|
@functools.cached_property
|
|
def addressable_devices(self) -> set[Device]:
|
|
"""The set of devices in the :class:`Sharding` that are addressable by the
|
|
current process.
|
|
"""
|
|
# Add a fast path for single controller runtimes.
|
|
if xb.process_count() == 1:
|
|
return self.device_set
|
|
return {d for d in self.device_set
|
|
if d.process_index == d.client.process_index()}
|
|
|
|
@functools.cached_property
|
|
def is_fully_addressable(self) -> bool:
|
|
"""Is this sharding fully addressable?
|
|
|
|
A sharding is fully addressable if the current process can address all of
|
|
the devices named in the :class:`Sharding`.
|
|
"""
|
|
# The pytype disable is because pytype can't recognize a cached property.
|
|
return len(self.device_set) == len(self.addressable_devices) # type: ignore
|
|
|
|
def addressable_devices_indices_map(
|
|
self, global_shape: Shape) -> Mapping[Device, Index | None]:
|
|
"""A mapping from addressable devices to the slice of array data each contains.
|
|
|
|
``addressable_devices_indices_map`` contains that part of
|
|
``device_indices_map`` that applies to the addressable devices.
|
|
"""
|
|
return _addressable_devices_indices_map(self, global_shape)
|