mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
This commit is contained in:
parent
152099ee0e
commit
b01111d96c
14
jax/BUILD
14
jax/BUILD
@ -598,6 +598,20 @@ pytype_strict_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "source_mapper",
|
||||
srcs = glob(include = ["experimental/source_mapper/**/*.py"]),
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":config",
|
||||
":core",
|
||||
":jax",
|
||||
":source_info_util",
|
||||
] + py_deps("absl/flags"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas",
|
||||
srcs = glob(
|
||||
|
29
jax/experimental/source_mapper/__init__.py
Normal file
29
jax/experimental/source_mapper/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.sourcemap import SourceMap as SourceMap
|
||||
from jax._src.sourcemap import MappingsGenerator as MappingsGenerator
|
||||
from jax.experimental.source_mapper.common import Pass as Pass
|
||||
from jax.experimental.source_mapper.common import register_pass as register_pass
|
||||
from jax.experimental.source_mapper.common import all_passes as all_passes
|
||||
from jax.experimental.source_mapper.common import filter_passes as filter_passes
|
||||
from jax.experimental.source_mapper.common import compile_with_env as compile_with_env
|
||||
from jax.experimental.source_mapper.common import SourceMapDump as SourceMapDump
|
||||
from jax.experimental.source_mapper.generate_map import generate_sourcemaps as generate_sourcemaps
|
||||
from jax.experimental.source_mapper.mlir import create_mlir_sourcemap as create_mlir_sourcemap
|
||||
|
||||
# We import the jaxpr and hlo passes to register them.
|
||||
import jax.experimental.source_mapper.jaxpr # pylint: disable=unused-import # noqa: F401
|
||||
from jax.experimental.source_mapper.jaxpr import canonicalize_filename as canonicalize_filename
|
||||
import jax.experimental.source_mapper.hlo # pylint: disable=unused-import # noqa: F401
|
91
jax/experimental/source_mapper/common.py
Normal file
91
jax/experimental/source_mapper/common.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common utilities for generating source maps."""
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import Any, Protocol, Sequence
|
||||
|
||||
from absl import flags
|
||||
import jax
|
||||
from jax._src import sourcemap
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SourceMapDump:
|
||||
"""A container for a source map and the paired generated code."""
|
||||
source_map: sourcemap.SourceMap
|
||||
generated_code: str
|
||||
pass_name: str
|
||||
|
||||
|
||||
class CompileFn(Protocol):
|
||||
|
||||
def __call__(self, work_dir, fn, f_args, f_kwargs) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class GenerateDumpFn(Protocol):
|
||||
|
||||
def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump:
|
||||
...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pass:
|
||||
name: str
|
||||
compile_fn: CompileFn
|
||||
generate_dump: GenerateDumpFn
|
||||
|
||||
|
||||
_pass_registry = {}
|
||||
|
||||
|
||||
def register_pass(pass_: Pass):
|
||||
if pass_.name in _pass_registry:
|
||||
raise ValueError(f"Pass {pass_.name} already registered")
|
||||
_pass_registry[pass_.name] = pass_
|
||||
|
||||
|
||||
def all_passes() -> Sequence[Pass]:
|
||||
return list(_pass_registry.values())
|
||||
|
||||
|
||||
def filter_passes(regex: str) -> Sequence[Pass]:
|
||||
"""Gets all registered passes whose display name matches the given regex."""
|
||||
return [
|
||||
pass_
|
||||
for pass_ in _pass_registry.values()
|
||||
if re.match(regex, pass_.name)
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def flag_env(**kwargs):
|
||||
"""A context manager for setting and restoring flags."""
|
||||
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
|
||||
for kwarg, new_value in kwargs.items():
|
||||
setattr(flags.FLAGS, kwarg, new_value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for kwarg, old_value in old_flags.items():
|
||||
setattr(flags.FLAGS, kwarg, old_value)
|
||||
|
||||
|
||||
def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
|
||||
with flag_env(**env_flags):
|
||||
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda
|
||||
*f_args, **f_kwargs
|
||||
).compile(compiler_flags)
|
55
jax/experimental/source_mapper/generate_map.py
Normal file
55
jax/experimental/source_mapper/generate_map.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generates source maps for JAX functions."""
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Sequence, Protocol
|
||||
|
||||
from jax.experimental.source_mapper import common
|
||||
|
||||
|
||||
class SourceMapGeneratorFn(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> Sequence[common.SourceMapDump]:
|
||||
...
|
||||
|
||||
|
||||
def generate_sourcemaps(
|
||||
f,
|
||||
passes: Sequence[common.Pass],
|
||||
**kwargs
|
||||
) -> SourceMapGeneratorFn:
|
||||
"""Generates a SourceMapBundle for the specified compiler passes.
|
||||
|
||||
Args:
|
||||
f: The function to compile.
|
||||
passes: Which compiler passes to generate sourcemaps for.
|
||||
**kwargs: Keyword arguments for generate_dump passes.
|
||||
"""
|
||||
def wrapper(*args, **kwargs) -> Sequence[common.SourceMapDump]:
|
||||
pass_results: list[common.SourceMapDump] = []
|
||||
compile_cache = {}
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
for pass_to_eval in passes:
|
||||
if pass_to_eval.compile_fn not in compile_cache:
|
||||
pass_work_dir = os.path.join(work_dir, pass_to_eval.name)
|
||||
os.makedirs(pass_work_dir, exist_ok=False)
|
||||
compile_result = pass_to_eval.compile_fn(
|
||||
pass_work_dir, f, args, kwargs
|
||||
)
|
||||
compile_cache[pass_to_eval.compile_fn] = compile_result
|
||||
compile_result = compile_cache[pass_to_eval.compile_fn]
|
||||
pass_results.append(pass_to_eval.generate_dump(compile_result,
|
||||
**kwargs))
|
||||
return pass_results
|
||||
return wrapper
|
134
jax/experimental/source_mapper/hlo.py
Normal file
134
jax/experimental/source_mapper/hlo.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Source mapping generator for HLO dialects."""
|
||||
import enum
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import sourcemap
|
||||
|
||||
from jax.experimental.source_mapper import common
|
||||
from jax.experimental.source_mapper import mlir
|
||||
|
||||
|
||||
class HloPass(enum.Enum):
|
||||
STABLE_HLO = "hlo:stable-hlo"
|
||||
ORIGINAL = "hlo:original"
|
||||
OPTIMIZED = "hlo:optimized"
|
||||
|
||||
|
||||
METADATA_REGEX = re.compile(
|
||||
r"metadata={op_name=\"(?P<scope>.*)\" source_file=\"(?P<src_file>.*)\""
|
||||
r" source_line=(?P<src_line>[0-9]+)\}"
|
||||
)
|
||||
|
||||
|
||||
def parse_hlo_dump(text: str) -> sourcemap.SourceMap:
|
||||
mappings = sourcemap.MappingsGenerator()
|
||||
used_source_files = []
|
||||
for line in text.split("\n"):
|
||||
mappings.new_group()
|
||||
match = METADATA_REGEX.search(line)
|
||||
if match:
|
||||
match_dict = match.groupdict()
|
||||
_ = match_dict["scope"] # Unused
|
||||
src_file = match_dict["src_file"]
|
||||
src_line = int(match_dict["src_line"])
|
||||
if src_file not in used_source_files:
|
||||
used_source_files.append(src_file)
|
||||
src_file_idx = used_source_files.index(src_file)
|
||||
src_line -= 1 # Segments are zero-indexed
|
||||
first_col = line.index(line.strip()[0])
|
||||
mappings.new_segment(first_col, src_file_idx, src_line, 0)
|
||||
mappings.new_group()
|
||||
|
||||
return sourcemap.SourceMap(
|
||||
version=3,
|
||||
sources=used_source_files,
|
||||
sources_content=[],
|
||||
mappings=mappings.mappings(),
|
||||
names=[],
|
||||
)
|
||||
|
||||
|
||||
def trace_and_lower(work_dir, f, f_args, f_kwargs):
|
||||
lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args)
|
||||
return (lowered, work_dir)
|
||||
|
||||
|
||||
def stable_hlo_generate_dump(args: tuple[Any, str],
|
||||
**_) -> common.SourceMapDump:
|
||||
lowered, work_dir = args
|
||||
del work_dir
|
||||
hlo_text = lowered.as_text(debug_info=True)
|
||||
source_map = mlir.create_mlir_sourcemap(hlo_text)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=hlo_text,
|
||||
pass_name=HloPass.STABLE_HLO.value,
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name=HloPass.STABLE_HLO.value,
|
||||
compile_fn=trace_and_lower,
|
||||
generate_dump=stable_hlo_generate_dump, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def original_hlo_generate_dump(args: tuple[Any, str],
|
||||
**_) -> common.SourceMapDump:
|
||||
lowered, work_dir = args
|
||||
del work_dir
|
||||
hlo_text = lowered.as_text(dialect="hlo", debug_info=True)
|
||||
source_map = parse_hlo_dump(hlo_text)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=hlo_text,
|
||||
pass_name=HloPass.ORIGINAL.value,
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name=HloPass.ORIGINAL.value,
|
||||
compile_fn=trace_and_lower,
|
||||
generate_dump=original_hlo_generate_dump, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def optimized_generate_dump(args: tuple[Any, str],
|
||||
**_) -> common.SourceMapDump:
|
||||
lowered, work_dir = args
|
||||
compilation_args = {"xla_dump_to": work_dir}
|
||||
hlo_text = lowered.compile(compilation_args).as_text()
|
||||
source_map = parse_hlo_dump(hlo_text)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=hlo_text,
|
||||
pass_name=HloPass.OPTIMIZED.value,
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name=HloPass.OPTIMIZED.value,
|
||||
compile_fn=trace_and_lower,
|
||||
generate_dump=optimized_generate_dump, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
80
jax/experimental/source_mapper/jaxpr.py
Normal file
80
jax/experimental/source_mapper/jaxpr.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Source mapping generator for Jaxprs."""
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import sourcemap
|
||||
from jax.experimental.source_mapper import common
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
def compile_jaxpr(work_dir, f, f_args, f_kwargs):
|
||||
del work_dir
|
||||
return jax.make_jaxpr(f)(*f_args, **f_kwargs)
|
||||
|
||||
|
||||
def canonicalize_filename(file_name: str):
|
||||
pattern = config.hlo_source_file_canonicalization_regex.value
|
||||
if pattern:
|
||||
file_name = re.sub(pattern, '', file_name)
|
||||
return file_name
|
||||
|
||||
|
||||
def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump:
|
||||
pprint_mappings: list[list[tuple[int, int, Any]]] = []
|
||||
pprint_str = jaxpr.pretty_print(source_map=pprint_mappings)
|
||||
used_source_files = []
|
||||
mappings = sourcemap.MappingsGenerator()
|
||||
for pprint_map_line in pprint_mappings:
|
||||
mappings.new_group()
|
||||
for pprint_segment in pprint_map_line:
|
||||
start_col, end_col, frame = pprint_segment
|
||||
del end_col
|
||||
file_name = canonicalize_filename(frame.file_name)
|
||||
if file_name not in used_source_files:
|
||||
used_source_files.append(file_name)
|
||||
file_idx = used_source_files.index(file_name)
|
||||
src_line = frame.start_line - 1 # Zero-indexed
|
||||
src_col = frame.start_column
|
||||
# A segment is a tuple of the form:
|
||||
# (generated_col, src_file_idx, src_line, src_col)
|
||||
mappings.new_segment(start_col, file_idx, src_line, src_col)
|
||||
mappings.new_group()
|
||||
source_map = sourcemap.SourceMap(
|
||||
version=3,
|
||||
sources=used_source_files,
|
||||
sources_content=[],
|
||||
mappings=mappings.mappings(),
|
||||
names=[],
|
||||
)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=pprint_str,
|
||||
pass_name='jaxpr',
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name='jaxpr',
|
||||
compile_fn=compile_jaxpr,
|
||||
generate_dump=make_jaxpr_dump, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
140
jax/experimental/source_mapper/mlir.py
Normal file
140
jax/experimental/source_mapper/mlir.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for generating source mappings for MLIR dialects."""
|
||||
import collections
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from jax._src import sourcemap
|
||||
|
||||
|
||||
# TODO(justinfu): Make a proper parser for MLIR dumps.
|
||||
LOC_REGEX = re.compile(r"loc\(#loc(?P<id>[0-9]+)\)")
|
||||
|
||||
SRC_REGEX = re.compile(
|
||||
r"#loc(?P<id>[0-9]+) ="
|
||||
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)\)"
|
||||
)
|
||||
|
||||
SCOPED_REGEX = re.compile(
|
||||
r"#loc(?P<id>[0-9]+) = loc\(\"(?P<scope>.*)\"\(#loc(?P<tgt_id>[0-9]+)\)\)"
|
||||
)
|
||||
|
||||
CALLSITE_REGEX = re.compile(
|
||||
r"#loc(?P<id>[0-9]+) = loc\(callsite\(#loc(?P<callee>[0-9]+) at"
|
||||
r" #loc(?P<caller>[0-9]+)\)\)"
|
||||
)
|
||||
|
||||
Location = collections.namedtuple("Location", ["file", "line", "col"])
|
||||
Redirect = collections.namedtuple("Redirect", ["tgt_id"])
|
||||
|
||||
|
||||
def create_mlir_sourcemap(mlir_dump: str) -> sourcemap.SourceMap:
|
||||
mappings = sourcemap.MappingsGenerator()
|
||||
dump_lines: list[str] = mlir_dump.split("\n")
|
||||
|
||||
segment_dict, sources = parse_mlir_locations(dump_lines)
|
||||
used_sources = []
|
||||
used_sources_filenames = []
|
||||
for line in dump_lines:
|
||||
mappings.new_group()
|
||||
match = LOC_REGEX.search(line)
|
||||
if match:
|
||||
loc_id = int(match.group("id"))
|
||||
if loc_id not in segment_dict:
|
||||
# TODO(justinfu): This happens on fusion locations - need to implement.
|
||||
continue
|
||||
segment = list(segment_dict[loc_id])
|
||||
first_col = line.index(line.strip()[0])
|
||||
segment[0] = first_col
|
||||
# Remap the sourcefile index to only sourcefiles that are used.
|
||||
# This is optional but makes the mapping file smaller by pruning
|
||||
# unused sourcefiles.
|
||||
source_idx = segment[1]
|
||||
if source_idx not in used_sources:
|
||||
used_sources.append(source_idx)
|
||||
used_sources_filenames.append(sources[source_idx])
|
||||
segment[1] = used_sources.index(source_idx)
|
||||
mappings.new_segment(*segment)
|
||||
mappings.new_group()
|
||||
|
||||
return sourcemap.SourceMap(
|
||||
version=3,
|
||||
sources=used_sources_filenames,
|
||||
sources_content=[''] * len(used_sources_filenames),
|
||||
mappings=mappings.mappings(),
|
||||
names=[],
|
||||
)
|
||||
|
||||
|
||||
def parse_mlir_locations(
|
||||
mlir_dump: list[str],
|
||||
) -> tuple[dict[int, sourcemap.Segment], list[str]]:
|
||||
locations: dict[int, Location | Redirect] = {}
|
||||
source_files = []
|
||||
for line in mlir_dump:
|
||||
if line.startswith("#loc"):
|
||||
src_match = SRC_REGEX.match(line)
|
||||
if src_match:
|
||||
match_dict = src_match.groupdict()
|
||||
filename = match_dict["file"]
|
||||
locations[int(match_dict["id"])] = Location(
|
||||
file=filename,
|
||||
line=int(match_dict["line"]),
|
||||
col=int(match_dict["col"]),
|
||||
)
|
||||
if filename not in source_files:
|
||||
source_files.append(filename)
|
||||
continue
|
||||
scoped_match = SCOPED_REGEX.match(line)
|
||||
if scoped_match:
|
||||
match_dict = scoped_match.groupdict()
|
||||
locations[int(match_dict["id"])] = Redirect(
|
||||
tgt_id=int(match_dict["tgt_id"])
|
||||
)
|
||||
continue
|
||||
callsite_match = CALLSITE_REGEX.match(line)
|
||||
if callsite_match:
|
||||
match_dict = callsite_match.groupdict()
|
||||
locations[int(match_dict["id"])] = Redirect(
|
||||
tgt_id=int(match_dict["callee"])
|
||||
)
|
||||
continue
|
||||
if "loc(unknown)" in line:
|
||||
continue
|
||||
# Resolve redirects
|
||||
while True:
|
||||
new_locations: dict[int, Location | Redirect] = {}
|
||||
updated = False
|
||||
for loc_id, loc in locations.items():
|
||||
if isinstance(loc, Redirect):
|
||||
new_locations[loc_id] = locations[loc.tgt_id]
|
||||
updated = True
|
||||
else:
|
||||
new_locations[loc_id] = loc
|
||||
locations = new_locations
|
||||
if not updated:
|
||||
break
|
||||
segment_dict: dict[int, sourcemap.Segment] = {}
|
||||
for id_, loc in locations.items():
|
||||
# A segment is a tuple of the form:
|
||||
# (generated_col, src_file_idx, src_line, src_col)
|
||||
loc = cast(Location, loc)
|
||||
segment_dict[id_] = (
|
||||
0,
|
||||
source_files.index(loc.file),
|
||||
loc.line - 1, # Zero-indexed, so offset by 1.
|
||||
loc.col,
|
||||
)
|
||||
return segment_dict, source_files
|
10
tests/BUILD
10
tests/BUILD
@ -1584,6 +1584,16 @@ jax_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
name = "source_mapper_test",
|
||||
srcs = ["source_mapper_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:source_mapper",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
name = "sourcemap_test",
|
||||
srcs = ["sourcemap_test.py"],
|
||||
|
89
tests/source_mapper_test.py
Normal file
89
tests/source_mapper_test.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import source_mapper
|
||||
|
||||
|
||||
class SourceMapperTest(jtu.JaxTestCase):
|
||||
|
||||
def test_jaxpr_pass(self):
|
||||
def jax_fn(x, y):
|
||||
return x + y
|
||||
test_x = jnp.array([1, 2, 3])
|
||||
test_y = jnp.array([4, 5, 6])
|
||||
source_maps = source_mapper.generate_sourcemaps(
|
||||
jax_fn,
|
||||
passes=source_mapper.filter_passes("jaxpr"))(test_x, test_y)
|
||||
self.assertLen(source_maps, 1)
|
||||
dump = source_maps[0]
|
||||
self.assertEqual(dump.pass_name, "jaxpr")
|
||||
self.assertIn("add a b", dump.generated_code)
|
||||
source_map = dump.source_map
|
||||
self.assertLen(source_map.sources, 1)
|
||||
self.assertEqual(source_map.sources[0],
|
||||
source_mapper.canonicalize_filename(__file__))
|
||||
mappings = source_map.mappings
|
||||
self.assertLen(mappings, len(dump.generated_code.split("\n")) + 1)
|
||||
gen_col, file_idx, src_line, _ = mappings[0][0]
|
||||
# It's hard to guarantee at what column the add instruction will be
|
||||
# generated in the dump. We just sanity-check that it's greater than 0.
|
||||
self.assertGreater(gen_col, 0)
|
||||
# There is only one file, so we should map to that
|
||||
self.assertEqual(file_idx, 0)
|
||||
# These should line up with the function definition of jax_fn above.
|
||||
self.assertEqual(src_line, jax_fn.__code__.co_firstlineno)
|
||||
# TODO(justinfu): This fails on external but not internal builds.
|
||||
# self.assertEqual(src_col, 13)
|
||||
|
||||
@parameterized.parameters(
|
||||
("hlo:stable-hlo", "stablehlo.add", 13),
|
||||
("hlo:original", "add", 0),
|
||||
("hlo:optimized", "add", 0),
|
||||
)
|
||||
def test_hlo_passes(self, pass_name, expected_hlo_op, expected_col):
|
||||
del expected_col
|
||||
def jax_fn(x, y):
|
||||
return x + y
|
||||
test_x = jnp.array([1, 2, 3])
|
||||
test_y = jnp.array([4, 5, 6])
|
||||
source_maps = source_mapper.generate_sourcemaps(
|
||||
jax_fn,
|
||||
passes=source_mapper.filter_passes(pass_name))(test_x, test_y)
|
||||
self.assertLen(source_maps, 1)
|
||||
dump = source_maps[0]
|
||||
self.assertEqual(dump.pass_name, pass_name)
|
||||
self.assertIn(expected_hlo_op, dump.generated_code)
|
||||
source_map = dump.source_map
|
||||
self.assertLen(source_map.sources, 1)
|
||||
self.assertEqual(source_map.sources[0],
|
||||
source_mapper.canonicalize_filename(__file__))
|
||||
mappings = source_map.mappings
|
||||
self.assertLen(mappings, len(dump.generated_code.split("\n")) + 1)
|
||||
nonempty_mappings = [m for m in mappings if m]
|
||||
self.assertLen(nonempty_mappings, 1)
|
||||
gen_col, file_idx, src_line, _ = nonempty_mappings[0][0]
|
||||
self.assertGreater(gen_col, 0)
|
||||
# There is only one file, so we should map to that
|
||||
self.assertEqual(file_idx, 0)
|
||||
# These should line up with the function definition of jax_fn above.
|
||||
self.assertEqual(src_line, jax_fn.__code__.co_firstlineno)
|
||||
# TODO(justinfu): This fails on external but not internal builds.
|
||||
# self.assertEqual(src_col, expected_col)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user