mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add experimental JAX roofline API.
This commit is contained in:
parent
a212a29dc6
commit
8c521547b7
@ -227,6 +227,7 @@ py_library_providing_imports_info(
|
||||
"_src/state/**/*.py",
|
||||
"_src/third_party/**/*.py",
|
||||
"experimental/key_reuse/**/*.py",
|
||||
"experimental/roofline/**/*.py",
|
||||
"image/**/*.py",
|
||||
"interpreters/**/*.py",
|
||||
"lax/**/*.py",
|
||||
|
29
jax/experimental/roofline/__init__.py
Normal file
29
jax/experimental/roofline/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax.experimental.roofline.roofline import (
|
||||
RooflineRuleContext as RooflineRuleContext,
|
||||
)
|
||||
from jax.experimental.roofline.roofline import RooflineShape as RooflineShape
|
||||
from jax.experimental.roofline.roofline import RooflineResult as RooflineResult
|
||||
from jax.experimental.roofline.roofline import roofline as roofline
|
||||
from jax.experimental.roofline.roofline import register_roofline as register_roofline
|
||||
from jax.experimental.roofline.roofline import (
|
||||
register_standard_roofline as register_standard_roofline,
|
||||
)
|
||||
from jax.experimental.roofline.roofline import roofline_and_grad as roofline_and_grad
|
||||
|
||||
|
||||
import jax.experimental.roofline.rooflines as rooflines
|
||||
|
||||
del rooflines
|
342
jax/experimental/roofline/roofline.py
Normal file
342
jax/experimental/roofline/roofline.py
Normal file
@ -0,0 +1,342 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Protocol, Sequence
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api import make_jaxpr
|
||||
from jax._src.interpreters.partial_eval import dce_jaxpr
|
||||
from jax._src.interpreters.xla import abstractify
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
|
||||
from jax.experimental import shard_map
|
||||
|
||||
|
||||
ShapeDtypeStructTree = Any
|
||||
|
||||
|
||||
map = util.safe_map
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True, kw_only=True)
|
||||
class RooflineRuleContext:
|
||||
name_stack: source_info_util.NameStack
|
||||
primitive: core.Primitive
|
||||
avals_in: Sequence[core.AbstractValue]
|
||||
avals_out: Sequence[core.AbstractValue]
|
||||
jaxpr_eqn_ctx: core.JaxprEqnContext
|
||||
mesh: Mesh | AbstractMesh
|
||||
pin_lhs_in_vmem: bool
|
||||
pin_rhs_in_vmem: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True, kw_only=True)
|
||||
class RooflineShape:
|
||||
shape: tuple[int, ...]
|
||||
dtype: np.dtype
|
||||
|
||||
@classmethod
|
||||
def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape":
|
||||
if not isinstance(aval, core.ShapedArray):
|
||||
raise TypeError(f"Expected ShapedArray, got {type(aval)}.")
|
||||
if not isinstance(aval.dtype, np.dtype):
|
||||
raise TypeError(f"Expected numpy dtype, got {type(aval.dtype)}.")
|
||||
return cls(shape=aval.shape, dtype=aval.dtype)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return int(np.prod(self.shape))
|
||||
|
||||
@property
|
||||
def bytes(self) -> int:
|
||||
return int(self.size * self.dtype.itemsize)
|
||||
|
||||
@classmethod
|
||||
def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int:
|
||||
return sum(cls.from_aval(aval).bytes for aval in avals)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True, kw_only=True)
|
||||
class RooflineResult:
|
||||
flops: int = 0
|
||||
ici_bytes: dict[str, int] = field(default_factory=dict)
|
||||
ici_latency: dict[str, int] = field(default_factory=dict)
|
||||
hbm_bytes: int = 0
|
||||
peak_hbm_bytes: int = 0
|
||||
|
||||
@classmethod
|
||||
def zeros(cls) -> "RooflineResult":
|
||||
return cls()
|
||||
|
||||
def __add__(self, other: "RooflineResult") -> "RooflineResult":
|
||||
def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]:
|
||||
return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)}
|
||||
|
||||
return RooflineResult(
|
||||
flops=self.flops + other.flops,
|
||||
ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes),
|
||||
ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency),
|
||||
hbm_bytes=self.hbm_bytes + other.hbm_bytes,
|
||||
peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes),
|
||||
)
|
||||
|
||||
def __mul__(self, constant: int | float) -> "RooflineResult":
|
||||
return RooflineResult(
|
||||
flops=int(self.flops * constant),
|
||||
ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()},
|
||||
ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()},
|
||||
hbm_bytes=int(self.hbm_bytes * constant),
|
||||
peak_hbm_bytes=int(self.peak_hbm_bytes * constant),
|
||||
)
|
||||
|
||||
def __rmul__(self, constant: int | float) -> "RooflineResult":
|
||||
return self.__mul__(constant)
|
||||
|
||||
|
||||
class _RooflineRule(Protocol):
|
||||
def __call__(
|
||||
self, ctx: RooflineRuleContext, *args: RooflineShape, **kw
|
||||
) -> RooflineResult: ...
|
||||
|
||||
|
||||
_rooflines: dict[core.Primitive, _RooflineRule] = {}
|
||||
|
||||
|
||||
def _roofline_interpreter(
|
||||
f_name: str,
|
||||
jaxpr: core.Jaxpr,
|
||||
mesh: Mesh | AbstractMesh,
|
||||
*,
|
||||
pin_lhs_in_vmem: bool = False,
|
||||
pin_rhs_in_vmem: bool = False,
|
||||
) -> RooflineResult:
|
||||
name_stack = source_info_util.new_name_stack(util.wrap_name(f_name, "roofline"))
|
||||
|
||||
result = RooflineResult.zeros()
|
||||
|
||||
env: dict[core.Var, RooflineShape] = {}
|
||||
|
||||
def write(v: core.Var, node: RooflineShape):
|
||||
assert node is not None
|
||||
env[v] = node
|
||||
|
||||
def read(v: core.Atom) -> RooflineShape:
|
||||
if type(v) is core.Literal:
|
||||
return RooflineShape.from_aval(abstractify(v.val))
|
||||
else:
|
||||
assert isinstance(v, core.Var)
|
||||
return env[v]
|
||||
|
||||
def aval(v: core.Atom) -> core.AbstractValue:
|
||||
if type(v) is core.Literal:
|
||||
return abstractify(v.val)
|
||||
else:
|
||||
return v.aval
|
||||
|
||||
def calculate_peak_hbm_bytes() -> int:
|
||||
return int(
|
||||
sum(np.prod(shape.shape) * shape.dtype.itemsize for shape in env.values())
|
||||
)
|
||||
|
||||
make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x))
|
||||
map(
|
||||
write,
|
||||
jaxpr.constvars,
|
||||
map(make_roofline_shape, jaxpr.constvars),
|
||||
)
|
||||
map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
|
||||
last_used = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=name_stack + eqn.source_info.name_stack
|
||||
)
|
||||
with source_info_util.user_context(
|
||||
eqn.source_info.traceback, name_stack=source_info.name_stack
|
||||
):
|
||||
if "jaxpr" in eqn.params:
|
||||
result += _roofline_interpreter(
|
||||
util.wrap_name(f_name, eqn.primitive.name),
|
||||
eqn.params["jaxpr"],
|
||||
mesh,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
)
|
||||
else:
|
||||
if eqn.primitive not in _rooflines:
|
||||
msg = f"No roofline rule for {eqn.primitive}."
|
||||
for attr in dir(eqn):
|
||||
if not attr.startswith("_"):
|
||||
msg += f"\n{attr}: {getattr(eqn, attr)}"
|
||||
raise NotImplementedError(msg)
|
||||
rule = _rooflines[eqn.primitive]
|
||||
result += rule(
|
||||
RooflineRuleContext(
|
||||
name_stack=source_info.name_stack,
|
||||
primitive=eqn.primitive,
|
||||
avals_in=map(aval, eqn.invars),
|
||||
avals_out=map(aval, eqn.outvars),
|
||||
jaxpr_eqn_ctx=eqn.ctx,
|
||||
mesh=mesh,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
),
|
||||
*map(read, eqn.invars),
|
||||
**eqn.params,
|
||||
)
|
||||
|
||||
map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars))
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _f_with_vjp(f: Callable):
|
||||
@util.wraps(f)
|
||||
def wrapped(*args):
|
||||
primals, f_vjp = api.vjp(f, *args)
|
||||
return f_vjp(tree_map(jnp.bfloat16, primals))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def roofline(
|
||||
f: Callable,
|
||||
mesh: Mesh | AbstractMesh,
|
||||
in_specs: shard_map.Specs,
|
||||
out_specs: shard_map.Specs,
|
||||
*,
|
||||
pin_lhs_in_vmem: bool = False,
|
||||
pin_rhs_in_vmem: bool = False,
|
||||
vjp: bool = False,
|
||||
print_jaxpr: bool = False,
|
||||
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]:
|
||||
@util.wraps(f)
|
||||
@traceback_util.api_boundary
|
||||
def wrapped(*args):
|
||||
wrapped_f = shard_map.shard_map(f, mesh, in_specs, out_specs)
|
||||
if vjp:
|
||||
wrapped_f = _f_with_vjp(wrapped_f)
|
||||
|
||||
jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args)
|
||||
|
||||
def make_sharded_shape_dtype_struct(
|
||||
shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs
|
||||
) -> api.ShapeDtypeStruct:
|
||||
return api.ShapeDtypeStruct(
|
||||
shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec)
|
||||
)
|
||||
|
||||
out_specs_flat = broadcast_prefix(out_specs, out_shapes)
|
||||
flat_out_shapes, treedef = tree_flatten(out_shapes)
|
||||
flat_out_shapes = map(
|
||||
make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat
|
||||
)
|
||||
out_shapes = tree_unflatten(treedef, flat_out_shapes)
|
||||
|
||||
used_outputs = (True,) * len(jaxpr.jaxpr.outvars)
|
||||
jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs)
|
||||
try:
|
||||
jaxpr = [e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p][
|
||||
-1
|
||||
].params["jaxpr"]
|
||||
except KeyError:
|
||||
raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.")
|
||||
|
||||
if print_jaxpr:
|
||||
print(jaxpr)
|
||||
|
||||
return out_shapes, _roofline_interpreter(
|
||||
util.fun_qual_name(f),
|
||||
jaxpr,
|
||||
mesh,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def register_roofline(prim: core.Primitive):
|
||||
def register(rule: _RooflineRule):
|
||||
_rooflines[prim] = rule
|
||||
return rule
|
||||
|
||||
return register
|
||||
|
||||
|
||||
def register_standard_roofline(prim: core.Primitive):
|
||||
def standard_rule(ctx: RooflineRuleContext, *args, **kwargs):
|
||||
return RooflineResult.zeros()
|
||||
|
||||
_rooflines[prim] = standard_rule
|
||||
|
||||
|
||||
def roofline_and_grad(
|
||||
f: Callable,
|
||||
mesh: Mesh | AbstractMesh,
|
||||
in_specs: shard_map.Specs,
|
||||
out_specs: shard_map.Specs,
|
||||
*,
|
||||
pin_lhs_in_vmem: bool = False,
|
||||
pin_rhs_in_vmem: bool = False,
|
||||
print_jaxpr: bool = False,
|
||||
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]:
|
||||
@util.wraps(f)
|
||||
@traceback_util.api_boundary
|
||||
def wrapped(*args):
|
||||
primal_shapes, fwd_result = roofline(
|
||||
f,
|
||||
mesh,
|
||||
in_specs,
|
||||
out_specs,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
print_jaxpr=print_jaxpr,
|
||||
)(*args)
|
||||
|
||||
return (
|
||||
primal_shapes,
|
||||
fwd_result,
|
||||
roofline(
|
||||
f,
|
||||
mesh,
|
||||
in_specs,
|
||||
out_specs,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
vjp=True,
|
||||
print_jaxpr=print_jaxpr,
|
||||
)(
|
||||
*tree_map(
|
||||
lambda x: api.ShapeDtypeStruct(
|
||||
x.shape,
|
||||
jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16,
|
||||
sharding=x.sharding,
|
||||
),
|
||||
args,
|
||||
)
|
||||
)[1],
|
||||
)
|
||||
|
||||
return wrapped
|
270
jax/experimental/roofline/rooflines.py
Normal file
270
jax/experimental/roofline/rooflines.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core, util
|
||||
from jax._src import ops
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src.lax import (
|
||||
ann,
|
||||
convolution,
|
||||
fft,
|
||||
lax,
|
||||
linalg,
|
||||
parallel as lax_parallel,
|
||||
slicing,
|
||||
special,
|
||||
windowed_reductions,
|
||||
)
|
||||
from jax.experimental import roofline
|
||||
from jax.experimental import shard_map
|
||||
|
||||
|
||||
for prim in it.chain(
|
||||
ad_util.__dict__.values(),
|
||||
ann.__dict__.values(),
|
||||
convolution.__dict__.values(),
|
||||
fft.__dict__.values(),
|
||||
lax.__dict__.values(),
|
||||
linalg.__dict__.values(),
|
||||
ops.__dict__.values(),
|
||||
prng.__dict__.values(),
|
||||
random.__dict__.values(),
|
||||
shard_map.__dict__.values(),
|
||||
slicing.__dict__.values(),
|
||||
special.__dict__.values(),
|
||||
windowed_reductions.__dict__.values(),
|
||||
):
|
||||
if isinstance(prim, core.Primitive):
|
||||
roofline.register_standard_roofline(prim)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax.dot_general_p)
|
||||
def _dot_general_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
dimension_numbers: lax.DotDimensionNumbers,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
(lhs_contract, _), (lhs_batch, _) = dimension_numbers
|
||||
|
||||
flops = (
|
||||
2
|
||||
* lhs.size
|
||||
* rhs.size
|
||||
/ np.prod([lhs.shape[i] for i in lhs_contract])
|
||||
/ np.prod([lhs.shape[i] for i in lhs_batch])
|
||||
)
|
||||
|
||||
hbm_bytes = 0
|
||||
if not ctx.pin_lhs_in_vmem:
|
||||
hbm_bytes += lhs.bytes
|
||||
hbm_bytes += out.bytes
|
||||
if not ctx.pin_rhs_in_vmem:
|
||||
hbm_bytes += rhs.bytes
|
||||
|
||||
return roofline.RooflineResult(flops=int(flops), hbm_bytes=hbm_bytes)
|
||||
|
||||
|
||||
def _return_zeros_if_one_sized_axis(
|
||||
ctx: roofline.RooflineRuleContext, axes: tuple[str, ...]
|
||||
) -> roofline.RooflineResult | None:
|
||||
axes_size = np.prod([ctx.mesh.shape[axis] for axis in axes])
|
||||
if axes_size > 1:
|
||||
return None
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: 0 for axis in axes},
|
||||
ici_latency={axis: 0 for axis in axes},
|
||||
)
|
||||
|
||||
|
||||
def _ring_collective_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axes: tuple[str, ...],
|
||||
is_reduce: bool = True,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axes):
|
||||
return zeros_result
|
||||
|
||||
mesh = ctx.mesh.shape
|
||||
current_shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
|
||||
if is_reduce:
|
||||
current_shard_size /= np.prod([mesh[axis] for axis in axes])
|
||||
|
||||
# We model the slowest color as the bottleneck.
|
||||
sorted_axes = sorted(axes, key=lambda x: mesh[x], reverse=True)
|
||||
num_axes = len(sorted_axes)
|
||||
|
||||
ici_bytes = 0
|
||||
# Phase split.
|
||||
current_shard_size //= num_axes
|
||||
for axis in sorted_axes:
|
||||
axis_size = mesh[axis]
|
||||
# Do phase.
|
||||
ici_bytes += current_shard_size * (axis_size - 1)
|
||||
# Increase shard size.
|
||||
current_shard_size *= axis_size
|
||||
|
||||
# Bottleneck is the longest axis.
|
||||
ici_latency = mesh[sorted_axes[0]] * num_axes
|
||||
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: int(ici_bytes) for axis in sorted_axes},
|
||||
ici_latency={axis: int(ici_latency) for axis in sorted_axes},
|
||||
)
|
||||
|
||||
|
||||
roofline.register_roofline(lax_parallel.reduce_scatter_p)(
|
||||
lambda *args, axis_name, **kw: _ring_collective_roofline(*args, axes=axis_name, **kw)
|
||||
)
|
||||
roofline.register_roofline(lax_parallel.all_gather_p)(
|
||||
lambda *args, axis_name, **kw: _ring_collective_roofline(
|
||||
*args, axes=axis_name, is_reduce=False, **kw
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _scalar_collective_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axes: tuple[str, ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in]
|
||||
ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes])
|
||||
return _ring_collective_roofline(ctx, *args, axes=axes, is_reduce=False, **kw)
|
||||
|
||||
|
||||
roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline)
|
||||
roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline)
|
||||
|
||||
|
||||
@roofline.register_roofline(shard_map.psum2_p)
|
||||
def _psum2_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axes: tuple[str, ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
ring_roofline = _ring_collective_roofline(ctx, *args, axes=axes, **kw)
|
||||
|
||||
def double_dict(d: dict[str, int]) -> dict[str, int]:
|
||||
return {k: v * 2 for k, v in d.items()}
|
||||
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes=double_dict(ring_roofline.ici_bytes),
|
||||
ici_latency=double_dict(ring_roofline.ici_latency),
|
||||
)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax_parallel.all_to_all_p)
|
||||
def _all_to_all_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axis_name: tuple[str, ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
|
||||
return zeros_result
|
||||
|
||||
mesh = ctx.mesh.shape
|
||||
size = roofline.RooflineShape.total_bytes(ctx.avals_in) * np.prod([
|
||||
mesh[axis] for axis in axis_name
|
||||
])
|
||||
|
||||
smallest_axis = sorted(axis_name, key=lambda x: mesh[x])[0]
|
||||
num_axes = len(axis_name)
|
||||
bisection_bw = mesh[smallest_axis] ** (num_axes - 1)
|
||||
if mesh[smallest_axis] > 2:
|
||||
# Times 2 because of wraparound.
|
||||
bisection_bw *= 2
|
||||
|
||||
# Half the data needs to cross the bisection on average.
|
||||
ici_bytes = size / 2 / bisection_bw
|
||||
|
||||
# The latency is the max number of hops across the mesh.
|
||||
ici_latency = sum(mesh[axis] / 2 for axis in axis_name)
|
||||
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
|
||||
ici_latency={axis: int(ici_latency) for axis in axis_name},
|
||||
)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax_parallel.ppermute_p)
|
||||
def _ppermute_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axis_name: tuple[str, ...],
|
||||
perm: tuple[tuple[int, int], ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
|
||||
return zeros_result
|
||||
|
||||
mesh = ctx.mesh.shape
|
||||
mesh_dims: list[int] = [mesh.get(axis, 1) for axis in axis_name]
|
||||
shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
|
||||
|
||||
ici_contention: dict[tuple[tuple[int, ...], ...], float] = defaultdict(float)
|
||||
ici_latency = 0
|
||||
|
||||
for src, dst in perm:
|
||||
if src == dst:
|
||||
continue
|
||||
# Perms are linearized.
|
||||
src_coords = tuple(int(i) for i in np.unravel_index(src, mesh_dims))
|
||||
dst_coords = tuple(int(i) for i in np.unravel_index(dst, mesh_dims))
|
||||
|
||||
ici_latency_for_perm = 0
|
||||
|
||||
# For each dimension.
|
||||
for i in range(len(axis_name)):
|
||||
dim_size = mesh_dims[i]
|
||||
src_pos = src_coords[i]
|
||||
dst_pos = dst_coords[i]
|
||||
|
||||
if src_pos != dst_pos:
|
||||
# Calculate distance with wraparound.
|
||||
clockwise_dist = (dst_pos - src_pos) % dim_size
|
||||
counter_dist = (src_pos - dst_pos) % dim_size
|
||||
direction = 1 if clockwise_dist <= counter_dist else -1
|
||||
|
||||
curr_pos = src_pos
|
||||
while curr_pos != dst_pos:
|
||||
curr_coords = util.tuple_update(src_coords, i, curr_pos)
|
||||
next_pos = (curr_pos + direction) % dim_size
|
||||
next_coords = util.tuple_update(curr_coords, i, next_pos)
|
||||
ici_contention[tuple(sorted([curr_coords, next_coords]))] += 1
|
||||
curr_pos = next_pos
|
||||
|
||||
distance = min(clockwise_dist, counter_dist)
|
||||
ici_latency_for_perm += distance
|
||||
|
||||
ici_latency = max(ici_latency, ici_latency_for_perm)
|
||||
|
||||
ici_bytes = shard_size * max(ici_contention.values(), default=0)
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
|
||||
ici_latency={axis: int(ici_latency) for axis in axis_name},
|
||||
)
|
@ -1200,6 +1200,12 @@ jax_multiplatform_test(
|
||||
srcs = ["key_reuse_test.py"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "roofline_test",
|
||||
srcs = ["roofline_test.py"],
|
||||
enable_backends = ["cpu"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "x64_context_test",
|
||||
srcs = ["x64_context_test.py"],
|
||||
|
426
tests/roofline_test.py
Normal file
426
tests/roofline_test.py
Normal file
@ -0,0 +1,426 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import contextlib
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.experimental import roofline
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def create_inputs(
|
||||
*shardings: P,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
mesh_shape: tuple[int, ...] = (2, 2, 2),
|
||||
) -> tuple[jax.sharding.Mesh, tuple[jax.ShapeDtypeStruct, ...]]:
|
||||
mesh = jtu.create_mesh(mesh_shape, ("x", "y", "z"))
|
||||
arrays = []
|
||||
for sharding in shardings:
|
||||
array = jax.ShapeDtypeStruct(
|
||||
(8, 8), dtype, sharding=jax.sharding.NamedSharding(mesh, sharding)
|
||||
)
|
||||
arrays.append(array)
|
||||
return mesh, tuple(arrays)
|
||||
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
class RooflineTest(jtu.JaxTestCase):
|
||||
def test_scalar_collectives(self):
|
||||
a_spec = P("z", ("x", "y"))
|
||||
b_spec = P(("x", "y"), "z")
|
||||
mesh, (a, b) = create_inputs(a_spec, b_spec)
|
||||
|
||||
@partial(
|
||||
roofline.roofline,
|
||||
mesh=mesh,
|
||||
in_specs=(a_spec, b_spec),
|
||||
out_specs=(P("z", None), P(("x", "y"), None)),
|
||||
)
|
||||
def scalar_collectives(a, b):
|
||||
a = lax.pmin(a, ("x", "y"))
|
||||
b = lax.pmax(b, "z")
|
||||
return a, b
|
||||
|
||||
_, results = scalar_collectives(a, b)
|
||||
|
||||
itemsize = 4
|
||||
|
||||
axis_size = 2
|
||||
axis_size_m1 = axis_size - 1
|
||||
|
||||
xy_num_axes = 2
|
||||
xy_ici_bytes = int(
|
||||
itemsize
|
||||
# 2 phases.
|
||||
* (
|
||||
(1 / xy_num_axes * axis_size_m1) + (1 * axis_size / xy_num_axes * axis_size_m1)
|
||||
)
|
||||
)
|
||||
# 2 phases times 2 hops.
|
||||
xy_ici_latency = 2 * 2
|
||||
|
||||
z_ici_bytes = int(itemsize * 1 * axis_size_m1)
|
||||
# 2 hops.
|
||||
z_ici_latency = 2
|
||||
expected = roofline.RooflineResult(
|
||||
ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes},
|
||||
ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency},
|
||||
peak_hbm_bytes=itemsize * 2 * 4 * 2,
|
||||
)
|
||||
self.assertDataclassEqual(results, expected)
|
||||
|
||||
def test_collective_matmul(self):
|
||||
a_spec = P(None, "x")
|
||||
b_spec = P(None, "x")
|
||||
c_spec = P("x", None)
|
||||
mesh, (a, b, c) = create_inputs(a_spec, b_spec, c_spec, dtype=jnp.int8)
|
||||
|
||||
@partial(
|
||||
roofline.roofline,
|
||||
mesh=mesh,
|
||||
in_specs=(a_spec, b_spec, c_spec),
|
||||
out_specs=a_spec,
|
||||
)
|
||||
def collective_matmul(a, b, c):
|
||||
a = lax.all_gather(a, "x", axis=1, tiled=True)
|
||||
# Test broadcasting and slicing works.
|
||||
a = a[None, :, :]
|
||||
b = b[:, None, :]
|
||||
ab = jnp.einsum("bij,jbk->ikb", a, b).astype(jnp.int8)[..., 0]
|
||||
abc = jnp.einsum("ik,kc->ic", ab, c).astype(jnp.int8)
|
||||
abc = lax.psum_scatter(abc, "x", scatter_dimension=1, tiled=True)
|
||||
return abc
|
||||
|
||||
_, results = collective_matmul(a, b, c)
|
||||
|
||||
itemsize = 1
|
||||
m, k, n = 8, 4, 8
|
||||
mk = m * k
|
||||
kn = k * n
|
||||
mn = m * n
|
||||
|
||||
axis_size = 2
|
||||
axis_size_m1 = axis_size - 1
|
||||
sharded_mk = mk
|
||||
|
||||
# Times 2 for ag + rs.
|
||||
ici_bytes = 2 * int(itemsize * sharded_mk * axis_size_m1)
|
||||
ici_latency = 2 * 2
|
||||
expected = roofline.RooflineResult(
|
||||
flops=2 * 2 * m * k * n,
|
||||
ici_bytes={"x": ici_bytes},
|
||||
ici_latency={"x": ici_latency},
|
||||
hbm_bytes=2 * itemsize * (mk + kn + mn),
|
||||
# Right after all_gather.
|
||||
peak_hbm_bytes=itemsize * (mk * axis_size + mk + kn),
|
||||
)
|
||||
self.assertDataclassEqual(results, expected)
|
||||
|
||||
def test_matmul_psum(self):
|
||||
a_spec = P("z", ("x", "y"))
|
||||
b_spec = P(("x", "y"), None)
|
||||
mesh, (a, b) = create_inputs(a_spec, b_spec)
|
||||
|
||||
@partial(
|
||||
roofline.roofline,
|
||||
mesh=mesh,
|
||||
in_specs=(a_spec, b_spec),
|
||||
out_specs=P("z", None),
|
||||
)
|
||||
def matmul_psum(a, b):
|
||||
c = a @ b
|
||||
c = lax.psum(c, ("x", "y"))
|
||||
return c
|
||||
|
||||
_, results = matmul_psum(a, b)
|
||||
|
||||
itemsize = 4
|
||||
m, k, n = 4, 2, 8
|
||||
mk = m * k
|
||||
kn = k * n
|
||||
mn = m * n
|
||||
|
||||
axis_size = 2
|
||||
axis_size_m1 = axis_size - 1
|
||||
num_axes = 2
|
||||
sharded_mn = mn / axis_size / num_axes
|
||||
|
||||
# Times 2 for ag + rs.
|
||||
ici_bytes = 2 * int(
|
||||
itemsize
|
||||
# 2 phases.
|
||||
* (
|
||||
(sharded_mn / num_axes * axis_size_m1)
|
||||
+ (sharded_mn * axis_size / num_axes * axis_size_m1)
|
||||
)
|
||||
)
|
||||
ici_latency = 2 * 2 * 2
|
||||
expected = roofline.RooflineResult(
|
||||
flops=2 * m * k * n,
|
||||
ici_bytes={axis: ici_bytes for axis in ("x", "y")},
|
||||
ici_latency={axis: ici_latency for axis in ("x", "y")},
|
||||
hbm_bytes=itemsize * (mk + kn + mn),
|
||||
peak_hbm_bytes=itemsize * (mn),
|
||||
)
|
||||
self.assertDataclassEqual(results, expected)
|
||||
|
||||
def test_all_to_all(self):
|
||||
a_spec = P("z", ("x", "y"))
|
||||
b_spec = P(("x", "y"), "z")
|
||||
mesh, (a, b) = create_inputs(a_spec, b_spec)
|
||||
|
||||
@partial(
|
||||
roofline.roofline,
|
||||
mesh=mesh,
|
||||
in_specs=(a_spec, b_spec),
|
||||
out_specs=(P(("z", "x", "y"), None), P(("x", "y", "z"), None)),
|
||||
)
|
||||
def all_to_all(a, b):
|
||||
a = lax.all_to_all(a, ("x", "y"), split_axis=0, concat_axis=1, tiled=True)
|
||||
b = lax.all_to_all(b, "z", split_axis=0, concat_axis=1, tiled=True)
|
||||
return a, b
|
||||
|
||||
_, results = all_to_all(a, b)
|
||||
|
||||
itemsize = 4
|
||||
|
||||
xy_size = itemsize * 8 * 8 / 2
|
||||
# Half the data over 2 links.
|
||||
xy_ici_bytes = int(xy_size / 2 / 2)
|
||||
# 2 hops.
|
||||
xy_ici_latency = 2
|
||||
|
||||
z_size = itemsize * 8 * 8 / 2 / 2
|
||||
# Half the data over 1 link.
|
||||
z_ici_bytes = int(z_size / 2)
|
||||
# 1 hop.
|
||||
z_ici_latency = 1
|
||||
expected = roofline.RooflineResult(
|
||||
ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes},
|
||||
ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency},
|
||||
peak_hbm_bytes=itemsize * 2 * 4 * 2,
|
||||
)
|
||||
self.assertDataclassEqual(results, expected)
|
||||
|
||||
def test_ppermute(self):
|
||||
a_spec = P("z", ("x", "y"))
|
||||
b_spec = P(("x", "y"), "z")
|
||||
mesh, (a, b) = create_inputs(a_spec, b_spec)
|
||||
|
||||
@partial(
|
||||
roofline.roofline,
|
||||
mesh=mesh,
|
||||
in_specs=(a_spec, b_spec),
|
||||
out_specs=(a_spec, b_spec),
|
||||
)
|
||||
def ppermute(a, b):
|
||||
a = lax.ppermute(a, ("x", "y"), perm=((0, 3), (3, 0), (1, 2), (2, 1)))
|
||||
b = lax.ppermute(b, "z", perm=((1, 0), (0, 1)))
|
||||
return a, b
|
||||
|
||||
_, results = ppermute(a, b)
|
||||
|
||||
itemsize = 4
|
||||
shard_size = itemsize * 4 * 2
|
||||
|
||||
# At most 2 shards contend for 1 link.
|
||||
xy_ici_bytes = int(shard_size * 2)
|
||||
# 2 hops.
|
||||
xy_ici_latency = 2
|
||||
|
||||
# No contention but there is a single link.
|
||||
z_ici_bytes = int(shard_size * 2)
|
||||
# 1 hop.
|
||||
z_ici_latency = 1
|
||||
expected = roofline.RooflineResult(
|
||||
ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes},
|
||||
ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency},
|
||||
peak_hbm_bytes=itemsize * 2 * 4 * 2,
|
||||
)
|
||||
self.assertDataclassEqual(results, expected)
|
||||
|
||||
def test_grad_matmuls(self):
|
||||
a_spec = P(None, "x")
|
||||
b_spec = P(None, None)
|
||||
mesh, (a, b) = create_inputs(a_spec, b_spec, dtype=jnp.int8)
|
||||
|
||||
@partial(
|
||||
roofline.roofline_and_grad,
|
||||
mesh=mesh,
|
||||
in_specs=(a_spec, b_spec),
|
||||
# Numerically incorrect AD, but tests that we handle it properly.
|
||||
out_specs=P("x", None),
|
||||
)
|
||||
def collective_matmul(a, b):
|
||||
a = lax.all_gather(a, "x", axis=1, tiled=True)
|
||||
return a @ b
|
||||
|
||||
c, fwd_results, bwd_results = collective_matmul(a, b)
|
||||
|
||||
itemsize = 1
|
||||
m, k, n = 8, 8, 8
|
||||
mk = m * k
|
||||
kn = k * n
|
||||
mn = m * n
|
||||
|
||||
axis_size = 2
|
||||
axis_size_m1 = axis_size - 1
|
||||
sharded_mk = mk // axis_size
|
||||
|
||||
ici_bytes = int(itemsize * sharded_mk * axis_size_m1)
|
||||
ici_latency = 2
|
||||
expected = roofline.RooflineResult(
|
||||
flops=2 * m * k * n,
|
||||
ici_bytes={"x": ici_bytes},
|
||||
ici_latency={"x": ici_latency},
|
||||
hbm_bytes=itemsize * (mk + kn + mn),
|
||||
peak_hbm_bytes=itemsize * (mk + kn),
|
||||
)
|
||||
self.assertDataclassEqual(fwd_results, expected)
|
||||
|
||||
bwd_itemsize = 2
|
||||
# 2 for psum + 1 for rs.
|
||||
bwd_ici_bytes = 3 * int(bwd_itemsize * sharded_mk * axis_size_m1)
|
||||
expected = roofline.RooflineResult(
|
||||
flops=2 * 2 * m * k * n,
|
||||
ici_bytes={"x": bwd_ici_bytes},
|
||||
ici_latency={"x": 3 * ici_latency},
|
||||
hbm_bytes=2 * bwd_itemsize * (mk + kn + mn),
|
||||
# Residuals + cotangents.
|
||||
peak_hbm_bytes=bwd_itemsize * (mk + kn + mn),
|
||||
)
|
||||
self.assertDataclassEqual(bwd_results, expected)
|
||||
|
||||
@partial(
|
||||
roofline.roofline,
|
||||
mesh=mesh,
|
||||
in_specs=c.sharding.spec,
|
||||
out_specs=c.sharding.spec,
|
||||
)
|
||||
def mul_2(c):
|
||||
return c * 2
|
||||
|
||||
results = mul_2(c)
|
||||
self.assertLen(results, 2)
|
||||
|
||||
def test_one_sized_axis_collectives(self):
|
||||
a_spec = P("x")
|
||||
mesh, (a,) = create_inputs(a_spec, mesh_shape=(1, 2, 4))
|
||||
|
||||
@partial(
|
||||
roofline.roofline,
|
||||
mesh=mesh,
|
||||
in_specs=a_spec,
|
||||
out_specs=a_spec,
|
||||
)
|
||||
def one_sized_axis_collectives(a):
|
||||
a = lax.pmin(a, "x")
|
||||
a = lax.all_gather(a, "x", axis=1, tiled=True)
|
||||
a = lax.psum_scatter(a, "x", scatter_dimension=1, tiled=True)
|
||||
a = lax.psum(a, "x")
|
||||
a = lax.all_to_all(a, "x", split_axis=0, concat_axis=1, tiled=True)
|
||||
a = lax.ppermute(a, "x", perm=((1, 0), (0, 1)))
|
||||
return a
|
||||
|
||||
_, results = one_sized_axis_collectives(a)
|
||||
expected = roofline.RooflineResult(
|
||||
ici_bytes={"x": 0},
|
||||
ici_latency={"x": 0},
|
||||
peak_hbm_bytes=4 * 8 * 8,
|
||||
)
|
||||
self.assertDataclassEqual(results, expected)
|
||||
|
||||
def test_remat(self):
|
||||
a_spec = P("x", None)
|
||||
b_spec = P("x", None)
|
||||
mesh, (a, b) = create_inputs(a_spec, b_spec)
|
||||
|
||||
def fsdp_checkpoint_policy(prim, *args, **kwargs):
|
||||
if prim == lax.all_gather_p and kwargs["axis_name"] == "x":
|
||||
return True
|
||||
return False
|
||||
|
||||
@partial(
|
||||
roofline.roofline_and_grad,
|
||||
mesh=mesh,
|
||||
in_specs=(a_spec, b_spec),
|
||||
out_specs=P("x", None),
|
||||
)
|
||||
@partial(jax.checkpoint, policy=fsdp_checkpoint_policy)
|
||||
def collective_matmul(a, b):
|
||||
b = lax.all_gather(b, "x", axis=0, tiled=True)
|
||||
return a @ b
|
||||
|
||||
_, fwd_results, bwd_results = collective_matmul(a, b)
|
||||
|
||||
itemsize = 4
|
||||
m, k, n = 4, 8, 8
|
||||
mk = m * k
|
||||
kn = k * n
|
||||
mn = m * n
|
||||
|
||||
axis_size = 2
|
||||
axis_size_m1 = axis_size - 1
|
||||
sharded_kn = kn // axis_size
|
||||
|
||||
ici_bytes = int(itemsize * sharded_kn * axis_size_m1)
|
||||
ici_latency = 2
|
||||
expected = roofline.RooflineResult(
|
||||
flops=2 * m * k * n,
|
||||
ici_bytes={"x": ici_bytes},
|
||||
ici_latency={"x": ici_latency},
|
||||
hbm_bytes=itemsize * (mk + kn + mn),
|
||||
peak_hbm_bytes=itemsize * (mk + kn),
|
||||
)
|
||||
self.assertDataclassEqual(fwd_results, expected)
|
||||
|
||||
bwd_itemsize = 2
|
||||
# Remat ag + rs.
|
||||
bwd_ici_bytes = 2 * int(bwd_itemsize * sharded_kn * axis_size_m1)
|
||||
expected = roofline.RooflineResult(
|
||||
flops=2 * 2 * m * k * n,
|
||||
ici_bytes={"x": bwd_ici_bytes},
|
||||
ici_latency={"x": 2 * ici_latency},
|
||||
hbm_bytes=2 * bwd_itemsize * (mk + kn + mn),
|
||||
# Residuals + cotangents.
|
||||
# We gather kn while computing the kn cotangents.
|
||||
peak_hbm_bytes=bwd_itemsize * (kn + kn + mn),
|
||||
)
|
||||
self.assertDataclassEqual(bwd_results, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user