mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
915 lines
30 KiB
Python
915 lines
30 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Module for Pallas:TPU-specific JAX primitives and functions."""
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import enum
|
|
from typing import Any
|
|
|
|
import jax
|
|
from jax._src import core as jax_core
|
|
from jax._src import dtypes
|
|
from jax._src import pretty_printer as pp
|
|
from jax._src import state
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.pallas import core as pl_core
|
|
from jax._src.pallas import utils as pallas_utils
|
|
from jax._src.pallas.mosaic import core as tpu_core
|
|
from jax._src.state import discharge as state_discharge
|
|
from jax._src.state import indexing
|
|
from jax._src.state import primitives as sp
|
|
from jax._src.state.types import Transform
|
|
from jax._src.typing import DTypeLike
|
|
import jax.numpy as jnp
|
|
|
|
Slice = indexing.Slice
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
repeat_p = jax_core.Primitive('repeat')
|
|
|
|
def repeat(x, repeats, axis):
|
|
return repeat_p.bind(x, repeats=repeats, axis=axis)
|
|
|
|
@repeat_p.def_abstract_eval
|
|
def _repeat_abstract_eval(x, *, repeats, axis):
|
|
shape = list(x.shape)
|
|
shape[axis] *= repeats
|
|
return jax_core.ShapedArray(shape, x.dtype)
|
|
|
|
|
|
def _repeat_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, repeats, axis):
|
|
def _repeat(x):
|
|
return jnp.repeat(x, repeats, axis)
|
|
return mlir.lower_fun(_repeat, multiple_results=False)(ctx, x)
|
|
mlir.register_lowering(repeat_p, _repeat_lowering_rule)
|
|
|
|
bitcast_p = jax_core.Primitive("bitcast")
|
|
|
|
|
|
def bitcast(x, ty: DTypeLike):
|
|
ty = dtypes.canonicalize_dtype(ty)
|
|
if len(x.shape) < 2:
|
|
raise ValueError("Not implemented: bitcast 1D")
|
|
src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype)
|
|
dst_bitwidth = pallas_utils.dtype_bitwidth(ty)
|
|
if x.shape[-2] * src_bitwidth % dst_bitwidth:
|
|
raise ValueError(
|
|
"Not implemented: the 2nd minor dim can not be perfectly packed or"
|
|
" unpacked"
|
|
)
|
|
return bitcast_p.bind(x, ty=ty)
|
|
|
|
|
|
@bitcast_p.def_abstract_eval
|
|
def _bitcast_abstract_eval(x, *, ty):
|
|
shape = list(x.shape)
|
|
src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype)
|
|
dst_bitwidth = pallas_utils.dtype_bitwidth(ty)
|
|
shape[-2] = shape[-2] * src_bitwidth // dst_bitwidth
|
|
return jax_core.ShapedArray(shape, ty)
|
|
|
|
|
|
def _bitcast_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, ty):
|
|
def _bitcast(x):
|
|
src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype)
|
|
dst_bitwidth = pallas_utils.dtype_bitwidth(ty)
|
|
if src_bitwidth < dst_bitwidth:
|
|
*leading, m, n = x.shape
|
|
packing = dst_bitwidth // src_bitwidth
|
|
x = x.reshape(*leading, m // packing, packing, n)
|
|
x = jnp.swapaxes(x, -1, -2)
|
|
return jax.lax.bitcast_convert_type(x, ty)
|
|
if src_bitwidth > dst_bitwidth:
|
|
y = jax.lax.bitcast_convert_type(x, ty)
|
|
*leading, m, n, packing = y.shape
|
|
return jnp.swapaxes(y, -1, -2).reshape(*leading, m * packing, n)
|
|
return jax.lax.bitcast_convert_type(x, ty)
|
|
|
|
return mlir.lower_fun(_bitcast, multiple_results=False)(ctx, x)
|
|
|
|
|
|
mlir.register_lowering(bitcast_p, _bitcast_lowering_rule)
|
|
|
|
roll_p = jax_core.Primitive("roll")
|
|
|
|
|
|
def roll(
|
|
x,
|
|
shift,
|
|
axis: int,
|
|
*,
|
|
stride: int | None = None,
|
|
stride_axis: int | None = None,
|
|
):
|
|
if isinstance(shift, int) and shift < 0:
|
|
raise ValueError("shift must be non-negative.")
|
|
if axis < 0 or axis >= len(x.shape):
|
|
raise ValueError("axis is out of range.")
|
|
if (stride is None) != (stride_axis is None):
|
|
raise ValueError("stride and stride_axis must be both specified or not.")
|
|
if stride is not None and stride_axis is not None:
|
|
if stride < 0:
|
|
raise ValueError("stride must be non-negative.")
|
|
if stride_axis < 0 or stride_axis >= len(x.shape):
|
|
raise ValueError("stride_axis is out of range")
|
|
if axis == stride_axis:
|
|
raise ValueError("expected axis and stride_axis are different.")
|
|
return roll_p.bind(
|
|
x, shift, axis=axis, stride=stride, stride_axis=stride_axis
|
|
)
|
|
|
|
|
|
@roll_p.def_abstract_eval
|
|
def _roll_abstract_eval(x, shift, **_):
|
|
del shift
|
|
return jax_core.raise_to_shaped(x)
|
|
|
|
|
|
def _roll_lowering_rule(
|
|
ctx: mlir.LoweringRuleContext, x, shift, *, axis, stride, stride_axis
|
|
):
|
|
def _roll(x, shift):
|
|
if stride is None:
|
|
return jnp.roll(x, shift, axis)
|
|
outputs = [
|
|
jnp.roll(xs, shift + i * stride, axis)
|
|
for i, xs in enumerate(jnp.split(x, x.shape[stride_axis], stride_axis))
|
|
]
|
|
return jnp.concatenate(outputs, stride_axis)
|
|
|
|
return mlir.lower_fun(_roll, multiple_results=False)(ctx, x, shift)
|
|
|
|
|
|
mlir.register_lowering(roll_p, _roll_lowering_rule)
|
|
|
|
|
|
class DeviceIdType(enum.Enum):
|
|
MESH = "mesh"
|
|
LOGICAL = "logical"
|
|
|
|
|
|
def check_sem_avals(
|
|
sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None
|
|
):
|
|
if allowed_semaphore_types is None:
|
|
allowed_semaphore_types = {
|
|
tpu_core.semaphore,
|
|
tpu_core.barrier_semaphore,
|
|
# For interpret mode.
|
|
pl_core.SEMAPHORE_INTERPRET_DTYPE,
|
|
}
|
|
if not isinstance(sem_aval, state.AbstractRef):
|
|
raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}")
|
|
sem_shape = sem_aval.shape
|
|
if sem_transforms_avals:
|
|
sem_shape = sem_transforms_avals[-1].get_indexer_shape()
|
|
if sem_shape:
|
|
raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}")
|
|
sem_dtype = sem_aval.dtype
|
|
if not any(
|
|
jnp.issubdtype(sem_dtype, sem_type)
|
|
for sem_type in allowed_semaphore_types
|
|
):
|
|
raise ValueError(
|
|
f"Must {name} semaphores of the following types:"
|
|
f" {allowed_semaphore_types}."
|
|
)
|
|
|
|
|
|
def _transform_semaphore(ref_value, transforms, ref_aval):
|
|
"""Helper function for indexing into a semaphore during state_discharge."""
|
|
if ref_value.shape == ref_aval.shape:
|
|
return state_discharge.transform_array(ref_value, transforms)
|
|
elif len(ref_value.shape) == 0:
|
|
return ref_value
|
|
else:
|
|
raise ValueError(
|
|
f"Semaphore value shape {ref_value.shape} does not match aval shape"
|
|
f" {ref_aval.shape}"
|
|
)
|
|
|
|
|
|
semaphore_read_p = jax_core.Primitive("semaphore_read")
|
|
semaphore_read_p.multiple_results = False
|
|
|
|
|
|
def semaphore_read(sem_or_view):
|
|
ref, transforms = _get_ref_and_transforms(sem_or_view)
|
|
args = [ref, transforms]
|
|
flat_args, args_tree = tree_util.tree_flatten(args)
|
|
return semaphore_read_p.bind(*flat_args, args_tree=args_tree)
|
|
|
|
@semaphore_read_p.def_abstract_eval
|
|
def _semaphore_read_abstract_eval(
|
|
*avals,
|
|
args_tree,
|
|
):
|
|
sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals)
|
|
check_sem_avals(
|
|
sem_aval,
|
|
sem_transforms_avals,
|
|
"read",
|
|
allowed_semaphore_types={
|
|
tpu_core.dma_semaphore,
|
|
tpu_core.semaphore,
|
|
tpu_core.barrier_semaphore,
|
|
pl_core.SEMAPHORE_INTERPRET_DTYPE,
|
|
},
|
|
)
|
|
return jax_core.ShapedArray((), jnp.dtype("int32"))
|
|
|
|
def _semaphore_read_discharge_rule(in_avals,
|
|
out_avals,
|
|
*flat_args,
|
|
args_tree):
|
|
del out_avals
|
|
[ref, transforms] = args_tree.unflatten(flat_args)
|
|
sem_value = _transform_semaphore(ref, transforms, in_avals[0])
|
|
sem_value = sem_value.astype(jnp.int32)
|
|
return (None,) * len(in_avals), sem_value
|
|
state_discharge.register_discharge_rule(semaphore_read_p)(
|
|
_semaphore_read_discharge_rule
|
|
)
|
|
|
|
|
|
semaphore_signal_p = jax_core.Primitive('semaphore_signal')
|
|
semaphore_signal_p.multiple_results = True
|
|
|
|
|
|
def semaphore_signal(
|
|
sem_or_view,
|
|
inc: int | jax.Array = 1,
|
|
*,
|
|
device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None,
|
|
device_id_type: DeviceIdType = DeviceIdType.MESH,
|
|
core_index: int | jax.Array | None = None,
|
|
):
|
|
ref, transforms = _get_ref_and_transforms(sem_or_view)
|
|
inc = jnp.asarray(inc, dtype=jnp.int32)
|
|
args = [ref, transforms, inc, device_id, core_index]
|
|
flat_args, args_tree = tree_util.tree_flatten(args)
|
|
semaphore_signal_p.bind(
|
|
*flat_args,
|
|
args_tree=args_tree,
|
|
device_id_type=device_id_type,
|
|
)
|
|
|
|
|
|
@semaphore_signal_p.def_abstract_eval
|
|
def _semaphore_signal_abstract_eval(
|
|
*avals,
|
|
args_tree,
|
|
device_id_type: DeviceIdType,
|
|
):
|
|
del device_id_type
|
|
(
|
|
sem_aval,
|
|
sem_transforms_avals,
|
|
value_aval,
|
|
device_id_avals,
|
|
core_index_aval,
|
|
) = tree_util.tree_unflatten(args_tree, avals)
|
|
check_sem_avals(sem_aval, sem_transforms_avals, "signal")
|
|
if value_aval.dtype != jnp.dtype("int32"):
|
|
raise ValueError("Must signal an int32 value.")
|
|
if device_id_avals is not None:
|
|
device_id_flat_avals = tree_util.tree_leaves(device_id_avals)
|
|
for aval in device_id_flat_avals:
|
|
if aval.dtype != jnp.dtype("int32"):
|
|
raise ValueError("`device_id`s must be an int32 value.")
|
|
return []
|
|
|
|
|
|
def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn,
|
|
context: jax_core.JaxprPpContext,
|
|
settings: jax_core.JaxprPpSettings):
|
|
del settings
|
|
invars = eqn.invars
|
|
tree = eqn.params["args_tree"]
|
|
(
|
|
sem,
|
|
sem_transforms,
|
|
value,
|
|
device_ids,
|
|
_,
|
|
) = tree_util.tree_unflatten(tree, invars)
|
|
out = pp.concat([
|
|
pp.text("semaphore_signal"),
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, sem, sem_transforms),
|
|
pp.text(" "),
|
|
pp.text(jax_core.pp_var(value, context)),
|
|
])
|
|
if device_ids is not None:
|
|
flat_device_ids = tree_util.tree_leaves(device_ids)
|
|
if not flat_device_ids:
|
|
return out
|
|
device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))]
|
|
for device_id in flat_device_ids[1:]:
|
|
device_ids_pp.append(pp.text(" "))
|
|
device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context)))
|
|
out = pp.concat([out, pp.concat(device_ids_pp)])
|
|
return out
|
|
jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn
|
|
|
|
|
|
def _semaphore_signal_discharge_rule(in_avals,
|
|
out_avals,
|
|
*flat_args,
|
|
args_tree,
|
|
device_id_type):
|
|
del out_avals, device_id_type
|
|
[ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args)
|
|
if device_id is not None:
|
|
raise NotImplementedError("Remote signal not implemented.")
|
|
if core_index is not None:
|
|
raise NotImplementedError("Multiple core support not implemented.")
|
|
sem_value = _transform_semaphore(ref, transforms, in_avals[0])
|
|
inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
|
_, new_sem_value = state_discharge.transform_swap_array(
|
|
ref, transforms, sem_value + inc
|
|
)
|
|
return (new_sem_value,) + (None,) * (len(in_avals) - 1), ()
|
|
state_discharge.register_discharge_rule(semaphore_signal_p)(
|
|
_semaphore_signal_discharge_rule
|
|
)
|
|
|
|
|
|
semaphore_wait_p = jax_core.Primitive('semaphore_wait')
|
|
semaphore_wait_p.multiple_results = True
|
|
|
|
def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
|
|
ref, transforms = _get_ref_and_transforms(sem_or_view)
|
|
dec = jnp.asarray(dec, dtype=jnp.int32)
|
|
args = [ref, transforms, dec]
|
|
flat_args, args_tree = tree_util.tree_flatten(args)
|
|
semaphore_wait_p.bind(*flat_args, args_tree=args_tree)
|
|
|
|
@semaphore_wait_p.def_abstract_eval
|
|
def _semaphore_wait_abstract_eval(*avals, args_tree):
|
|
sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten(
|
|
args_tree, avals
|
|
)
|
|
check_sem_avals(sem_aval, sem_transforms_avals, "wait")
|
|
if value_aval.dtype != jnp.dtype("int32"):
|
|
raise ValueError("Must wait an int32 value.")
|
|
return []
|
|
|
|
def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
|
context: jax_core.JaxprPpContext,
|
|
settings: jax_core.JaxprPpSettings):
|
|
del settings
|
|
invars = eqn.invars
|
|
tree = eqn.params["args_tree"]
|
|
(
|
|
sem,
|
|
sem_transforms,
|
|
value,
|
|
) = tree_util.tree_unflatten(tree, invars)
|
|
return pp.concat([
|
|
pp.text("semaphore_wait"),
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, sem, sem_transforms),
|
|
pp.text(" "),
|
|
pp.text(jax_core.pp_var(value, context)),
|
|
])
|
|
jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn
|
|
|
|
def _semaphore_wait_discharge_rule(in_avals,
|
|
out_avals,
|
|
*flat_args,
|
|
args_tree):
|
|
del out_avals
|
|
[ref, transforms, dec] = args_tree.unflatten(flat_args)
|
|
sem_value = _transform_semaphore(ref, transforms, in_avals[0])
|
|
dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
|
_, new_sem_value = state_discharge.transform_swap_array(
|
|
ref, transforms, sem_value - dec
|
|
)
|
|
return (new_sem_value,) + (None,) * (len(in_avals) - 1), ()
|
|
state_discharge.register_discharge_rule(semaphore_wait_p)(
|
|
_semaphore_wait_discharge_rule
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AsyncCopyDescriptor:
|
|
src_ref: Any
|
|
src_transforms: tuple[Transform, ...]
|
|
dst_ref: Any
|
|
dst_transforms: tuple[Transform, ...]
|
|
dst_sem: int | jax.Array
|
|
dst_sem_transforms: tuple[Transform, ...]
|
|
src_sem: int | jax.Array | None
|
|
src_sem_transforms: tuple[Transform, ...] | None
|
|
device_id: int | jax.Array | None
|
|
device_id_type: DeviceIdType = DeviceIdType.MESH
|
|
|
|
def __post_init__(self):
|
|
if (self.src_sem is None) ^ (self.device_id is None):
|
|
raise ValueError("Either both or neither `src_sem` and `device_id` "
|
|
"can be set.")
|
|
|
|
@property
|
|
def is_remote(self):
|
|
return self.src_sem is not None
|
|
|
|
def _get_args_and_tree(self, swap_src_and_dst: bool = False):
|
|
if swap_src_and_dst:
|
|
return tree_util.tree_flatten((
|
|
self.dst_ref,
|
|
self.dst_transforms,
|
|
self.src_ref,
|
|
self.src_transforms,
|
|
self.src_sem,
|
|
self.src_sem_transforms,
|
|
self.dst_sem,
|
|
self.dst_sem_transforms,
|
|
self.device_id,
|
|
))
|
|
else:
|
|
return tree_util.tree_flatten((
|
|
self.src_ref,
|
|
self.src_transforms,
|
|
self.dst_ref,
|
|
self.dst_transforms,
|
|
self.dst_sem,
|
|
self.dst_sem_transforms,
|
|
self.src_sem,
|
|
self.src_sem_transforms,
|
|
self.device_id,
|
|
))
|
|
|
|
def start(self):
|
|
flat_args, tree = self._get_args_and_tree()
|
|
dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type)
|
|
|
|
def wait(self):
|
|
if self.is_remote:
|
|
self.wait_send()
|
|
self.wait_recv()
|
|
|
|
def wait_recv(self):
|
|
flat_args, tree = self._get_args_and_tree()
|
|
dma_wait_p.bind(
|
|
*flat_args, tree=tree, device_id_type=self.device_id_type
|
|
)
|
|
|
|
def wait_send(self):
|
|
if not self.is_remote:
|
|
raise ValueError("Cannot `wait_send` on a local copy.")
|
|
# We swap src and dst since by default dma_wait_p waits on the dst_sem
|
|
# As a clean up, maybe we could modify the primitive to have a
|
|
# `wait_on_send` bool.
|
|
flat_args, tree = self._get_args_and_tree(swap_src_and_dst=True)
|
|
dma_wait_p.bind(
|
|
*flat_args, tree=tree, device_id_type=self.device_id_type
|
|
)
|
|
|
|
|
|
dma_start_p = jax_core.Primitive('dma_start')
|
|
dma_start_p.multiple_results = True
|
|
|
|
@dma_start_p.def_effectful_abstract_eval
|
|
def _dma_start_abstract_eval(*args, tree, device_id_type):
|
|
(
|
|
src_ref_aval,
|
|
src_transforms_avals,
|
|
dst_ref_aval,
|
|
dst_transforms_avals,
|
|
dst_sem_aval,
|
|
dst_sem_transforms_avals,
|
|
src_sem_aval,
|
|
src_sem_transforms_avals,
|
|
device_id_aval,
|
|
) = tree_util.tree_unflatten(tree, args)
|
|
dst_sem_shape = dst_sem_aval.shape
|
|
if dst_sem_transforms_avals:
|
|
dst_sem_shape = dst_sem_transforms_avals[-1].get_indexer_shape()
|
|
if dst_sem_shape:
|
|
raise ValueError(
|
|
f"Cannot signal on a non-()-shaped semaphore: {dst_sem_shape}"
|
|
)
|
|
if src_sem_aval is not None:
|
|
src_sem_shape = src_sem_aval.shape
|
|
if src_sem_transforms_avals:
|
|
src_sem_shape = src_sem_transforms_avals[-1].get_indexer_shape()
|
|
if src_sem_shape:
|
|
raise ValueError(
|
|
f"Cannot signal on a non-()-shaped semaphore: {src_sem_shape}"
|
|
)
|
|
n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals))
|
|
return [], {state.ReadEffect(0), state.WriteEffect(n_src_transforms + 1)}
|
|
|
|
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
|
context: jax_core.JaxprPpContext,
|
|
settings: jax_core.JaxprPpSettings):
|
|
invars = eqn.invars
|
|
tree = eqn.params["tree"]
|
|
(
|
|
src_ref,
|
|
src_transforms,
|
|
dst_ref,
|
|
dst_transforms,
|
|
dst_sem,
|
|
dst_sem_transforms,
|
|
src_sem,
|
|
src_sem_transforms,
|
|
device_id,
|
|
) = tree_util.tree_unflatten(tree, invars)
|
|
del src_sem_transforms
|
|
# TODO(sharadmv): pretty print source semaphores and device id
|
|
if src_sem or device_id:
|
|
return jax_core._pp_eqn(eqn, context, settings)
|
|
return pp.concat([
|
|
pp.text("dma_start"),
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, src_ref, src_transforms),
|
|
pp.text(" -> "),
|
|
sp.pp_ref_transforms(context, dst_ref, dst_transforms),
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, dst_sem, dst_sem_transforms),
|
|
])
|
|
|
|
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
|
|
|
|
def dma_start_discharge_rule(in_avals, out_avals,
|
|
*args, tree, device_id_type):
|
|
(
|
|
src_ref,
|
|
src_transforms,
|
|
dst_ref,
|
|
dst_transforms,
|
|
dst_sem,
|
|
dst_sem_transforms,
|
|
src_sem,
|
|
src_sem_transforms,
|
|
device_id,
|
|
) = tree_util.tree_unflatten(tree, args)
|
|
(
|
|
_,
|
|
src_transforms_avals,
|
|
_,
|
|
dst_transforms_avals,
|
|
dst_sem_aval,
|
|
dst_sem_transforms_avals,
|
|
src_sem_aval,
|
|
src_sem_transforms_avals,
|
|
_,
|
|
) = tree_util.tree_unflatten(tree, in_avals)
|
|
del out_avals
|
|
is_remote = device_id is not None
|
|
if not is_remote:
|
|
# Local async copies only use one semaphore.
|
|
assert src_sem is None
|
|
assert src_sem_transforms is None
|
|
|
|
num_src_sem_transforms = len(tree_util.tree_leaves(src_sem_transforms_avals))
|
|
num_dst_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
|
|
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
|
|
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))
|
|
|
|
updates = state_discharge.transform_array(src_ref, src_transforms)
|
|
local_src = updates
|
|
|
|
if is_remote:
|
|
# Note that this code only works in SPMD mode. If not all devices execute
|
|
# the DMA then the devices that do will hang.
|
|
# TODO(justinfu): Verify that code only works in SPMD mode.
|
|
axis_env = jax_core.get_axis_env()
|
|
nonempty_axes = [name for name in axis_env.axis_sizes if name is not None]
|
|
if device_id_type == DeviceIdType.LOGICAL:
|
|
if len(nonempty_axes) > 1:
|
|
raise NotImplementedError("Sharding with more than one named axis not "
|
|
"implemented in dma_start_p for LOGICAL "
|
|
"device_id_type.")
|
|
shard_axis = nonempty_axes[0]
|
|
my_axis = jax.lax.axis_index(shard_axis)
|
|
elif device_id_type == DeviceIdType.MESH:
|
|
device_id_len = 1
|
|
if isinstance(device_id, jax.Array):
|
|
device_id_len = device_id.size
|
|
elif hasattr(device_id, '__len__'):
|
|
device_id_len = len(device_id)
|
|
if device_id_len != len(axis_env.axis_sizes):
|
|
raise ValueError(
|
|
f"device_id ({device_id_len}) and mesh ({len(axis_env.axis_sizes)}) "
|
|
"must have same length.")
|
|
if device_id_len > 1 or len(nonempty_axes) > 1:
|
|
raise NotImplementedError("Meshes with more than 1 named dimension not "
|
|
"implemented in dma_start_p")
|
|
shard_axis = nonempty_axes[0].name
|
|
my_axis = jax.lax.axis_index(shard_axis)
|
|
else:
|
|
raise ValueError(f"Unknown device_id_type: {device_id_type}")
|
|
# Compute the update that is being sent to the current device.
|
|
who_copy_to_me = jax.lax.all_gather(device_id, shard_axis) == my_axis
|
|
# TODO(justinfu): Add a checkify for verifying there is at most one source.
|
|
# TODO(justinfu): Handle the case where no other device is copying to
|
|
# this device.
|
|
index = jnp.argmax(who_copy_to_me, axis=0)
|
|
global_updates = jax.lax.all_gather(updates, shard_axis)
|
|
updates = jax.lax.dynamic_index_in_dim(
|
|
global_updates, index, axis=0, keepdims=False)
|
|
|
|
# Handle asymmetrical indexing when devices do not share the same
|
|
# dst_transform.
|
|
global_dst_transforms = tree_util.tree_map(
|
|
lambda x: jax.lax.all_gather(x, shard_axis), dst_transforms
|
|
)
|
|
dst_transforms = tree_util.tree_map(
|
|
lambda x: jax.lax.dynamic_index_in_dim(
|
|
x, index, axis=0, keepdims=False
|
|
),
|
|
global_dst_transforms,
|
|
)
|
|
|
|
_, new_dst = state_discharge.transform_swap_array(
|
|
dst_ref, dst_transforms, updates
|
|
)
|
|
|
|
# Update semaphore values.
|
|
# TODO(justinfu): Potentially handle asymmetric copy sizes.
|
|
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
|
|
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
|
dst_sem_value = _transform_semaphore(
|
|
dst_sem, dst_sem_transforms, dst_sem_aval
|
|
)
|
|
_, new_dst_sem = state_discharge.transform_swap_array(
|
|
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
|
|
)
|
|
if is_remote:
|
|
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
|
|
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
|
src_sem_value = _transform_semaphore(
|
|
src_sem, src_sem_transforms, src_sem_aval
|
|
)
|
|
_, new_src_sem = state_discharge.transform_swap_array(
|
|
src_sem, src_sem_transforms, src_sem_value + send_size
|
|
)
|
|
else:
|
|
new_src_sem = None
|
|
|
|
new_vals = (None,) # src_val
|
|
new_vals += (None,) * num_src_transform_vals
|
|
new_vals += (new_dst,) # dst_val
|
|
new_vals += (None,) * num_dst_transform_vals
|
|
new_vals += (new_dst_sem,) # dst_sem
|
|
new_vals += (None,) * num_dst_sem_transforms
|
|
if is_remote:
|
|
new_vals += (new_src_sem,) # src_sem
|
|
new_vals += (None,) * num_src_sem_transforms
|
|
new_vals += (None,) # device_id
|
|
assert (len(new_vals) ==
|
|
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"
|
|
return new_vals, []
|
|
|
|
state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
|
|
|
|
|
|
dma_wait_p = jax_core.Primitive('dma_wait')
|
|
dma_wait_p.multiple_results = True
|
|
|
|
@dma_wait_p.def_abstract_eval
|
|
def _dma_wait_abstract_eval(*args, tree, device_id_type):
|
|
del args, tree, device_id_type
|
|
return []
|
|
|
|
def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
|
context: jax_core.JaxprPpContext,
|
|
settings: jax_core.JaxprPpSettings):
|
|
del settings
|
|
invars = eqn.invars
|
|
tree = eqn.params["tree"]
|
|
(
|
|
_,
|
|
_,
|
|
ref,
|
|
transforms,
|
|
sem,
|
|
sem_transforms,
|
|
_,
|
|
_,
|
|
_,
|
|
) = tree_util.tree_unflatten(tree, invars)
|
|
return pp.concat([
|
|
pp.text("dma_wait"),
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, ref, transforms),
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, sem, sem_transforms),
|
|
])
|
|
|
|
jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
|
|
|
|
def dma_wait_discharge_rule(in_avals, out_avals,
|
|
*args, tree, device_id_type):
|
|
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
|
|
del out_avals, device_id_type
|
|
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
|
|
tree_util.tree_unflatten(tree, args))
|
|
(_,
|
|
src_ref_transforms_avals,
|
|
_,
|
|
dst_ref_transforms_avals,
|
|
dst_sem_aval,
|
|
dst_sem_transforms_avals,
|
|
src_sem_aval,
|
|
src_sem_transforms_avals,
|
|
device_id_aval,
|
|
) = tree_util.tree_unflatten(tree, in_avals)
|
|
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
|
|
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
|
|
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
|
|
copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
|
|
copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
|
sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval)
|
|
_, new_sem = state_discharge.transform_swap_array(
|
|
dst_sem, dst_sem_transforms, sem_value - copy_size
|
|
)
|
|
new_vals = (None,) # src_ref
|
|
new_vals += (None,) * len(tree_util.tree_leaves(src_ref_transforms_avals))
|
|
new_vals += (None,) # ref
|
|
new_vals += (None,) * num_transforms # ref_transforms
|
|
new_vals += (new_sem,) # sem
|
|
new_vals += (None,) * num_sem_transforms
|
|
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_aval)) # src_sem
|
|
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
|
|
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
|
|
return new_vals, []
|
|
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
|
|
|
|
def _get_ref_and_transforms(ref):
|
|
if isinstance(ref, state.TransformedRef):
|
|
return ref.ref, ref.transforms
|
|
return ref, ()
|
|
|
|
def make_async_copy(src_ref, dst_ref, sem):
|
|
"""Issues a DMA copying from src_ref to dst_ref."""
|
|
src_ref, src_transforms = _get_ref_and_transforms(src_ref)
|
|
dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref)
|
|
sem, sem_transforms = _get_ref_and_transforms(sem)
|
|
return AsyncCopyDescriptor(
|
|
src_ref,
|
|
src_transforms,
|
|
dst_ref,
|
|
dst_transforms,
|
|
sem,
|
|
sem_transforms,
|
|
None,
|
|
None,
|
|
None,
|
|
DeviceIdType.MESH,
|
|
)
|
|
|
|
def async_copy(src_ref, dst_ref, sem):
|
|
"""Issues a DMA copying from src_ref to dst_ref."""
|
|
copy_descriptor = make_async_copy(src_ref, dst_ref, sem)
|
|
copy_descriptor.start()
|
|
return copy_descriptor
|
|
|
|
def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
|
|
device_id_type: DeviceIdType = DeviceIdType.MESH):
|
|
"""Creates a description of a remote copy operation.
|
|
|
|
Copies data from src_ref on the current device to dst_ref on the device
|
|
specified by device_id. Both semaphores should be waited on using the
|
|
descriptor on both source and target devices.
|
|
|
|
Note that device_id can also refer to the current device.
|
|
|
|
Args:
|
|
src_ref: The source Reference.
|
|
dst_ref: The destination Reference.
|
|
send_sem: The semaphore on the source device.
|
|
recv_sem: The semaphore on the destination device.
|
|
device_id: The device id of the destination device.
|
|
device_id_type: The type of the device id.
|
|
Returns:
|
|
An AsyncCopyDescriptor.
|
|
"""
|
|
src_ref, src_transforms = _get_ref_and_transforms(src_ref)
|
|
send_sem, send_sem_transforms = _get_ref_and_transforms(send_sem)
|
|
dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref)
|
|
recv_sem, recv_sem_transforms = _get_ref_and_transforms(recv_sem)
|
|
return AsyncCopyDescriptor(
|
|
src_ref,
|
|
src_transforms,
|
|
dst_ref,
|
|
dst_transforms,
|
|
recv_sem,
|
|
recv_sem_transforms,
|
|
send_sem,
|
|
send_sem_transforms,
|
|
device_id,
|
|
device_id_type=device_id_type,
|
|
)
|
|
|
|
def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
|
|
device_id_type: DeviceIdType = DeviceIdType.MESH):
|
|
copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem,
|
|
device_id, device_id_type)
|
|
copy_descriptor.start()
|
|
return copy_descriptor
|
|
|
|
device_id_p = jax_core.Primitive('device_id')
|
|
|
|
@device_id_p.def_abstract_eval
|
|
def _device_id_abstract_eval():
|
|
return jax_core.ShapedArray((), jnp.dtype("int32"))
|
|
|
|
device_id = device_id_p.bind
|
|
|
|
get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore')
|
|
|
|
@get_barrier_semaphore_p.def_abstract_eval
|
|
def _get_barrier_semaphore_abstract_eval():
|
|
return pl_core.AbstractMemoryRef(
|
|
jax_core.ShapedArray((), tpu_core.BarrierSemaphoreTy()),
|
|
tpu_core.TPUMemorySpace.SEMAPHORE,
|
|
)
|
|
|
|
def get_barrier_semaphore():
|
|
"""Returns a barrier semaphore.
|
|
|
|
This function returns a barrier semaphore based on the collective_id of the
|
|
current pallas kernel.
|
|
|
|
It's very important that the semaphore is wait-ed back down to 0, or else the
|
|
semaphores will become corrupted.
|
|
|
|
It's also very important that the collective_id is different for each pallas
|
|
kernel with communication. E.g. if you have two pallas kernels, one that syncs
|
|
across the X axis of the device mesh and the second that syncs across the Y
|
|
axis, they must have different collective_ids.
|
|
However it is legal for two kernels that perform the same synchronization
|
|
pattern (e.g. only communicating with neighbours on the same mesh axis)
|
|
to share a collective_id. However, if in doubt, prefer not sharing
|
|
collective_ids, as doing so incorrectly can lead to silent data corruption or
|
|
crashes.
|
|
Note that re-using the same collective_id doesn't guarantee that the same
|
|
semaphore is provided by XLA.
|
|
"""
|
|
return get_barrier_semaphore_p.bind()
|
|
|
|
delay_p = jax_core.Primitive("delay")
|
|
delay_p.multiple_results = True
|
|
|
|
|
|
@delay_p.def_abstract_eval
|
|
def _delay_abstract_eval(nanos):
|
|
del nanos
|
|
return []
|
|
|
|
|
|
def delay(nanos):
|
|
"""Delays vector execution for the given number of nanosconds."""
|
|
delay_p.bind(nanos)
|
|
|
|
|
|
# RNG Ops
|
|
prng_seed_p = jax_core.Primitive("prng_seed")
|
|
prng_seed_p.multiple_results = True
|
|
|
|
@prng_seed_p.def_abstract_eval
|
|
def _(*_):
|
|
return []
|
|
|
|
|
|
def prng_seed(*seeds: int | jax.Array) -> None:
|
|
"""Sets the seed for PRNG.
|
|
|
|
Args:
|
|
seeds: One or more integer seeds for setting the PRNG seed. If
|
|
more than one seed is passed in, the seed material will be
|
|
mixed before setting the internal PRNG state.
|
|
"""
|
|
prng_seed_p.bind(*seeds)
|
|
|
|
prng_random_bits_p = jax_core.Primitive(
|
|
'prng_random_bits')
|
|
|
|
@prng_random_bits_p.def_abstract_eval
|
|
def _(*, shape):
|
|
return jax_core.ShapedArray(shape, jnp.dtype("int32"))
|
|
|
|
def prng_random_bits(shape):
|
|
return prng_random_bits_p.bind(shape=shape)
|