mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details. PiperOrigin-RevId: 476167538
347 lines
14 KiB
Python
347 lines
14 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Module for JAX debugging primitives and related functionality."""
|
|
import enum
|
|
import functools
|
|
import string
|
|
import sys
|
|
|
|
from typing import Any, Dict, Callable, Sequence, Set, Tuple, Union
|
|
|
|
from jax import core
|
|
from jax import tree_util
|
|
from jax import lax
|
|
from jax._src import ad_checkpoint
|
|
from jax._src import custom_derivatives
|
|
from jax._src import lib as jaxlib
|
|
from jax._src import util
|
|
from jax.config import config
|
|
from jax.experimental.sharding import Sharding
|
|
from jax.interpreters import ad
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax._src.lax import control_flow as lcf
|
|
from jax._src.lib import xla_client as xc
|
|
import jax.numpy as jnp
|
|
|
|
# pytype: disable=import-error
|
|
try:
|
|
import rich
|
|
import rich.align
|
|
import rich.box
|
|
import rich.console
|
|
import rich.padding
|
|
import rich.table
|
|
RICH_ENABLED = True
|
|
except:
|
|
RICH_ENABLED = False
|
|
# pytype: enable=import-error
|
|
|
|
DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])
|
|
|
|
core.ordered_effects.add(DebugEffect.ORDERED_PRINT)
|
|
mlir.lowerable_effects.add(DebugEffect.PRINT)
|
|
mlir.lowerable_effects.add(DebugEffect.ORDERED_PRINT)
|
|
lcf.allowed_effects.add(DebugEffect.PRINT)
|
|
lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT)
|
|
ad_checkpoint.remat_allowed_effects.add(DebugEffect.PRINT)
|
|
ad_checkpoint.remat_allowed_effects.add(DebugEffect.ORDERED_PRINT)
|
|
custom_derivatives.allowed_effects.add(DebugEffect.PRINT)
|
|
custom_derivatives.allowed_effects.add(DebugEffect.ORDERED_PRINT)
|
|
|
|
# `debug_callback_p` is the main primitive for staging out Python callbacks.
|
|
debug_callback_p = core.Primitive('debug_callback')
|
|
debug_callback_p.multiple_results = True
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
|
|
@debug_callback_p.def_impl
|
|
def debug_callback_impl(*args, callback: Callable[..., Any],
|
|
effect: DebugEffect):
|
|
del effect
|
|
return callback(*args)
|
|
|
|
@debug_callback_p.def_effectful_abstract_eval
|
|
def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any],
|
|
effect: DebugEffect):
|
|
del flat_avals, callback
|
|
return [], {effect}
|
|
|
|
def debug_callback_batching_rule(args, dims, **params):
|
|
"""Unrolls the debug callback across the mapped axis."""
|
|
axis_size = next(x.shape[i] for x, i in zip(args, dims)
|
|
if i is not None)
|
|
# TODO(sharadmv): implement in terms of rolled loop unstead of
|
|
# unrolled.
|
|
def get_arg_at_dim(i, dim, arg):
|
|
if dim is batching.not_mapped:
|
|
# Broadcast unmapped argument
|
|
return arg
|
|
return lax.index_in_dim(arg, i, axis=dim, keepdims=False)
|
|
outs = []
|
|
for i in range(axis_size):
|
|
args_idx = map(functools.partial(get_arg_at_dim, i), dims, args)
|
|
outs.append(debug_callback_p.bind(*args_idx, **params))
|
|
outs = [jnp.stack(xs) for xs in zip(*outs)]
|
|
return outs, (0,) * len(outs)
|
|
batching.primitive_batchers[debug_callback_p] = debug_callback_batching_rule
|
|
|
|
def debug_callback_jvp_rule(primals, tangents, **params):
|
|
return debug_callback_p.bind(*primals, **params), []
|
|
ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule
|
|
|
|
def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any],
|
|
effect: DebugEffect):
|
|
del flat_args, callback, effect
|
|
raise ValueError("Transpose doesn't support debugging callbacks.")
|
|
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule
|
|
|
|
def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
|
|
|
if isinstance(ctx.module_context.axis_context,
|
|
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
|
# Apply maximal sharding so pjit only executes the callback on device 0.
|
|
sharding = xc.OpSharding()
|
|
sharding.type = xc.OpSharding.Type.MAXIMAL
|
|
sharding.tile_assignment_dimensions = [1]
|
|
sharding.tile_assignment_devices = [0]
|
|
else:
|
|
sharding = None
|
|
|
|
def _callback(*flat_args):
|
|
return tuple(
|
|
debug_callback_p.impl(
|
|
*flat_args, effect=effect, callback=callback, **params))
|
|
if effect in core.ordered_effects:
|
|
token = ctx.tokens_in.get(effect)[0]
|
|
result, token, keepalive = mlir.emit_python_callback(
|
|
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
|
|
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
|
|
else:
|
|
result, token, keepalive = mlir.emit_python_callback(
|
|
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True,
|
|
sharding=sharding)
|
|
ctx.module_context.add_keepalive(keepalive)
|
|
return result
|
|
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
|
|
platform="cpu")
|
|
mlir.register_lowering(
|
|
debug_callback_p, debug_callback_lowering, platform="gpu")
|
|
if jaxlib.version >= (0, 3, 15):
|
|
mlir.register_lowering(
|
|
debug_callback_p, debug_callback_lowering, platform="tpu")
|
|
|
|
def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
|
# The default behavior for effectful primitives is to not stage them if
|
|
# possible. For debug callback, we actually want it to be staged to
|
|
# provide more information to the user. This rule bypasses partial_eval's
|
|
# regular behavior to do that. Specifically, we will stage the callback
|
|
# if:
|
|
# 1) the policy says debug_callbacks are not saveable
|
|
# 2) the policy says debug_callbacks are saveable BUT all of the input
|
|
# values are instantiated.
|
|
# The purpose is to call back with as much information as possible while
|
|
# avoiding unnecessarily staging out other values.
|
|
if any(unks_in):
|
|
# The usual case (if we have any unknowns, we need to stage it out)
|
|
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
|
|
return None, eqn, [], [], res
|
|
if saveable(debug_callback_p, *[v.aval for v in eqn.invars], **eqn.params):
|
|
# The policy is telling us we can save the debug callback.
|
|
if all(inst_in):
|
|
# If all of the inputs are instantiated, we also stage out the
|
|
# debug_callback.
|
|
return eqn, eqn, [], [], []
|
|
else:
|
|
# If any are not instantiated, we don't do any extra staging to avoid
|
|
# affecting the computation.
|
|
return eqn, None, [], [], []
|
|
# If we can't save the debug callback (thanks to the policy) we listen to
|
|
# the policy and stage out the debug callback.
|
|
return eqn, eqn, [], [], []
|
|
pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = (
|
|
_debug_callback_partial_eval_custom)
|
|
|
|
def debug_callback(callback: Callable[..., Any], *args: Any,
|
|
ordered: bool = False, **kwargs: Any):
|
|
"""Calls a stageable Python callback.
|
|
|
|
`debug_callback` enables you to pass in a Python function that can be called
|
|
inside of a staged JAX program. A `debug_callback` follows existing JAX
|
|
transformation *pure* operational semantics, which are therefore unaware of
|
|
side-effects. This means the effect could be dropped, duplicated, or
|
|
potentially reordered in the presence of higher-order primitives and
|
|
transformations.
|
|
|
|
We want this behavior because we'd like `debug_callback` to be "innocuous",
|
|
i.e. we want these primitives to change the JAX computation as little as
|
|
possible while revealing as much about them as possible, such as which parts
|
|
of the computation are duplicated or dropped.
|
|
|
|
Args:
|
|
callback: A Python callable. Its return value will be ignored.
|
|
*args: The positional arguments to the callback.
|
|
ordered: A keyword only argument used to indicate whether or not the
|
|
staged out computation will enforce ordering of this callback w.r.t.
|
|
other ordered callbacks.
|
|
**kwargs: The keyword arguments to the callback.
|
|
Returns:
|
|
The value of `callback(*args, **kwargs)`.
|
|
"""
|
|
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
|
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
|
|
def _flat_callback(*flat_args):
|
|
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
|
callback(*args, **kwargs)
|
|
return []
|
|
return debug_callback_p.bind(*flat_args, callback=_flat_callback,
|
|
effect=effect)
|
|
|
|
class _DebugPrintFormatChecker(string.Formatter):
|
|
|
|
def check_unused_args(self, used_args, args, kwargs):
|
|
unused_args = [arg for i, arg in enumerate(args) if i not in used_args]
|
|
unused_kwargs = [k for k in kwargs if k not in used_args]
|
|
if unused_args:
|
|
raise ValueError(
|
|
f"Unused positional arguments to `jax.debug.print`: {unused_args}")
|
|
if unused_kwargs:
|
|
raise ValueError(
|
|
f"Unused keyword arguments to `jax.debug.print`: {unused_kwargs}. "
|
|
"You may be passing an f-string (i.e, `f\"{x}\"`) into "
|
|
"`jax.debug.print` and instead should pass in a regular string.")
|
|
|
|
formatter = _DebugPrintFormatChecker()
|
|
|
|
def _format_print_callback(fmt: str, *args, **kwargs):
|
|
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
|
|
|
|
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
|
|
"""Prints values and works in staged out JAX functions.
|
|
|
|
Args:
|
|
fmt: A format string, e.g. ``"hello {x}"``, that will be used to format
|
|
input arguments.
|
|
*args: A list of positional arguments to be formatted.
|
|
ordered: A keyword only argument used to indicate whether or not the
|
|
staged out computation will enforce ordering of this ``debug_print``
|
|
w.r.t. other ordered ``debug_print`` calls.
|
|
**kwargs: Additional keyword arguments to be formatted.
|
|
"""
|
|
# Check that we provide the correct arguments to be formatted
|
|
formatter.format(fmt, *args, **kwargs)
|
|
|
|
debug_callback(functools.partial(_format_print_callback, fmt), *args,
|
|
**kwargs, ordered=ordered)
|
|
|
|
def _slice_to_chunk_idx(size: int, slc: slice) -> int:
|
|
if slc.stop == slc.start == None:
|
|
return 0
|
|
slice_size = slc.stop - slc.start
|
|
assert slc.start % slice_size == 0
|
|
assert size % slice_size == 0
|
|
return slc.start // slice_size
|
|
|
|
def _raise_to_slice(slc: Union[slice, int]):
|
|
if isinstance(slc, int):
|
|
return slice(slc, slc + 1)
|
|
return slc
|
|
|
|
def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
|
|
use_color: bool = False, scale: float = 1.,
|
|
min_width: int = 9, max_width: int = 80):
|
|
"""Visualizes a `Sharding`."""
|
|
if not RICH_ENABLED:
|
|
raise ValueError("`visualize_sharding` requires `rich` to be installed.")
|
|
if use_color:
|
|
# TODO(sharadmv): Implement color in the visualizer
|
|
raise NotImplementedError
|
|
if len(shape) > 2 or len(shape) < 1:
|
|
raise ValueError(
|
|
"`visualize_sharding` only works for shapes with 1 and 2 dimensions.")
|
|
|
|
base_height = int(10 * scale)
|
|
aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0]
|
|
base_width = int(base_height * aspect_ratio)
|
|
height_to_width_ratio = 2.5
|
|
|
|
# Grab the device kind from the first device
|
|
device_kind = next(iter(sharding.device_set)).platform.upper()
|
|
|
|
device_indices_map = sharding.devices_indices_map(tuple(shape))
|
|
slices: Dict[Tuple[int, ...], Set[int]] = {}
|
|
heights: Dict[Tuple[int, ...], int] = {}
|
|
widths: Dict[Tuple[int, ...], int] = {}
|
|
for dev, slcs in device_indices_map.items():
|
|
slcs = tuple(map(_raise_to_slice, slcs))
|
|
chunk_idxs = tuple(map(_slice_to_chunk_idx, shape, slcs))
|
|
if slcs is None:
|
|
raise NotImplementedError
|
|
if len(slcs) == 2:
|
|
vert, horiz = slcs
|
|
vert_size = ((vert.stop - vert.start ) if vert.stop is not None
|
|
else shape[0])
|
|
horiz_size = ((horiz.stop - horiz.start) if horiz.stop is not None
|
|
else shape[1])
|
|
chunk_height = vert_size / shape[0] * base_height
|
|
chunk_width = (
|
|
horiz_size / shape[1] * base_width *
|
|
height_to_width_ratio)
|
|
heights[chunk_idxs] = int(chunk_height)
|
|
widths[chunk_idxs] = int(chunk_width)
|
|
slices.setdefault(chunk_idxs, set()).add(dev.id)
|
|
else:
|
|
# In the 1D case, we set the height to 1.
|
|
horiz, = slcs
|
|
vert = slice(0, 1, None)
|
|
horiz_size = (
|
|
(horiz.stop - horiz.start) if horiz.stop is not None else shape[0])
|
|
heights[(0, *chunk_idxs)] = 1
|
|
widths[(0, *chunk_idxs)] = int(horiz_size / shape[0] * base_width)
|
|
slices.setdefault((0, *chunk_idxs), set()).add(dev.id)
|
|
num_rows = max([a[0] for a in slices.keys()]) + 1
|
|
if len(list(slices.keys())[0]) == 1:
|
|
num_cols = 1
|
|
else:
|
|
num_cols = max([a[1] for a in slices.keys()]) + 1
|
|
|
|
table = rich.table.Table(show_header=False, show_lines=True, padding=0,
|
|
highlight=True, pad_edge=False,
|
|
box=rich.box.SQUARE)
|
|
console = rich.console.Console(width=max_width)
|
|
for i in range(num_rows):
|
|
col = []
|
|
for j in range(num_cols):
|
|
entry = f"{device_kind} "+",".join([str(s) for s in sorted(slices[i, j])])
|
|
width, height = widths[i, j], heights[i, j]
|
|
width = min(max(width, min_width), max_width)
|
|
left_padding, remainder = divmod(width - len(entry) - 2, 2)
|
|
right_padding = left_padding + remainder
|
|
top_padding, remainder = divmod(height - 2, 2)
|
|
bottom_padding = top_padding + remainder
|
|
padding = (top_padding, right_padding, bottom_padding, left_padding)
|
|
padding = tuple(max(x, 0) for x in padding) # type: ignore
|
|
col.append(
|
|
rich.padding.Padding(
|
|
rich.align.Align(entry, "center", vertical="middle"), padding))
|
|
table.add_row(*col)
|
|
console.print(table, end='\n\n')
|
|
|
|
def visualize_array_sharding(arr, **kwargs):
|
|
"""Visualizes an array's sharding."""
|
|
if not config.jax_array:
|
|
raise NotImplementedError("`visualize_array_sharding` not implemented.")
|
|
return visualize_sharding(arr.shape, arr.sharding, **kwargs)
|