mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[pallas] Improve error and debugging messages with source locations
Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function. Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive. Added some more information to the `if debug` prints. Set the MLIR module names so that the debug dumps are named properly. I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules. PiperOrigin-RevId: 659506675
This commit is contained in:
parent
b2a469b361
commit
252032a368
@ -36,6 +36,7 @@ from jax._src import mesh as mesh_lib
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state import discharge as state_discharge
|
||||
import jax.numpy as jnp
|
||||
@ -56,9 +57,47 @@ TupleGrid = tuple[GridElement, ...]
|
||||
Grid = Union[NamedGrid, TupleGrid]
|
||||
StaticGrid = tuple[int, ...]
|
||||
GridMappingGrid = tuple[int | DynamicGridDim, ...]
|
||||
SrcInfoStr = str # function_name at filename:linenumber
|
||||
OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NameAndSrcInfo:
|
||||
#: The name of the pallas_call or the name of the kernel function.
|
||||
name: str
|
||||
#: the source info, and the name of kernel function if not in `name`.`
|
||||
src_info: str
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}{' ' if self.src_info else ''}{self.src_info}"
|
||||
__repr__ = __str__
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pallas_call(pallas_call_name: str | None,
|
||||
src_info : str | None) -> NameAndSrcInfo:
|
||||
"""Formats the name and the source info.
|
||||
|
||||
Args:
|
||||
pallas_call_name: The `name` argument to pallas_call.
|
||||
src_info: The result of `api_util.fun_source_info(kernel)`, in the form
|
||||
"{function_name} at {file_name}:{line_number}".
|
||||
"""
|
||||
if pallas_call_name is not None:
|
||||
pallas_call_name = mlir._module_name_regex.sub("_", pallas_call_name)
|
||||
if src_info is None:
|
||||
return NameAndSrcInfo(
|
||||
"unknown" if pallas_call_name is None else pallas_call_name,
|
||||
"")
|
||||
if pallas_call_name is not None:
|
||||
return NameAndSrcInfo(pallas_call_name,
|
||||
f"for kernel function {src_info}")
|
||||
src_info_parts = src_info.split(" ")
|
||||
return NameAndSrcInfo(src_info_parts[0],
|
||||
" ".join(src_info_parts[1:]))
|
||||
|
||||
|
||||
# Pytrees of jax.ShapeDtypeStruct
|
||||
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]
|
||||
|
||||
@ -268,7 +307,7 @@ class BlockMapping:
|
||||
block_shape: tuple[Mapped | int, ...]
|
||||
block_aval: AbstractMemoryRef # The block ref aval
|
||||
index_map_jaxpr: jax_core.ClosedJaxpr
|
||||
index_map_src_info: SrcInfoStr
|
||||
index_map_src_info: NameAndSrcInfo
|
||||
indexing_mode: IndexingMode
|
||||
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
|
||||
origin: OriginStr
|
||||
@ -534,7 +573,8 @@ def _convert_block_spec_to_block_mapping(
|
||||
lu.wrap_init(index_map_func), index_map_tree)
|
||||
debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk,
|
||||
False, "pallas_call index_map")
|
||||
index_map_src_info = debug.func_src_info or "<unknown>"
|
||||
index_map_src_info = NameAndSrcInfo.from_pallas_call(None,
|
||||
debug.func_src_info)
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
|
||||
index_map_avals,
|
||||
|
@ -49,7 +49,7 @@ from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import core as pl_core
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
@ -71,7 +71,7 @@ import numpy as np
|
||||
|
||||
NDIndexer = indexing.NDIndexer
|
||||
TPUMemorySpace = tpu_core.TPUMemorySpace
|
||||
MemorySpace = pl_core.MemorySpace | TPUMemorySpace
|
||||
MemorySpace = pallas_core.MemorySpace | TPUMemorySpace
|
||||
VMEM = tpu_core.TPUMemorySpace.VMEM
|
||||
SMEM = tpu_core.TPUMemorySpace.SMEM
|
||||
# Booleans are stored as the following type in memrefs.
|
||||
@ -105,7 +105,7 @@ class LoweringContext:
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
|
||||
user_grid_indices: Sequence[ir.Value] | None
|
||||
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
|
||||
block_shapes: list[tuple[int | pallas_core.Mapped, ...]]
|
||||
name_stack: source_info_util.NameStack
|
||||
mesh_context: MeshContext | None
|
||||
replace = dataclasses.replace
|
||||
@ -136,7 +136,7 @@ class LoweringRuleContext:
|
||||
lowering_context: LoweringContext
|
||||
avals_in: Sequence[jax_core.AbstractValue]
|
||||
avals_out: Sequence[jax_core.AbstractValue]
|
||||
block_shapes: Sequence[tuple[int | pl_core.Mapped, ...] | None]
|
||||
block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None]
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
@ -145,9 +145,9 @@ def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None
|
||||
) -> ir.Attribute:
|
||||
if memory_space is None:
|
||||
memory_space = VMEM
|
||||
elif memory_space == pl_core.MemorySpace.ERROR:
|
||||
elif memory_space == pallas_core.MemorySpace.ERROR:
|
||||
memory_space = SMEM
|
||||
elif memory_space == pl_core.MemorySpace.INDEX:
|
||||
elif memory_space == pallas_core.MemorySpace.INDEX:
|
||||
memory_space = SMEM
|
||||
return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>")
|
||||
|
||||
@ -252,10 +252,10 @@ def _get_aval_physical_dtype_shape(aval):
|
||||
|
||||
def _get_arg_type(
|
||||
aval,
|
||||
block_mapping: pl_core.BlockMapping | None,
|
||||
block_mapping: pallas_core.BlockMapping | None,
|
||||
):
|
||||
memory_space = None
|
||||
if isinstance(aval, pl_core.AbstractMemoryRef):
|
||||
if isinstance(aval, pallas_core.AbstractMemoryRef):
|
||||
memory_space = aval.memory_space
|
||||
# We assume unannotated memory refs are in VMEM
|
||||
if memory_space is None:
|
||||
@ -265,7 +265,7 @@ def _get_arg_type(
|
||||
# TODO(necula): clean this None block_mapping
|
||||
if block_mapping is None:
|
||||
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
|
||||
shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape)
|
||||
shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape)
|
||||
return (
|
||||
aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
|
||||
block_mapping.block_shape,
|
||||
@ -277,7 +277,7 @@ class MosaicGridMapping:
|
||||
grid: tuple[int, ...] | None
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
jaxpr: jax_core.Jaxpr
|
||||
block_mappings: tuple[pl_core.BlockMapping | None, ...]
|
||||
block_mappings: tuple[pallas_core.BlockMapping | None, ...]
|
||||
mapped_dims: tuple[int, ...]
|
||||
scalar_prefetch_types: tuple[ir.Type, ...]
|
||||
operand_types: tuple[ir.Type, ...]
|
||||
@ -289,7 +289,7 @@ class MosaicGridMapping:
|
||||
mesh_info: MeshInfo | None
|
||||
get_grid_indices: Callable | None
|
||||
|
||||
def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
|
||||
def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping,
|
||||
dimension_semantics: tuple[str, ...] | None,
|
||||
mesh: mesh_lib.Mesh | None):
|
||||
self.grid = grid_mapping.grid
|
||||
@ -340,7 +340,7 @@ class MosaicGridMapping:
|
||||
for aval in scratch_avals
|
||||
)
|
||||
self.grid_types, _ = unzip2([
|
||||
_get_arg_type(pl_core.index_map_grid_aval, None)
|
||||
_get_arg_type(pallas_core.index_map_grid_aval, None)
|
||||
for _ in range(len(self.grid))
|
||||
])
|
||||
self._prepare_mesh_info(mesh)
|
||||
@ -418,9 +418,11 @@ class MeshInfo:
|
||||
def lower_jaxpr_to_module(
|
||||
lowering_context: mlir.LoweringRuleContext,
|
||||
ctx: ir.Context,
|
||||
grid_mapping: pl_core.GridMapping,
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
*,
|
||||
dimension_semantics: tuple[str | None, ...] | None,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
mesh: mesh_lib.Mesh | None = None,
|
||||
for_verification: bool = False,
|
||||
) -> tuple[Module, tuple[Any, ...]]:
|
||||
@ -432,7 +434,8 @@ def lower_jaxpr_to_module(
|
||||
bm.has_trivial_window()):
|
||||
continue
|
||||
def err_details():
|
||||
return (f"Block spec for {bm.origin} has block shape "
|
||||
return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} "
|
||||
"has block shape "
|
||||
f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, "
|
||||
# TODO(necula): add index_map source location info
|
||||
f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in "
|
||||
@ -460,7 +463,7 @@ def lower_jaxpr_to_module(
|
||||
"only blocks having the same block shape as the array shape "
|
||||
"and a trivial index_map (returning all 0s)." + err_details())
|
||||
|
||||
unmapped_bs = [1 if bs is pl_core.mapped else bs for bs in bm.block_shape]
|
||||
unmapped_bs = [1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
|
||||
bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1]
|
||||
if rank >= 2:
|
||||
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
|
||||
@ -507,6 +510,9 @@ def lower_jaxpr_to_module(
|
||||
jaxpr, grid_mapping, dimension_semantics, mesh)
|
||||
mosaic_grid_mapping.maybe_compress_grid()
|
||||
m = ir.Module.create()
|
||||
attrs = m.operation.attributes
|
||||
module_name = name_and_src_info.name
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
sym_tab = ir.SymbolTable(m.operation)
|
||||
func_op = lower_jaxpr_to_func(
|
||||
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
@ -534,7 +540,7 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
assert mlir_func.verify(), mlir_func
|
||||
block_shape = [
|
||||
1 if b is pl_core.mapped else b for b in bm.block_shape
|
||||
1 if b is pallas_core.mapped else b for b in bm.block_shape
|
||||
]
|
||||
# If we have an extended dtype, we need to add the block shape for the
|
||||
# remaining physical dtype.
|
||||
@ -544,7 +550,7 @@ def lower_jaxpr_to_module(
|
||||
window_bounds=window_shape,
|
||||
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
|
||||
)
|
||||
if isinstance(bm.indexing_mode, pl_core.Unblocked):
|
||||
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
if bm.indexing_mode.padding is None:
|
||||
pad_low = pad_high = [0] * len(bm.block_shape)
|
||||
else:
|
||||
@ -557,7 +563,7 @@ def lower_jaxpr_to_module(
|
||||
sym_tab.insert(mlir_func)
|
||||
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
|
||||
static_grid = [
|
||||
MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid
|
||||
MLIR_DYNAMIC if b is pallas_core.dynamic_grid_dim else b for b in grid
|
||||
]
|
||||
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
|
||||
|
||||
@ -911,7 +917,7 @@ def _make_index(s):
|
||||
def _maybe_cast_to_index(cast_to_index, x):
|
||||
if cast_to_index:
|
||||
return _make_index(x)
|
||||
return _ensure_mlir_value(x, aval=pl_core.index_map_grid_aval)
|
||||
return _ensure_mlir_value(x, aval=pallas_core.index_map_grid_aval)
|
||||
|
||||
|
||||
def _index_to_start_size_stride(
|
||||
@ -940,7 +946,7 @@ def _index_to_start_size_stride(
|
||||
|
||||
def _indexer_to_start_size_stride(
|
||||
indexer: NDIndexer,
|
||||
ref_block_shape: tuple[int | pl_core.Mapped, ...],
|
||||
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
||||
*,
|
||||
cast_to_index: bool,
|
||||
) -> tuple[
|
||||
@ -948,7 +954,7 @@ def _indexer_to_start_size_stride(
|
||||
tuple[int | ir.Value, ...],
|
||||
tuple[int, ...],
|
||||
tuple[bool, ...],
|
||||
tuple[int | pl_core.Mapped, ...],
|
||||
tuple[int | pallas_core.Mapped, ...],
|
||||
]:
|
||||
indices_iter = iter(indexer.indices)
|
||||
starts, sizes, strides, squeeze_dims = [], [], [], []
|
||||
@ -960,7 +966,7 @@ def _indexer_to_start_size_stride(
|
||||
1,
|
||||
True,
|
||||
)
|
||||
if s is pl_core.mapped
|
||||
if s is pallas_core.mapped
|
||||
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
|
||||
)
|
||||
starts.append(start)
|
||||
@ -982,9 +988,9 @@ def _indexer_to_start_size_stride(
|
||||
|
||||
def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
|
||||
indexer: NDIndexer,
|
||||
ref_block_shape: tuple[int | pl_core.Mapped, ...]
|
||||
) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...],
|
||||
tuple[int | pl_core.Mapped, ...]]:
|
||||
ref_block_shape: tuple[int | pallas_core.Mapped, ...]
|
||||
) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...],
|
||||
tuple[int | pallas_core.Mapped, ...]]:
|
||||
assert ref_block_shape is not None
|
||||
target_shape = indexer.get_indexer_shape()
|
||||
starts, sizes, strides, squeeze_dims, ref_block_shape = (
|
||||
@ -1216,7 +1222,7 @@ def _masked_swap_lowering_rule(
|
||||
mem_slice_shape.insert(i, 1)
|
||||
mem_slice_shape_iter = iter(mem_slice_shape)
|
||||
mem_slice_shape = [
|
||||
1 if b is pl_core.mapped else next(mem_slice_shape_iter)
|
||||
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
|
||||
for b in ref_block_shape
|
||||
]
|
||||
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
|
||||
@ -2126,8 +2132,8 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
|
||||
if unroll != 1:
|
||||
raise NotImplementedError(
|
||||
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
|
||||
lbd = _ensure_mlir_value(start, pl_core.index_map_grid_aval)
|
||||
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pl_core.index_map_grid_aval))
|
||||
lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval)
|
||||
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval))
|
||||
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
||||
for_op = scf.ForOp(lbd, ubd, step, args)
|
||||
with ir.InsertionPoint(for_op.body):
|
||||
@ -2525,7 +2531,7 @@ def _bitcast_convert_type_lowering_rule(
|
||||
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule
|
||||
|
||||
def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
|
||||
if isinstance(aval, pl_core.AbstractMemoryRef):
|
||||
if isinstance(aval, pallas_core.AbstractMemoryRef):
|
||||
memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>")
|
||||
if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype):
|
||||
assert aval.memory_space == TPUMemorySpace.SEMAPHORE
|
||||
@ -2574,8 +2580,8 @@ def _device_id_to_logical(
|
||||
return sum(a * b for a, b in zip(indices, mesh_strides))
|
||||
lower_ctx = LoweringRuleContext(
|
||||
lowering_context=ctx.lowering_context,
|
||||
avals_in=[pl_core.index_map_grid_aval] * len(device_ids),
|
||||
avals_out=[pl_core.index_map_grid_aval],
|
||||
avals_in=[pallas_core.index_map_grid_aval] * len(device_ids),
|
||||
avals_out=[pallas_core.index_map_grid_aval],
|
||||
block_shapes=(None,) * len(device_ids),
|
||||
)
|
||||
return lower_fun(_linearize_mesh_indices, multiple_results=False)(
|
||||
@ -2855,7 +2861,7 @@ def _shard_map_discharge_rule(
|
||||
rewrite,
|
||||
):
|
||||
del out_avals, auto, in_names, out_names, check_rep, rewrite
|
||||
if not isinstance(mesh, pl_core.PallasMesh):
|
||||
if not isinstance(mesh, pallas_core.PallasMesh):
|
||||
raise NotImplementedError("Mesh must be a PallasMesh")
|
||||
if len(mesh.shape) > 1:
|
||||
raise NotImplementedError("Mesh must be 1D")
|
||||
@ -2867,9 +2873,9 @@ def _shard_map_discharge_rule(
|
||||
out = pallas_call.pallas_call(
|
||||
body,
|
||||
out_shape=in_avals,
|
||||
in_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
|
||||
in_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
|
||||
* len(in_avals),
|
||||
out_specs=[pl_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
|
||||
out_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)]
|
||||
* len(in_avals),
|
||||
input_output_aliases={i: i for i in range(len(in_avals))},
|
||||
grid=((core_axis_name, num_cores),),
|
||||
|
@ -66,7 +66,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
name_and_src_info: core.NameAndSrcInfo,
|
||||
grid_mapping: core.GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
@ -75,6 +75,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
del interpret
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
|
||||
print(jaxpr)
|
||||
if "mosaic_params" in compiler_params:
|
||||
# TODO(slebedev): Remove this branch after July 12th 2024.
|
||||
@ -106,9 +107,11 @@ def pallas_call_tpu_lowering_rule(
|
||||
return lowering.lower_jaxpr_to_module(
|
||||
ctx, mlir_ctx, grid_mapping, jaxpr,
|
||||
dimension_semantics=dimension_semantics, mesh=mesh,
|
||||
for_verification=for_verification)
|
||||
for_verification=for_verification,
|
||||
name_and_src_info=name_and_src_info)
|
||||
mosaic_module, extra_args = lower_module(for_verification=False)
|
||||
if debug:
|
||||
print(f"\nThe Mosaic module for pallas_call {name_and_src_info}:")
|
||||
print(mosaic_module)
|
||||
num_extra_args = len(extra_args)
|
||||
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
|
||||
@ -132,6 +135,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
verification_module, num_devices, num_cores
|
||||
)
|
||||
if promela_dump_path == "stdout":
|
||||
print(f"The Promela model for pallas_call {name_and_src_info}:")
|
||||
print(model)
|
||||
else:
|
||||
if promela_dump_path == "sponge":
|
||||
@ -142,7 +146,10 @@ def pallas_call_tpu_lowering_rule(
|
||||
" --jax_pallas_dump_promela_to=sponge"
|
||||
)
|
||||
dump_ctx = tempfile.NamedTemporaryFile(
|
||||
mode="w", prefix=name + "-", suffix=".pml", dir=promela_dump_path, delete=False,
|
||||
mode="w",
|
||||
prefix=name_and_src_info.name + "-",
|
||||
suffix=".pml",
|
||||
dir=promela_dump_path, delete=False,
|
||||
)
|
||||
with dump_ctx as f:
|
||||
f.write(model)
|
||||
@ -173,7 +180,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
module=mosaic_module,
|
||||
out_type=kernel_out_avals,
|
||||
backend="tpu",
|
||||
kernel_name=name,
|
||||
kernel_name=name_and_src_info.name,
|
||||
cost_estimate=mosaic_params.get("cost_estimate"),
|
||||
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
|
||||
flags=mosaic_params.get("flags"),
|
||||
|
@ -31,7 +31,7 @@ from jax._src.lax import lax
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import memref as memref_dialect
|
||||
from jax._src.pallas import core as pl_core
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.state import primitives as sp
|
||||
from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
@ -53,7 +53,7 @@ partial = functools.partial
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
name: str
|
||||
grid_mapping: pl_core.GridMapping
|
||||
grid_mapping: pallas_core.GridMapping
|
||||
runtime_smem: ir.Value # ir.MemRefType
|
||||
smem_used_bytes: int
|
||||
|
||||
@ -117,7 +117,7 @@ class LoweringRuleContext:
|
||||
module_context: ModuleContext
|
||||
avals_in: Sequence[jax_core.ShapedArray]
|
||||
avals_out: Sequence[jax_core.ShapedArray]
|
||||
block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None
|
||||
block_shapes: list[tuple[int | pallas_core.Mapped, ...]] | None
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
@ -142,9 +142,9 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
grid_mapping: pl_core.GridMapping,
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
compiler_params: dict[str, Any],
|
||||
) -> LoweringResult:
|
||||
in_structs = tuple(grid_mapping.in_shapes)
|
||||
@ -180,7 +180,8 @@ def lower_jaxpr_to_module(
|
||||
|
||||
barrier.wait()
|
||||
|
||||
module_ctx = ModuleContext(name, grid_mapping, runtime_smem, smem_used_bytes=0)
|
||||
module_ctx = ModuleContext(name_and_src_info.name,
|
||||
grid_mapping, runtime_smem, smem_used_bytes=0)
|
||||
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, buffers_smem)
|
||||
|
||||
for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem):
|
||||
@ -210,6 +211,7 @@ def lower_jaxpr_to_module(
|
||||
*extra_smem_scratch,
|
||||
mgpu.TMABarrier(),
|
||||
),
|
||||
module_name=name_and_src_info.name,
|
||||
)
|
||||
|
||||
return LoweringResult(module, grid, gmem_scratch_bytes, out_structs)
|
||||
|
@ -30,7 +30,7 @@ def pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
@ -48,16 +48,19 @@ def pallas_call_lowering(
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
|
||||
print(jaxpr)
|
||||
print(f"The grid mapping for pallas_call {name_and_src_info}:")
|
||||
print(grid_mapping)
|
||||
|
||||
lowering_result = lowering.lower_jaxpr_to_module(
|
||||
grid_mapping,
|
||||
jaxpr,
|
||||
name,
|
||||
name_and_src_info,
|
||||
compiler_params,
|
||||
)
|
||||
if debug:
|
||||
print(f"\nThe Mosaic GPU module for pallas_call {name_and_src_info}:")
|
||||
print(lowering_result.module.operation)
|
||||
|
||||
module = lowering_result.module
|
||||
|
@ -167,12 +167,12 @@ def _pallas_call_impl(*args, **kwargs):
|
||||
def _pallas_call_impl_interpret(
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
name_and_src_info: pallas_core.NameAndStrInfo,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
compiler_params: Any):
|
||||
del compiler_params, name
|
||||
del compiler_params
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr.
|
||||
dynamic_grid_args, args = split_list( # type: ignore
|
||||
@ -188,6 +188,7 @@ def _pallas_call_impl_interpret(
|
||||
with grid_mapping.trace_env():
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(f"\nJaxpr the the kernel in pallas_call {name_and_src_info}:")
|
||||
print(discharged_jaxpr)
|
||||
out = _initialize_output_vals(grid_mapping.block_mappings_output,
|
||||
args, input_output_aliases)
|
||||
@ -301,7 +302,7 @@ def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
|
||||
for bm in grid_mapping.block_mappings_output)
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping, debug, interpret, compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
@ -336,8 +337,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
effs.append(eff)
|
||||
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
|
||||
if debug:
|
||||
print(f"\nThe jaxpr for the jvp of pallas_call {name_and_src_info}:")
|
||||
print(jvp_jaxpr)
|
||||
# TODO(necula): does this work with consts?
|
||||
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
||||
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
|
||||
jvp_grid_mapping = grid_mapping.replace(
|
||||
@ -349,7 +350,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
*primals,
|
||||
*tangents,
|
||||
jaxpr=jvp_jaxpr,
|
||||
name=f"{name}_jvp",
|
||||
name_and_src_info=name_and_src_info.replace(
|
||||
name=f"{name_and_src_info.name}_jvp"),
|
||||
grid_mapping=jvp_grid_mapping,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
@ -428,7 +430,7 @@ def _batch_with_explicit_loop(
|
||||
dims: Sequence[int | batching.NotMapped],
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
@ -493,7 +495,7 @@ def _batch_with_explicit_loop(
|
||||
batch_out = pallas_call_p.bind(
|
||||
*batch_args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
name_and_src_info=name_and_src_info,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -520,7 +522,7 @@ def _pallas_call_batching_rule(
|
||||
dims,
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
@ -542,7 +544,7 @@ def _pallas_call_batching_rule(
|
||||
out = pallas_call_p.bind(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
name_and_src_info=name_and_src_info,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -573,7 +575,7 @@ def _pallas_call_batching_rule(
|
||||
args=dynamic_grid_args + args,
|
||||
dims=dynamic_grid_dims + dims,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
name_and_src_info=name_and_src_info,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -605,7 +607,7 @@ def _pallas_call_batching_rule(
|
||||
args=scalar_args + args,
|
||||
dims=scalar_bdims + bdims,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
name_and_src_info=name_and_src_info,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -660,7 +662,8 @@ def _pallas_call_batching_rule(
|
||||
*dynamic_grid_args,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=f"batched_{name}",
|
||||
name_and_src_info=name_and_src_info.replace(
|
||||
name=f"{name_and_src_info.name}_batched"),
|
||||
grid_mapping=batched_grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -836,7 +839,7 @@ checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
fun_src_info: pallas_core.SrcInfoStr,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
grid_mapping: GridMapping,
|
||||
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
|
||||
kernel_in_tree: tree_util.PyTreeDef,
|
||||
@ -855,23 +858,17 @@ def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
|
||||
for c in consts]
|
||||
raise ValueError(
|
||||
f"The kernel function {fun_src_info} in a "
|
||||
"pallas_call should not capture constants. You should pass them "
|
||||
f"as inputs. It captures constants of shapes: {consts_avals}")
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
f"captures constants {consts_avals}. "
|
||||
"You should pass them as inputs")
|
||||
|
||||
kernel_out_tree = out_tree_thunk()
|
||||
if kernel_out_tree != tree_util.tree_structure(None):
|
||||
raise ValueError(
|
||||
f"The kernel function {fun_src_info} in a "
|
||||
f"pallas_call should return None. "
|
||||
f"It returns a PyTree: {kernel_out_tree}")
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
f"should return None. It returns a PyTree: {kernel_out_tree}")
|
||||
return jaxpr
|
||||
|
||||
def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
if name is None:
|
||||
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
|
||||
return name
|
||||
|
||||
|
||||
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
|
||||
"jax_pallas_use_mosaic_gpu",
|
||||
@ -1009,7 +1006,11 @@ def pallas_call(
|
||||
grid whose body is the kernel lowered as a JAX function. This does not
|
||||
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
|
||||
This is useful for debugging.
|
||||
name: TO BE DOCUMENTED.
|
||||
name: if present, specifies the name to use for this kernel call in
|
||||
debugging and error messages. To this name we append the file and line
|
||||
where the kernel function is defined, .e.g:
|
||||
`{name} for kernel function {kernel_name} at {file}:{line}`.
|
||||
If missing, then we use `{kernel_name} at {file}:{line}`.
|
||||
compiler_params: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
@ -1017,7 +1018,9 @@ def pallas_call(
|
||||
invoke the Pallas kernel.
|
||||
|
||||
"""
|
||||
name = _extract_function_name(kernel, name)
|
||||
kernel_src_info = api_util.fun_sourceinfo(kernel)
|
||||
name_and_src_info = pallas_core.NameAndSrcInfo.from_pallas_call(
|
||||
name, kernel_src_info)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
|
||||
@ -1058,16 +1061,15 @@ def pallas_call(
|
||||
|
||||
kernel_fun_sig = api_util.fun_signature(kernel)
|
||||
arg_names = None
|
||||
kernel_src_info: pallas_core.SrcInfoStr = "<unknown>"
|
||||
if kernel_fun_sig:
|
||||
kernel_debug_info = api_util.debug_info(
|
||||
"pallas_call kernel",
|
||||
api_util.fun_sourceinfo(kernel),
|
||||
kernel_src_info,
|
||||
kernel_fun_sig,
|
||||
[1] * len(kernel_fun_sig.parameters), {}, (), ())
|
||||
if kernel_debug_info:
|
||||
arg_names = kernel_debug_info.arg_names
|
||||
kernel_src_info = kernel_debug_info.func_src_info
|
||||
del kernel_debug_info
|
||||
in_origins = tuple(in_path_to_input_origin(p, arg_names)
|
||||
for p in in_paths)
|
||||
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
|
||||
@ -1105,7 +1107,8 @@ def pallas_call(
|
||||
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
|
||||
out_flat = pallas_call_p.bind(
|
||||
*dynamic_grid_bounds, *index_args, *rest_args,
|
||||
jaxpr=jaxpr, name=name,
|
||||
jaxpr=jaxpr,
|
||||
name_and_src_info=name_and_src_info,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping,
|
||||
|
@ -266,7 +266,7 @@ def _check_tensor_size(shape: tuple[int | pallas_core.Mapped, ...]):
|
||||
def lower_jaxpr_to_triton_module(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
grid_mapping: GridMapping,
|
||||
name: str,
|
||||
name_and_src_info: pallas_core.NameAndStrInfo,
|
||||
platform: str
|
||||
) -> LoweringResult:
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
@ -283,6 +283,9 @@ def lower_jaxpr_to_triton_module(
|
||||
)
|
||||
with _new_ir_context(), ir.Location.unknown():
|
||||
module = ir.Module.create()
|
||||
attrs = module.operation.attributes
|
||||
module_name = name_and_src_info.name
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
param_types = [
|
||||
tt_dialect.PointerType.get(_dtype_to_ir_type(var.aval.dtype), 1)
|
||||
for var in jaxpr.invars
|
||||
@ -290,7 +293,7 @@ def lower_jaxpr_to_triton_module(
|
||||
assert len(jaxpr.outvars) == 0
|
||||
fn_type = ir.FunctionType.get(param_types, [])
|
||||
fn = tt_dialect.FuncOp(
|
||||
name,
|
||||
name_and_src_info.name,
|
||||
ir.TypeAttr.get(fn_type),
|
||||
sym_visibility="public",
|
||||
res_attrs=ir.DictAttr.get(dict(noinline=ir.BoolAttr.get(False))),
|
||||
@ -310,7 +313,8 @@ def lower_jaxpr_to_triton_module(
|
||||
if i not in grid_mapping.vmapped_dims
|
||||
]
|
||||
ctx = ModuleContext(
|
||||
name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
|
||||
name_and_src_info.name,
|
||||
grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
|
||||
)
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError(
|
||||
|
@ -42,7 +42,7 @@ def pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
@ -67,14 +67,17 @@ def pallas_call_lowering(
|
||||
num_stages = triton_params.pop("num_stages", 3)
|
||||
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
|
||||
print(jaxpr)
|
||||
print("The grid mapping for pallas_call {name_and_src_info}:")
|
||||
print(grid_mapping)
|
||||
|
||||
lowering_result = lowering.lower_jaxpr_to_triton_module(
|
||||
jaxpr, grid_mapping, name, lowering_platform
|
||||
jaxpr, grid_mapping, name_and_src_info, lowering_platform
|
||||
)
|
||||
module_op = lowering_result.module.operation
|
||||
if debug:
|
||||
print(f"\nThe Triton module for pallas_call {name_and_src_info}:")
|
||||
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
|
||||
|
||||
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
|
||||
@ -86,7 +89,7 @@ def pallas_call_lowering(
|
||||
buf = io.BytesIO()
|
||||
module_op.write_bytecode(buf)
|
||||
backend_config = dict(
|
||||
name=ir.StringAttr.get(name),
|
||||
name=ir.StringAttr.get(name_and_src_info.name),
|
||||
ir=ir.StringAttr.get(buf.getvalue()),
|
||||
num_stages=mlir.i32_attr(num_stages),
|
||||
num_warps=mlir.i32_attr(num_warps),
|
||||
|
@ -690,6 +690,7 @@ def _lower_as_gpu_kernel(
|
||||
in_shapes: tuple[Any, ...],
|
||||
out_shape,
|
||||
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
||||
module_name: str,
|
||||
prof_spec: profiler.ProfilerSpec | None = None,
|
||||
):
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
@ -714,6 +715,8 @@ def _lower_as_gpu_kernel(
|
||||
out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block))
|
||||
|
||||
module = ir.Module.create()
|
||||
attrs = module.operation.attributes
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
with ir.InsertionPoint(module.body):
|
||||
_declare_runtime_functions()
|
||||
gmem_scratch_bytes = 0
|
||||
@ -772,6 +775,7 @@ def as_gpu_kernel(
|
||||
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
||||
prof_spec: profiler.ProfilerSpec | None = None,
|
||||
cluster: tuple[int, int, int] = (1, 1, 1),
|
||||
module_name: str = "unknown",
|
||||
):
|
||||
if isinstance(in_shape, list):
|
||||
in_shape = tuple(in_shape)
|
||||
@ -780,7 +784,8 @@ def as_gpu_kernel(
|
||||
|
||||
module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, prof_spec
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
|
||||
module_name, prof_spec
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -27,12 +27,14 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax._src import api_util
|
||||
from jax._src import checkify
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
@ -507,7 +509,7 @@ class PallasCallTest(PallasBaseTest):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The kernel function .* should not capture constants"):
|
||||
"The kernel function .* captures constants"):
|
||||
kernel(x)
|
||||
|
||||
def test_vector_slicing(self):
|
||||
@ -712,7 +714,7 @@ class ApiErrorTest(PallasBaseTest):
|
||||
out_shape=(a, a))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The kernel function my_kernel at .*pallas_test.py:.* in a pallas_call should return None"):
|
||||
"The kernel function .* my_kernel at .*pallas_test.py:.* should return None"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_kernel_with_no_signature_returns_something(self):
|
||||
@ -721,7 +723,7 @@ class ApiErrorTest(PallasBaseTest):
|
||||
out_shape=a)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The kernel function .* at .*pallas_test.py:.* in a pallas_call should return None"):
|
||||
"The kernel function .* at .*pallas_test.py:.* should return None"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_in_specs_not_a_sequence(self):
|
||||
@ -825,7 +827,6 @@ class ApiErrorTest(PallasBaseTest):
|
||||
".* `out_specs` is a tuple of length 1 but `out_shape` is a tuple of length 2.*", re.DOTALL)):
|
||||
f(a)
|
||||
|
||||
|
||||
def test_pallas_call_block_shape_ndim_mismatch(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
@ -885,6 +886,40 @@ class ApiErrorTest(PallasBaseTest):
|
||||
out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)],
|
||||
input_output_aliases={1: 0})(x, x)
|
||||
|
||||
def test_name_and_src_info(self):
|
||||
def the_kernel(): return None
|
||||
ns1 = pallas_core.NameAndSrcInfo.from_pallas_call(
|
||||
"my_name", api_util.fun_sourceinfo(the_kernel))
|
||||
self.assertEqual("my_name", ns1.name)
|
||||
self.assertIn("the_kernel", ns1.src_info)
|
||||
self.assertIn("pallas_test.py:", ns1.src_info)
|
||||
self.assertRegex(
|
||||
str(ns1),
|
||||
"my_name for kernel function the_kernel at .*pallas_test.py:.*")
|
||||
|
||||
ns2 = pallas_core.NameAndSrcInfo.from_pallas_call(
|
||||
None,
|
||||
api_util.fun_sourceinfo(the_kernel))
|
||||
self.assertEqual("the_kernel", ns2.name)
|
||||
self.assertIn("pallas_test.py:", ns2.src_info)
|
||||
self.assertRegex(
|
||||
str(ns2),
|
||||
"the_kernel at .*pallas_test.py:.*")
|
||||
|
||||
ns3 = pallas_core.NameAndSrcInfo.from_pallas_call("my_name", None)
|
||||
self.assertEqual("my_name", ns3.name)
|
||||
self.assertEqual("", ns3.src_info)
|
||||
self.assertEqual(str(ns3), "my_name")
|
||||
|
||||
ns4 = pallas_core.NameAndSrcInfo.from_pallas_call("my name with spaces",
|
||||
None)
|
||||
self.assertEqual("my_name_with_spaces", ns4.name)
|
||||
self.assertEqual("", ns4.src_info)
|
||||
|
||||
ns5 = pallas_core.NameAndSrcInfo.from_pallas_call(None, None)
|
||||
self.assertEqual("unknown", ns5.name)
|
||||
self.assertEqual("", ns5.src_info)
|
||||
|
||||
|
||||
class ApiErrorInterpreterTest(ApiErrorTest):
|
||||
INTERPRET = True
|
||||
|
@ -147,7 +147,7 @@ class PallasCallVmapTest(PallasBaseTest):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The kernel function .* should not capture constants"):
|
||||
"The kernel function .* captures constants"):
|
||||
kernel(x)
|
||||
|
||||
def test_vmap_of_kernel_with_input_output_aliases(self):
|
||||
|
@ -1884,10 +1884,6 @@ class PallasCallPrintTest(PallasBaseTest):
|
||||
|
||||
class PallasCallTraceTest(PallasBaseTest):
|
||||
|
||||
def parse_debug_string(self, debug_string):
|
||||
jaxpr, mlir = debug_string.split('module')
|
||||
return {'jaxpr': jaxpr, 'mlir': mlir}
|
||||
|
||||
def test_trace_start_stop_match(self):
|
||||
def kernel(o_ref):
|
||||
with jax.named_scope('scope1'):
|
||||
@ -1900,10 +1896,10 @@ class PallasCallTraceTest(PallasBaseTest):
|
||||
debug=True,
|
||||
)()
|
||||
# TODO(justinfu): Add an official lowering API to get the MLIR.
|
||||
mlir = self.parse_debug_string(msg.getvalue())['mlir']
|
||||
debug_string = msg.getvalue()
|
||||
|
||||
num_start = mlir.count('tpu.trace_start')
|
||||
num_stop = mlir.count('tpu.trace_stop')
|
||||
num_start = debug_string.count('tpu.trace_start')
|
||||
num_stop = debug_string.count('tpu.trace_stop')
|
||||
self.assertEqual(num_start, 1)
|
||||
self.assertEqual(num_stop, 1)
|
||||
|
||||
@ -1926,10 +1922,10 @@ class PallasCallTraceTest(PallasBaseTest):
|
||||
debug=True,
|
||||
)()
|
||||
# TODO(justinfu): Add an official lowering API to get the MLIR.
|
||||
mlir = self.parse_debug_string(msg.getvalue())['mlir']
|
||||
debug_string = msg.getvalue()
|
||||
|
||||
num_start = mlir.count('tpu.trace_start')
|
||||
num_stop = mlir.count('tpu.trace_stop')
|
||||
num_start = debug_string.count('tpu.trace_start')
|
||||
num_stop = debug_string.count('tpu.trace_stop')
|
||||
self.assertEqual(num_start, 2)
|
||||
self.assertEqual(num_stop, 2)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user