[pallas] Improve error and debugging messages with source locations

Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function.
Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive.

Added some more information to the `if debug` prints.

Set the MLIR module names so that the debug dumps are named properly.

I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules.

PiperOrigin-RevId: 659506675
This commit is contained in:
George Necula 2024-08-05 04:23:15 -07:00 committed by jax authors
parent b2a469b361
commit 252032a368
12 changed files with 205 additions and 101 deletions

View File

@ -36,6 +36,7 @@ from jax._src import mesh as mesh_lib
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
import jax.numpy as jnp
@ -56,9 +57,47 @@ TupleGrid = tuple[GridElement, ...]
Grid = Union[NamedGrid, TupleGrid]
StaticGrid = tuple[int, ...]
GridMappingGrid = tuple[int | DynamicGridDim, ...]
SrcInfoStr = str # function_name at filename:linenumber
OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.
name: str
#: the source info, and the name of kernel function if not in `name`.`
src_info: str
def __str__(self):
return f"{self.name}{' ' if self.src_info else ''}{self.src_info}"
__repr__ = __str__
replace = dataclasses.replace
@staticmethod
def from_pallas_call(pallas_call_name: str | None,
src_info : str | None) -> NameAndSrcInfo:
"""Formats the name and the source info.
Args:
pallas_call_name: The `name` argument to pallas_call.
src_info: The result of `api_util.fun_source_info(kernel)`, in the form
"{function_name} at {file_name}:{line_number}".
"""
if pallas_call_name is not None:
pallas_call_name = mlir._module_name_regex.sub("_", pallas_call_name)
if src_info is None:
return NameAndSrcInfo(
"unknown" if pallas_call_name is None else pallas_call_name,
"")
if pallas_call_name is not None:
return NameAndSrcInfo(pallas_call_name,
f"for kernel function {src_info}")
src_info_parts = src_info.split(" ")
return NameAndSrcInfo(src_info_parts[0],
" ".join(src_info_parts[1:]))
# Pytrees of jax.ShapeDtypeStruct
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]
@ -268,7 +307,7 @@ class BlockMapping:
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: SrcInfoStr
index_map_src_info: NameAndSrcInfo
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: OriginStr
@ -534,7 +573,8 @@ def _convert_block_spec_to_block_mapping(
lu.wrap_init(index_map_func), index_map_tree)
debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk,
False, "pallas_call index_map")
index_map_src_info = debug.func_src_info or "<unknown>"
index_map_src_info = NameAndSrcInfo.from_pallas_call(None,
debug.func_src_info)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
index_map_avals,

View File

@ -49,7 +49,7 @@ from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import pallas_call
from jax._src.pallas import core as pl_core
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
@ -71,7 +71,7 @@ import numpy as np
NDIndexer = indexing.NDIndexer
TPUMemorySpace = tpu_core.TPUMemorySpace
MemorySpace = pl_core.MemorySpace | TPUMemorySpace
MemorySpace = pallas_core.MemorySpace | TPUMemorySpace
VMEM = tpu_core.TPUMemorySpace.VMEM
SMEM = tpu_core.TPUMemorySpace.SMEM
# Booleans are stored as the following type in memrefs.
@ -105,7 +105,7 @@ class LoweringContext:
grid_names: tuple[Hashable, ...] | None
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
user_grid_indices: Sequence[ir.Value] | None
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
block_shapes: list[tuple[int | pallas_core.Mapped, ...]]
name_stack: source_info_util.NameStack
mesh_context: MeshContext | None
replace = dataclasses.replace
@ -136,7 +136,7 @@ class LoweringRuleContext:
lowering_context: LoweringContext
avals_in: Sequence[jax_core.AbstractValue]
avals_out: Sequence[jax_core.AbstractValue]
block_shapes: Sequence[tuple[int | pl_core.Mapped, ...] | None]
block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None]
replace = dataclasses.replace
@ -145,9 +145,9 @@ def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None
) -> ir.Attribute:
if memory_space is None:
memory_space = VMEM
elif memory_space == pl_core.MemorySpace.ERROR:
elif memory_space == pallas_core.MemorySpace.ERROR:
memory_space = SMEM
elif memory_space == pl_core.MemorySpace.INDEX:
elif memory_space == pallas_core.MemorySpace.INDEX:
memory_space = SMEM
return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>")
@ -252,10 +252,10 @@ def _get_aval_physical_dtype_shape(aval):
def _get_arg_type(
aval,
block_mapping: pl_core.BlockMapping | None,
block_mapping: pallas_core.BlockMapping | None,
):
memory_space = None
if isinstance(aval, pl_core.AbstractMemoryRef):
if isinstance(aval, pallas_core.AbstractMemoryRef):
memory_space = aval.memory_space
# We assume unannotated memory refs are in VMEM
if memory_space is None:
@ -265,7 +265,7 @@ def _get_arg_type(
# TODO(necula): clean this None block_mapping
if block_mapping is None:
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape)
shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape)
return (
aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
block_mapping.block_shape,
@ -277,7 +277,7 @@ class MosaicGridMapping:
grid: tuple[int, ...] | None
grid_names: tuple[Hashable, ...] | None
jaxpr: jax_core.Jaxpr
block_mappings: tuple[pl_core.BlockMapping | None, ...]
block_mappings: tuple[pallas_core.BlockMapping | None, ...]
mapped_dims: tuple[int, ...]
scalar_prefetch_types: tuple[ir.Type, ...]
operand_types: tuple[ir.Type, ...]
@ -289,7 +289,7 @@ class MosaicGridMapping:
mesh_info: MeshInfo | None
get_grid_indices: Callable | None
def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping,
dimension_semantics: tuple[str, ...] | None,
mesh: mesh_lib.Mesh | None):
self.grid = grid_mapping.grid
@ -340,7 +340,7 @@ class MosaicGridMapping:
for aval in scratch_avals
)
self.grid_types, _ = unzip2([
_get_arg_type(pl_core.index_map_grid_aval, None)
_get_arg_type(pallas_core.index_map_grid_aval, None)
for _ in range(len(self.grid))
])
self._prepare_mesh_info(mesh)
@ -418,9 +418,11 @@ class MeshInfo:
def lower_jaxpr_to_module(
lowering_context: mlir.LoweringRuleContext,
ctx: ir.Context,
grid_mapping: pl_core.GridMapping,
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
*,
dimension_semantics: tuple[str | None, ...] | None,
name_and_src_info: pallas_core.NameAndSrcInfo,
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
@ -432,7 +434,8 @@ def lower_jaxpr_to_module(
bm.has_trivial_window()):
continue
def err_details():
return (f"Block spec for {bm.origin} has block shape "
return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} "
"has block shape "
f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, "
# TODO(necula): add index_map source location info
f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in "
@ -460,7 +463,7 @@ def lower_jaxpr_to_module(
"only blocks having the same block shape as the array shape "
"and a trivial index_map (returning all 0s)." + err_details())
unmapped_bs = [1 if bs is pl_core.mapped else bs for bs in bm.block_shape]
unmapped_bs = [1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1]
if rank >= 2:
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
@ -507,6 +510,9 @@ def lower_jaxpr_to_module(
jaxpr, grid_mapping, dimension_semantics, mesh)
mosaic_grid_mapping.maybe_compress_grid()
m = ir.Module.create()
attrs = m.operation.attributes
module_name = name_and_src_info.name
attrs["sym_name"] = ir.StringAttr.get(module_name)
sym_tab = ir.SymbolTable(m.operation)
func_op = lower_jaxpr_to_func(
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
@ -534,7 +540,7 @@ def lower_jaxpr_to_module(
)
assert mlir_func.verify(), mlir_func
block_shape = [
1 if b is pl_core.mapped else b for b in bm.block_shape
1 if b is pallas_core.mapped else b for b in bm.block_shape
]
# If we have an extended dtype, we need to add the block shape for the
# remaining physical dtype.
@ -544,7 +550,7 @@ def lower_jaxpr_to_module(
window_bounds=window_shape,
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
)
if isinstance(bm.indexing_mode, pl_core.Unblocked):
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
if bm.indexing_mode.padding is None:
pad_low = pad_high = [0] * len(bm.block_shape)
else:
@ -557,7 +563,7 @@ def lower_jaxpr_to_module(
sym_tab.insert(mlir_func)
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
static_grid = [
MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid
MLIR_DYNAMIC if b is pallas_core.dynamic_grid_dim else b for b in grid
]
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
@ -911,7 +917,7 @@ def _make_index(s):
def _maybe_cast_to_index(cast_to_index, x):
if cast_to_index:
return _make_index(x)
return _ensure_mlir_value(x, aval=pl_core.index_map_grid_aval)
return _ensure_mlir_value(x, aval=pallas_core.index_map_grid_aval)
def _index_to_start_size_stride(
@ -940,7 +946,7 @@ def _index_to_start_size_stride(
def _indexer_to_start_size_stride(
indexer: NDIndexer,
ref_block_shape: tuple[int | pl_core.Mapped, ...],
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
*,
cast_to_index: bool,
) -> tuple[
@ -948,7 +954,7 @@ def _indexer_to_start_size_stride(
tuple[int | ir.Value, ...],
tuple[int, ...],
tuple[bool, ...],
tuple[int | pl_core.Mapped, ...],
tuple[int | pallas_core.Mapped, ...],
]:
indices_iter = iter(indexer.indices)
starts, sizes, strides, squeeze_dims = [], [], [], []
@ -960,7 +966,7 @@ def _indexer_to_start_size_stride(
1,
True,
)
if s is pl_core.mapped
if s is pallas_core.mapped
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
)
starts.append(start)
@ -982,9 +988,9 @@ def _indexer_to_start_size_stride(
def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
indexer: NDIndexer,
ref_block_shape: tuple[int | pl_core.Mapped, ...]
) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...],
tuple[int | pl_core.Mapped, ...]]:
ref_block_shape: tuple[int | pallas_core.Mapped, ...]
) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...],
tuple[int | pallas_core.Mapped, ...]]:
assert ref_block_shape is not None
target_shape = indexer.get_indexer_shape()
starts, sizes, strides, squeeze_dims, ref_block_shape = (
@ -1216,7 +1222,7 @@ def _masked_swap_lowering_rule(
mem_slice_shape.insert(i, 1)
mem_slice_shape_iter = iter(mem_slice_shape)
mem_slice_shape = [
1 if b is pl_core.mapped else next(mem_slice_shape_iter)
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
for b in ref_block_shape
]
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
@ -2126,8 +2132,8 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
if unroll != 1:
raise NotImplementedError(
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
lbd = _ensure_mlir_value(start, pl_core.index_map_grid_aval)
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pl_core.index_map_grid_aval))
lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval)
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval))
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
for_op = scf.ForOp(lbd, ubd, step, args)
with ir.InsertionPoint(for_op.body):
@ -2525,7 +2531,7 @@ def _bitcast_convert_type_lowering_rule(
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule
def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
if isinstance(aval, pl_core.AbstractMemoryRef):
if isinstance(aval, pallas_core.AbstractMemoryRef):
memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>")
if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype):
assert aval.memory_space == TPUMemorySpace.SEMAPHORE
@ -2574,8 +2580,8 @@ def _device_id_to_logical(
return sum(a * b for a, b in zip(indices, mesh_strides))
lower_ctx = LoweringRuleContext(
lowering_context=ctx.lowering_context,
avals_in=[pl_core.index_map_grid_aval] * len(device_ids),
avals_out=[pl_core.index_map_grid_aval],
avals_in=[pallas_core.index_map_grid_aval] * len(device_ids),
avals_out=[pallas_core.index_map_grid_aval],
block_shapes=(None,) * len(device_ids),
)
return lower_fun(_linearize_mesh_indices, multiple_results=False)(
@ -2855,7 +2861,7 @@ def _shard_map_discharge_rule(
rewrite,
):
del out_avals, auto, in_names, out_names, check_rep, rewrite
if not isinstance(mesh, pl_core.PallasMesh):
if not isinstance(mesh, pallas_core.PallasMesh):
raise NotImplementedError("Mesh must be a PallasMesh")
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
@ -2867,9 +2873,9 @@ def _shard_map_discharge_rule(
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
in_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
* len(in_avals),
out_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
out_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
* len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((core_axis_name, num_cores),),

View File

@ -66,7 +66,7 @@ def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext,
*in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
name_and_src_info: core.NameAndSrcInfo,
grid_mapping: core.GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
@ -75,6 +75,7 @@ def pallas_call_tpu_lowering_rule(
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
if debug:
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
print(jaxpr)
if "mosaic_params" in compiler_params:
# TODO(slebedev): Remove this branch after July 12th 2024.
@ -106,9 +107,11 @@ def pallas_call_tpu_lowering_rule(
return lowering.lower_jaxpr_to_module(
ctx, mlir_ctx, grid_mapping, jaxpr,
dimension_semantics=dimension_semantics, mesh=mesh,
for_verification=for_verification)
for_verification=for_verification,
name_and_src_info=name_and_src_info)
mosaic_module, extra_args = lower_module(for_verification=False)
if debug:
print(f"\nThe Mosaic module for pallas_call {name_and_src_info}:")
print(mosaic_module)
num_extra_args = len(extra_args)
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
@ -132,6 +135,7 @@ def pallas_call_tpu_lowering_rule(
verification_module, num_devices, num_cores
)
if promela_dump_path == "stdout":
print(f"The Promela model for pallas_call {name_and_src_info}:")
print(model)
else:
if promela_dump_path == "sponge":
@ -142,7 +146,10 @@ def pallas_call_tpu_lowering_rule(
" --jax_pallas_dump_promela_to=sponge"
)
dump_ctx = tempfile.NamedTemporaryFile(
mode="w", prefix=name + "-", suffix=".pml", dir=promela_dump_path, delete=False,
mode="w",
prefix=name_and_src_info.name + "-",
suffix=".pml",
dir=promela_dump_path, delete=False,
)
with dump_ctx as f:
f.write(model)
@ -173,7 +180,7 @@ def pallas_call_tpu_lowering_rule(
module=mosaic_module,
out_type=kernel_out_avals,
backend="tpu",
kernel_name=name,
kernel_name=name_and_src_info.name,
cost_estimate=mosaic_params.get("cost_estimate"),
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
flags=mosaic_params.get("flags"),

View File

@ -31,7 +31,7 @@ from jax._src.lax import lax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import memref as memref_dialect
from jax._src.pallas import core as pl_core
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.state import primitives as sp
from jax.experimental.mosaic import gpu as mosaic_gpu
@ -53,7 +53,7 @@ partial = functools.partial
@dataclasses.dataclass
class ModuleContext:
name: str
grid_mapping: pl_core.GridMapping
grid_mapping: pallas_core.GridMapping
runtime_smem: ir.Value # ir.MemRefType
smem_used_bytes: int
@ -117,7 +117,7 @@ class LoweringRuleContext:
module_context: ModuleContext
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None
block_shapes: list[tuple[int | pallas_core.Mapped, ...]] | None
replace = dataclasses.replace
@ -142,9 +142,9 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name
def lower_jaxpr_to_module(
grid_mapping: pl_core.GridMapping,
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
name: str,
name_and_src_info: pallas_core.NameAndSrcInfo,
compiler_params: dict[str, Any],
) -> LoweringResult:
in_structs = tuple(grid_mapping.in_shapes)
@ -180,7 +180,8 @@ def lower_jaxpr_to_module(
barrier.wait()
module_ctx = ModuleContext(name, grid_mapping, runtime_smem, smem_used_bytes=0)
module_ctx = ModuleContext(name_and_src_info.name,
grid_mapping, runtime_smem, smem_used_bytes=0)
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, buffers_smem)
for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem):
@ -210,6 +211,7 @@ def lower_jaxpr_to_module(
*extra_smem_scratch,
mgpu.TMABarrier(),
),
module_name=name_and_src_info.name,
)
return LoweringResult(module, grid, gmem_scratch_bytes, out_structs)

View File

@ -30,7 +30,7 @@ def pallas_call_lowering(
ctx: mlir.LoweringRuleContext,
*args,
jaxpr: jax_core.Jaxpr,
name: str,
name_and_src_info: pallas_core.NameAndSrcInfo,
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
@ -48,16 +48,19 @@ def pallas_call_lowering(
)
if debug:
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
print(jaxpr)
print(f"The grid mapping for pallas_call {name_and_src_info}:")
print(grid_mapping)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
jaxpr,
name,
name_and_src_info,
compiler_params,
)
if debug:
print(f"\nThe Mosaic GPU module for pallas_call {name_and_src_info}:")
print(lowering_result.module.operation)
module = lowering_result.module

View File

@ -167,12 +167,12 @@ def _pallas_call_impl(*args, **kwargs):
def _pallas_call_impl_interpret(
*args,
jaxpr: jax_core.Jaxpr,
name: str,
name_and_src_info: pallas_core.NameAndStrInfo,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
compiler_params: Any):
del compiler_params, name
del compiler_params
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr.
dynamic_grid_args, args = split_list( # type: ignore
@ -188,6 +188,7 @@ def _pallas_call_impl_interpret(
with grid_mapping.trace_env():
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(f"\nJaxpr the the kernel in pallas_call {name_and_src_info}:")
print(discharged_jaxpr)
out = _initialize_output_vals(grid_mapping.block_mappings_output,
args, input_output_aliases)
@ -301,7 +302,7 @@ def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
for bm in grid_mapping.block_mappings_output)
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping, debug, interpret, compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
@ -336,8 +337,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
effs.append(eff)
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
if debug:
print(f"\nThe jaxpr for the jvp of pallas_call {name_and_src_info}:")
print(jvp_jaxpr)
# TODO(necula): does this work with consts?
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
jvp_grid_mapping = grid_mapping.replace(
@ -349,7 +350,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
*primals,
*tangents,
jaxpr=jvp_jaxpr,
name=f"{name}_jvp",
name_and_src_info=name_and_src_info.replace(
name=f"{name_and_src_info.name}_jvp"),
grid_mapping=jvp_grid_mapping,
interpret=interpret,
debug=debug,
@ -428,7 +430,7 @@ def _batch_with_explicit_loop(
dims: Sequence[int | batching.NotMapped],
*,
jaxpr: jax_core.Jaxpr,
name: str,
name_and_src_info: pallas_core.NameAndSrcInfo,
grid_mapping: GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
@ -493,7 +495,7 @@ def _batch_with_explicit_loop(
batch_out = pallas_call_p.bind(
*batch_args,
jaxpr=jaxpr,
name=name,
name_and_src_info=name_and_src_info,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -520,7 +522,7 @@ def _pallas_call_batching_rule(
dims,
*,
jaxpr: jax_core.Jaxpr,
name: str,
name_and_src_info: pallas_core.NameAndSrcInfo,
grid_mapping: GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
@ -542,7 +544,7 @@ def _pallas_call_batching_rule(
out = pallas_call_p.bind(
*args,
jaxpr=jaxpr,
name=name,
name_and_src_info=name_and_src_info,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -573,7 +575,7 @@ def _pallas_call_batching_rule(
args=dynamic_grid_args + args,
dims=dynamic_grid_dims + dims,
jaxpr=jaxpr,
name=name,
name_and_src_info=name_and_src_info,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -605,7 +607,7 @@ def _pallas_call_batching_rule(
args=scalar_args + args,
dims=scalar_bdims + bdims,
jaxpr=jaxpr,
name=name,
name_and_src_info=name_and_src_info,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -660,7 +662,8 @@ def _pallas_call_batching_rule(
*dynamic_grid_args,
*args,
jaxpr=jaxpr,
name=f"batched_{name}",
name_and_src_info=name_and_src_info.replace(
name=f"{name_and_src_info.name}_batched"),
grid_mapping=batched_grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -836,7 +839,7 @@ checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
@weakref_lru_cache
def _trace_kernel_to_jaxpr(fun: Callable,
fun_src_info: pallas_core.SrcInfoStr,
name_and_src_info: pallas_core.NameAndSrcInfo,
grid_mapping: GridMapping,
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
kernel_in_tree: tree_util.PyTreeDef,
@ -855,23 +858,17 @@ def _trace_kernel_to_jaxpr(fun: Callable,
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
for c in consts]
raise ValueError(
f"The kernel function {fun_src_info} in a "
"pallas_call should not capture constants. You should pass them "
f"as inputs. It captures constants of shapes: {consts_avals}")
f"The kernel function in the pallas_call {name_and_src_info} "
f"captures constants {consts_avals}. "
"You should pass them as inputs")
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
f"The kernel function {fun_src_info} in a "
f"pallas_call should return None. "
f"It returns a PyTree: {kernel_out_tree}")
f"The kernel function in the pallas_call {name_and_src_info} "
f"should return None. It returns a PyTree: {kernel_out_tree}")
return jaxpr
def _extract_function_name(f: Callable, name: str | None) -> str:
if name is None:
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
return name
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
"jax_pallas_use_mosaic_gpu",
@ -1009,7 +1006,11 @@ def pallas_call(
grid whose body is the kernel lowered as a JAX function. This does not
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
This is useful for debugging.
name: TO BE DOCUMENTED.
name: if present, specifies the name to use for this kernel call in
debugging and error messages. To this name we append the file and line
where the kernel function is defined, .e.g:
`{name} for kernel function {kernel_name} at {file}:{line}`.
If missing, then we use `{kernel_name} at {file}:{line}`.
compiler_params: TO BE DOCUMENTED.
Returns:
@ -1017,7 +1018,9 @@ def pallas_call(
invoke the Pallas kernel.
"""
name = _extract_function_name(kernel, name)
kernel_src_info = api_util.fun_sourceinfo(kernel)
name_and_src_info = pallas_core.NameAndSrcInfo.from_pallas_call(
name, kernel_src_info)
if compiler_params is None:
compiler_params = {}
@ -1058,16 +1061,15 @@ def pallas_call(
kernel_fun_sig = api_util.fun_signature(kernel)
arg_names = None
kernel_src_info: pallas_core.SrcInfoStr = "<unknown>"
if kernel_fun_sig:
kernel_debug_info = api_util.debug_info(
"pallas_call kernel",
api_util.fun_sourceinfo(kernel),
kernel_src_info,
kernel_fun_sig,
[1] * len(kernel_fun_sig.parameters), {}, (), ())
if kernel_debug_info:
arg_names = kernel_debug_info.arg_names
kernel_src_info = kernel_debug_info.func_src_info
del kernel_debug_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
@ -1105,7 +1107,8 @@ def pallas_call(
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *index_args, *rest_args,
jaxpr=jaxpr, name=name,
jaxpr=jaxpr,
name_and_src_info=name_and_src_info,
debug=debug,
interpret=interpret,
grid_mapping=grid_mapping,

View File

@ -266,7 +266,7 @@ def _check_tensor_size(shape: tuple[int | pallas_core.Mapped, ...]):
def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
grid_mapping: GridMapping,
name: str,
name_and_src_info: pallas_core.NameAndStrInfo,
platform: str
) -> LoweringResult:
if grid_mapping.num_dynamic_grid_bounds:
@ -283,6 +283,9 @@ def lower_jaxpr_to_triton_module(
)
with _new_ir_context(), ir.Location.unknown():
module = ir.Module.create()
attrs = module.operation.attributes
module_name = name_and_src_info.name
attrs["sym_name"] = ir.StringAttr.get(module_name)
param_types = [
tt_dialect.PointerType.get(_dtype_to_ir_type(var.aval.dtype), 1)
for var in jaxpr.invars
@ -290,7 +293,7 @@ def lower_jaxpr_to_triton_module(
assert len(jaxpr.outvars) == 0
fn_type = ir.FunctionType.get(param_types, [])
fn = tt_dialect.FuncOp(
name,
name_and_src_info.name,
ir.TypeAttr.get(fn_type),
sym_visibility="public",
res_attrs=ir.DictAttr.get(dict(noinline=ir.BoolAttr.get(False))),
@ -310,7 +313,8 @@ def lower_jaxpr_to_triton_module(
if i not in grid_mapping.vmapped_dims
]
ctx = ModuleContext(
name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
name_and_src_info.name,
grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
)
if grid_mapping.num_index_operands:
raise NotImplementedError(

View File

@ -42,7 +42,7 @@ def pallas_call_lowering(
ctx: mlir.LoweringRuleContext,
*in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
name_and_src_info: pallas_core.NameAndSrcInfo,
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
@ -67,14 +67,17 @@ def pallas_call_lowering(
num_stages = triton_params.pop("num_stages", 3)
if debug:
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
print(jaxpr)
print("The grid mapping for pallas_call {name_and_src_info}:")
print(grid_mapping)
lowering_result = lowering.lower_jaxpr_to_triton_module(
jaxpr, grid_mapping, name, lowering_platform
jaxpr, grid_mapping, name_and_src_info, lowering_platform
)
module_op = lowering_result.module.operation
if debug:
print(f"\nThe Triton module for pallas_call {name_and_src_info}:")
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
@ -86,7 +89,7 @@ def pallas_call_lowering(
buf = io.BytesIO()
module_op.write_bytecode(buf)
backend_config = dict(
name=ir.StringAttr.get(name),
name=ir.StringAttr.get(name_and_src_info.name),
ir=ir.StringAttr.get(buf.getvalue()),
num_stages=mlir.i32_attr(num_stages),
num_warps=mlir.i32_attr(num_warps),

View File

@ -690,6 +690,7 @@ def _lower_as_gpu_kernel(
in_shapes: tuple[Any, ...],
out_shape,
smem_scratch_shape: ShapeTree | Union[ShapeTree],
module_name: str,
prof_spec: profiler.ProfilerSpec | None = None,
):
ptr_ty = ir.Type.parse("!llvm.ptr")
@ -714,6 +715,8 @@ def _lower_as_gpu_kernel(
out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block))
module = ir.Module.create()
attrs = module.operation.attributes
attrs["sym_name"] = ir.StringAttr.get(module_name)
with ir.InsertionPoint(module.body):
_declare_runtime_functions()
gmem_scratch_bytes = 0
@ -772,6 +775,7 @@ def as_gpu_kernel(
smem_scratch_shape: ShapeTree | Union[ShapeTree],
prof_spec: profiler.ProfilerSpec | None = None,
cluster: tuple[int, int, int] = (1, 1, 1),
module_name: str = "unknown",
):
if isinstance(in_shape, list):
in_shape = tuple(in_shape)
@ -780,7 +784,8 @@ def as_gpu_kernel(
module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, prof_spec
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, prof_spec
)
)

View File

@ -27,12 +27,14 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax import random
from jax._src import api_util
from jax._src import checkify
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.lib import version as jaxlib_version
from jax._src.pallas import core as pallas_core
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import pallas as pl
import jax.numpy as jnp
@ -507,7 +509,7 @@ class PallasCallTest(PallasBaseTest):
with self.assertRaisesRegex(
ValueError,
"The kernel function .* should not capture constants"):
"The kernel function .* captures constants"):
kernel(x)
def test_vector_slicing(self):
@ -712,7 +714,7 @@ class ApiErrorTest(PallasBaseTest):
out_shape=(a, a))
with self.assertRaisesRegex(
ValueError,
"The kernel function my_kernel at .*pallas_test.py:.* in a pallas_call should return None"):
"The kernel function .* my_kernel at .*pallas_test.py:.* should return None"):
f(a)
def test_pallas_call_kernel_with_no_signature_returns_something(self):
@ -721,7 +723,7 @@ class ApiErrorTest(PallasBaseTest):
out_shape=a)
with self.assertRaisesRegex(
ValueError,
"The kernel function .* at .*pallas_test.py:.* in a pallas_call should return None"):
"The kernel function .* at .*pallas_test.py:.* should return None"):
f(a)
def test_pallas_call_in_specs_not_a_sequence(self):
@ -825,7 +827,6 @@ class ApiErrorTest(PallasBaseTest):
".* `out_specs` is a tuple of length 1 but `out_shape` is a tuple of length 2.*", re.DOTALL)):
f(a)
def test_pallas_call_block_shape_ndim_mismatch(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
@ -885,6 +886,40 @@ class ApiErrorTest(PallasBaseTest):
out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)],
input_output_aliases={1: 0})(x, x)
def test_name_and_src_info(self):
def the_kernel(): return None
ns1 = pallas_core.NameAndSrcInfo.from_pallas_call(
"my_name", api_util.fun_sourceinfo(the_kernel))
self.assertEqual("my_name", ns1.name)
self.assertIn("the_kernel", ns1.src_info)
self.assertIn("pallas_test.py:", ns1.src_info)
self.assertRegex(
str(ns1),
"my_name for kernel function the_kernel at .*pallas_test.py:.*")
ns2 = pallas_core.NameAndSrcInfo.from_pallas_call(
None,
api_util.fun_sourceinfo(the_kernel))
self.assertEqual("the_kernel", ns2.name)
self.assertIn("pallas_test.py:", ns2.src_info)
self.assertRegex(
str(ns2),
"the_kernel at .*pallas_test.py:.*")
ns3 = pallas_core.NameAndSrcInfo.from_pallas_call("my_name", None)
self.assertEqual("my_name", ns3.name)
self.assertEqual("", ns3.src_info)
self.assertEqual(str(ns3), "my_name")
ns4 = pallas_core.NameAndSrcInfo.from_pallas_call("my name with spaces",
None)
self.assertEqual("my_name_with_spaces", ns4.name)
self.assertEqual("", ns4.src_info)
ns5 = pallas_core.NameAndSrcInfo.from_pallas_call(None, None)
self.assertEqual("unknown", ns5.name)
self.assertEqual("", ns5.src_info)
class ApiErrorInterpreterTest(ApiErrorTest):
INTERPRET = True

View File

@ -147,7 +147,7 @@ class PallasCallVmapTest(PallasBaseTest):
with self.assertRaisesRegex(
ValueError,
"The kernel function .* should not capture constants"):
"The kernel function .* captures constants"):
kernel(x)
def test_vmap_of_kernel_with_input_output_aliases(self):

View File

@ -1884,10 +1884,6 @@ class PallasCallPrintTest(PallasBaseTest):
class PallasCallTraceTest(PallasBaseTest):
def parse_debug_string(self, debug_string):
jaxpr, mlir = debug_string.split('module')
return {'jaxpr': jaxpr, 'mlir': mlir}
def test_trace_start_stop_match(self):
def kernel(o_ref):
with jax.named_scope('scope1'):
@ -1900,10 +1896,10 @@ class PallasCallTraceTest(PallasBaseTest):
debug=True,
)()
# TODO(justinfu): Add an official lowering API to get the MLIR.
mlir = self.parse_debug_string(msg.getvalue())['mlir']
debug_string = msg.getvalue()
num_start = mlir.count('tpu.trace_start')
num_stop = mlir.count('tpu.trace_stop')
num_start = debug_string.count('tpu.trace_start')
num_stop = debug_string.count('tpu.trace_stop')
self.assertEqual(num_start, 1)
self.assertEqual(num_stop, 1)
@ -1926,10 +1922,10 @@ class PallasCallTraceTest(PallasBaseTest):
debug=True,
)()
# TODO(justinfu): Add an official lowering API to get the MLIR.
mlir = self.parse_debug_string(msg.getvalue())['mlir']
debug_string = msg.getvalue()
num_start = mlir.count('tpu.trace_start')
num_stop = mlir.count('tpu.trace_stop')
num_start = debug_string.count('tpu.trace_start')
num_stop = debug_string.count('tpu.trace_stop')
self.assertEqual(num_start, 2)
self.assertEqual(num_stop, 2)