From b01111d96c6d051c2bbeeae5f040801a727b9e24 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 29 Jan 2025 15:01:07 -0800 Subject: [PATCH] Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental. PiperOrigin-RevId: 721119935 --- jax/BUILD | 14 ++ jax/experimental/source_mapper/__init__.py | 29 ++++ jax/experimental/source_mapper/common.py | 91 ++++++++++++ .../source_mapper/generate_map.py | 55 +++++++ jax/experimental/source_mapper/hlo.py | 134 +++++++++++++++++ jax/experimental/source_mapper/jaxpr.py | 80 ++++++++++ jax/experimental/source_mapper/mlir.py | 140 ++++++++++++++++++ tests/BUILD | 10 ++ tests/source_mapper_test.py | 89 +++++++++++ 9 files changed, 642 insertions(+) create mode 100644 jax/experimental/source_mapper/__init__.py create mode 100644 jax/experimental/source_mapper/common.py create mode 100644 jax/experimental/source_mapper/generate_map.py create mode 100644 jax/experimental/source_mapper/hlo.py create mode 100644 jax/experimental/source_mapper/jaxpr.py create mode 100644 jax/experimental/source_mapper/mlir.py create mode 100644 tests/source_mapper_test.py diff --git a/jax/BUILD b/jax/BUILD index 401a59e81..657dbf179 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -598,6 +598,20 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "source_mapper", + srcs = glob(include = ["experimental/source_mapper/**/*.py"]), + visibility = [ + "//visibility:public", + ], + deps = [ + ":config", + ":core", + ":jax", + ":source_info_util", + ] + py_deps("absl/flags"), +) + pytype_strict_library( name = "pallas", srcs = glob( diff --git a/jax/experimental/source_mapper/__init__.py b/jax/experimental/source_mapper/__init__.py new file mode 100644 index 000000000..1dc158ec2 --- /dev/null +++ b/jax/experimental/source_mapper/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.sourcemap import SourceMap as SourceMap +from jax._src.sourcemap import MappingsGenerator as MappingsGenerator +from jax.experimental.source_mapper.common import Pass as Pass +from jax.experimental.source_mapper.common import register_pass as register_pass +from jax.experimental.source_mapper.common import all_passes as all_passes +from jax.experimental.source_mapper.common import filter_passes as filter_passes +from jax.experimental.source_mapper.common import compile_with_env as compile_with_env +from jax.experimental.source_mapper.common import SourceMapDump as SourceMapDump +from jax.experimental.source_mapper.generate_map import generate_sourcemaps as generate_sourcemaps +from jax.experimental.source_mapper.mlir import create_mlir_sourcemap as create_mlir_sourcemap + +# We import the jaxpr and hlo passes to register them. +import jax.experimental.source_mapper.jaxpr # pylint: disable=unused-import # noqa: F401 +from jax.experimental.source_mapper.jaxpr import canonicalize_filename as canonicalize_filename +import jax.experimental.source_mapper.hlo # pylint: disable=unused-import # noqa: F401 diff --git a/jax/experimental/source_mapper/common.py b/jax/experimental/source_mapper/common.py new file mode 100644 index 000000000..57051f9df --- /dev/null +++ b/jax/experimental/source_mapper/common.py @@ -0,0 +1,91 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for generating source maps.""" +import contextlib +import dataclasses +import re +from typing import Any, Protocol, Sequence + +from absl import flags +import jax +from jax._src import sourcemap + + +@dataclasses.dataclass(frozen=True) +class SourceMapDump: + """A container for a source map and the paired generated code.""" + source_map: sourcemap.SourceMap + generated_code: str + pass_name: str + + +class CompileFn(Protocol): + + def __call__(self, work_dir, fn, f_args, f_kwargs) -> Any: + ... + + +class GenerateDumpFn(Protocol): + + def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump: + ... + + +@dataclasses.dataclass(frozen=True) +class Pass: + name: str + compile_fn: CompileFn + generate_dump: GenerateDumpFn + + +_pass_registry = {} + + +def register_pass(pass_: Pass): + if pass_.name in _pass_registry: + raise ValueError(f"Pass {pass_.name} already registered") + _pass_registry[pass_.name] = pass_ + + +def all_passes() -> Sequence[Pass]: + return list(_pass_registry.values()) + + +def filter_passes(regex: str) -> Sequence[Pass]: + """Gets all registered passes whose display name matches the given regex.""" + return [ + pass_ + for pass_ in _pass_registry.values() + if re.match(regex, pass_.name) + ] + + +@contextlib.contextmanager +def flag_env(**kwargs): + """A context manager for setting and restoring flags.""" + old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs} + for kwarg, new_value in kwargs.items(): + setattr(flags.FLAGS, kwarg, new_value) + try: + yield + finally: + for kwarg, old_value in old_flags.items(): + setattr(flags.FLAGS, kwarg, old_value) + + +def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags): + with flag_env(**env_flags): + jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda + *f_args, **f_kwargs + ).compile(compiler_flags) diff --git a/jax/experimental/source_mapper/generate_map.py b/jax/experimental/source_mapper/generate_map.py new file mode 100644 index 000000000..5b0207f9f --- /dev/null +++ b/jax/experimental/source_mapper/generate_map.py @@ -0,0 +1,55 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generates source maps for JAX functions.""" +import os +import tempfile +from typing import Sequence, Protocol + +from jax.experimental.source_mapper import common + + +class SourceMapGeneratorFn(Protocol): + def __call__(self, *args, **kwargs) -> Sequence[common.SourceMapDump]: + ... + + +def generate_sourcemaps( + f, + passes: Sequence[common.Pass], + **kwargs +) -> SourceMapGeneratorFn: + """Generates a SourceMapBundle for the specified compiler passes. + + Args: + f: The function to compile. + passes: Which compiler passes to generate sourcemaps for. + **kwargs: Keyword arguments for generate_dump passes. + """ + def wrapper(*args, **kwargs) -> Sequence[common.SourceMapDump]: + pass_results: list[common.SourceMapDump] = [] + compile_cache = {} + with tempfile.TemporaryDirectory() as work_dir: + for pass_to_eval in passes: + if pass_to_eval.compile_fn not in compile_cache: + pass_work_dir = os.path.join(work_dir, pass_to_eval.name) + os.makedirs(pass_work_dir, exist_ok=False) + compile_result = pass_to_eval.compile_fn( + pass_work_dir, f, args, kwargs + ) + compile_cache[pass_to_eval.compile_fn] = compile_result + compile_result = compile_cache[pass_to_eval.compile_fn] + pass_results.append(pass_to_eval.generate_dump(compile_result, + **kwargs)) + return pass_results + return wrapper diff --git a/jax/experimental/source_mapper/hlo.py b/jax/experimental/source_mapper/hlo.py new file mode 100644 index 000000000..601a007a4 --- /dev/null +++ b/jax/experimental/source_mapper/hlo.py @@ -0,0 +1,134 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Source mapping generator for HLO dialects.""" +import enum +import re +from typing import Any + +import jax +from jax._src import sourcemap + +from jax.experimental.source_mapper import common +from jax.experimental.source_mapper import mlir + + +class HloPass(enum.Enum): + STABLE_HLO = "hlo:stable-hlo" + ORIGINAL = "hlo:original" + OPTIMIZED = "hlo:optimized" + + +METADATA_REGEX = re.compile( + r"metadata={op_name=\"(?P.*)\" source_file=\"(?P.*)\"" + r" source_line=(?P[0-9]+)\}" +) + + +def parse_hlo_dump(text: str) -> sourcemap.SourceMap: + mappings = sourcemap.MappingsGenerator() + used_source_files = [] + for line in text.split("\n"): + mappings.new_group() + match = METADATA_REGEX.search(line) + if match: + match_dict = match.groupdict() + _ = match_dict["scope"] # Unused + src_file = match_dict["src_file"] + src_line = int(match_dict["src_line"]) + if src_file not in used_source_files: + used_source_files.append(src_file) + src_file_idx = used_source_files.index(src_file) + src_line -= 1 # Segments are zero-indexed + first_col = line.index(line.strip()[0]) + mappings.new_segment(first_col, src_file_idx, src_line, 0) + mappings.new_group() + + return sourcemap.SourceMap( + version=3, + sources=used_source_files, + sources_content=[], + mappings=mappings.mappings(), + names=[], + ) + + +def trace_and_lower(work_dir, f, f_args, f_kwargs): + lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args) + return (lowered, work_dir) + + +def stable_hlo_generate_dump(args: tuple[Any, str], + **_) -> common.SourceMapDump: + lowered, work_dir = args + del work_dir + hlo_text = lowered.as_text(debug_info=True) + source_map = mlir.create_mlir_sourcemap(hlo_text) + return common.SourceMapDump( + source_map=source_map, + generated_code=hlo_text, + pass_name=HloPass.STABLE_HLO.value, + ) + + +common.register_pass( + common.Pass( + name=HloPass.STABLE_HLO.value, + compile_fn=trace_and_lower, + generate_dump=stable_hlo_generate_dump, # type: ignore[arg-type] + ) +) + + +def original_hlo_generate_dump(args: tuple[Any, str], + **_) -> common.SourceMapDump: + lowered, work_dir = args + del work_dir + hlo_text = lowered.as_text(dialect="hlo", debug_info=True) + source_map = parse_hlo_dump(hlo_text) + return common.SourceMapDump( + source_map=source_map, + generated_code=hlo_text, + pass_name=HloPass.ORIGINAL.value, + ) + + +common.register_pass( + common.Pass( + name=HloPass.ORIGINAL.value, + compile_fn=trace_and_lower, + generate_dump=original_hlo_generate_dump, # type: ignore[arg-type] + ) +) + + +def optimized_generate_dump(args: tuple[Any, str], + **_) -> common.SourceMapDump: + lowered, work_dir = args + compilation_args = {"xla_dump_to": work_dir} + hlo_text = lowered.compile(compilation_args).as_text() + source_map = parse_hlo_dump(hlo_text) + return common.SourceMapDump( + source_map=source_map, + generated_code=hlo_text, + pass_name=HloPass.OPTIMIZED.value, + ) + + +common.register_pass( + common.Pass( + name=HloPass.OPTIMIZED.value, + compile_fn=trace_and_lower, + generate_dump=optimized_generate_dump, # type: ignore[arg-type] + ) +) diff --git a/jax/experimental/source_mapper/jaxpr.py b/jax/experimental/source_mapper/jaxpr.py new file mode 100644 index 000000000..b467000b1 --- /dev/null +++ b/jax/experimental/source_mapper/jaxpr.py @@ -0,0 +1,80 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Source mapping generator for Jaxprs.""" +import re +from typing import Any + +import jax +from jax._src import config +from jax._src import core +from jax._src import source_info_util +from jax._src import sourcemap +from jax.experimental.source_mapper import common + +source_info_util.register_exclusion(__file__) + + +def compile_jaxpr(work_dir, f, f_args, f_kwargs): + del work_dir + return jax.make_jaxpr(f)(*f_args, **f_kwargs) + + +def canonicalize_filename(file_name: str): + pattern = config.hlo_source_file_canonicalization_regex.value + if pattern: + file_name = re.sub(pattern, '', file_name) + return file_name + + +def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump: + pprint_mappings: list[list[tuple[int, int, Any]]] = [] + pprint_str = jaxpr.pretty_print(source_map=pprint_mappings) + used_source_files = [] + mappings = sourcemap.MappingsGenerator() + for pprint_map_line in pprint_mappings: + mappings.new_group() + for pprint_segment in pprint_map_line: + start_col, end_col, frame = pprint_segment + del end_col + file_name = canonicalize_filename(frame.file_name) + if file_name not in used_source_files: + used_source_files.append(file_name) + file_idx = used_source_files.index(file_name) + src_line = frame.start_line - 1 # Zero-indexed + src_col = frame.start_column + # A segment is a tuple of the form: + # (generated_col, src_file_idx, src_line, src_col) + mappings.new_segment(start_col, file_idx, src_line, src_col) + mappings.new_group() + source_map = sourcemap.SourceMap( + version=3, + sources=used_source_files, + sources_content=[], + mappings=mappings.mappings(), + names=[], + ) + return common.SourceMapDump( + source_map=source_map, + generated_code=pprint_str, + pass_name='jaxpr', + ) + + +common.register_pass( + common.Pass( + name='jaxpr', + compile_fn=compile_jaxpr, + generate_dump=make_jaxpr_dump, # type: ignore[arg-type] + ) +) diff --git a/jax/experimental/source_mapper/mlir.py b/jax/experimental/source_mapper/mlir.py new file mode 100644 index 000000000..4f107d39d --- /dev/null +++ b/jax/experimental/source_mapper/mlir.py @@ -0,0 +1,140 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for generating source mappings for MLIR dialects.""" +import collections +import re +from typing import cast + +from jax._src import sourcemap + + +# TODO(justinfu): Make a proper parser for MLIR dumps. +LOC_REGEX = re.compile(r"loc\(#loc(?P[0-9]+)\)") + +SRC_REGEX = re.compile( + r"#loc(?P[0-9]+) =" + r" loc\(\"(?P.*)\":(?P[0-9]+):(?P[0-9]+)\)" +) + +SCOPED_REGEX = re.compile( + r"#loc(?P[0-9]+) = loc\(\"(?P.*)\"\(#loc(?P[0-9]+)\)\)" +) + +CALLSITE_REGEX = re.compile( + r"#loc(?P[0-9]+) = loc\(callsite\(#loc(?P[0-9]+) at" + r" #loc(?P[0-9]+)\)\)" +) + +Location = collections.namedtuple("Location", ["file", "line", "col"]) +Redirect = collections.namedtuple("Redirect", ["tgt_id"]) + + +def create_mlir_sourcemap(mlir_dump: str) -> sourcemap.SourceMap: + mappings = sourcemap.MappingsGenerator() + dump_lines: list[str] = mlir_dump.split("\n") + + segment_dict, sources = parse_mlir_locations(dump_lines) + used_sources = [] + used_sources_filenames = [] + for line in dump_lines: + mappings.new_group() + match = LOC_REGEX.search(line) + if match: + loc_id = int(match.group("id")) + if loc_id not in segment_dict: + # TODO(justinfu): This happens on fusion locations - need to implement. + continue + segment = list(segment_dict[loc_id]) + first_col = line.index(line.strip()[0]) + segment[0] = first_col + # Remap the sourcefile index to only sourcefiles that are used. + # This is optional but makes the mapping file smaller by pruning + # unused sourcefiles. + source_idx = segment[1] + if source_idx not in used_sources: + used_sources.append(source_idx) + used_sources_filenames.append(sources[source_idx]) + segment[1] = used_sources.index(source_idx) + mappings.new_segment(*segment) + mappings.new_group() + + return sourcemap.SourceMap( + version=3, + sources=used_sources_filenames, + sources_content=[''] * len(used_sources_filenames), + mappings=mappings.mappings(), + names=[], + ) + + +def parse_mlir_locations( + mlir_dump: list[str], +) -> tuple[dict[int, sourcemap.Segment], list[str]]: + locations: dict[int, Location | Redirect] = {} + source_files = [] + for line in mlir_dump: + if line.startswith("#loc"): + src_match = SRC_REGEX.match(line) + if src_match: + match_dict = src_match.groupdict() + filename = match_dict["file"] + locations[int(match_dict["id"])] = Location( + file=filename, + line=int(match_dict["line"]), + col=int(match_dict["col"]), + ) + if filename not in source_files: + source_files.append(filename) + continue + scoped_match = SCOPED_REGEX.match(line) + if scoped_match: + match_dict = scoped_match.groupdict() + locations[int(match_dict["id"])] = Redirect( + tgt_id=int(match_dict["tgt_id"]) + ) + continue + callsite_match = CALLSITE_REGEX.match(line) + if callsite_match: + match_dict = callsite_match.groupdict() + locations[int(match_dict["id"])] = Redirect( + tgt_id=int(match_dict["callee"]) + ) + continue + if "loc(unknown)" in line: + continue + # Resolve redirects + while True: + new_locations: dict[int, Location | Redirect] = {} + updated = False + for loc_id, loc in locations.items(): + if isinstance(loc, Redirect): + new_locations[loc_id] = locations[loc.tgt_id] + updated = True + else: + new_locations[loc_id] = loc + locations = new_locations + if not updated: + break + segment_dict: dict[int, sourcemap.Segment] = {} + for id_, loc in locations.items(): + # A segment is a tuple of the form: + # (generated_col, src_file_idx, src_line, src_col) + loc = cast(Location, loc) + segment_dict[id_] = ( + 0, + source_files.index(loc.file), + loc.line - 1, # Zero-indexed, so offset by 1. + loc.col, + ) + return segment_dict, source_files diff --git a/tests/BUILD b/tests/BUILD index 6d3f8f45e..6b90309bd 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1584,6 +1584,16 @@ jax_py_test( ], ) +jax_py_test( + name = "source_mapper_test", + srcs = ["source_mapper_test.py"], + deps = [ + "//jax", + "//jax:source_mapper", + "//jax:test_util", + ], +) + jax_py_test( name = "sourcemap_test", srcs = ["sourcemap_test.py"], diff --git a/tests/source_mapper_test.py b/tests/source_mapper_test.py new file mode 100644 index 000000000..973290043 --- /dev/null +++ b/tests/source_mapper_test.py @@ -0,0 +1,89 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from absl.testing import parameterized +from jax import numpy as jnp +from jax._src import test_util as jtu +from jax.experimental import source_mapper + + +class SourceMapperTest(jtu.JaxTestCase): + + def test_jaxpr_pass(self): + def jax_fn(x, y): + return x + y + test_x = jnp.array([1, 2, 3]) + test_y = jnp.array([4, 5, 6]) + source_maps = source_mapper.generate_sourcemaps( + jax_fn, + passes=source_mapper.filter_passes("jaxpr"))(test_x, test_y) + self.assertLen(source_maps, 1) + dump = source_maps[0] + self.assertEqual(dump.pass_name, "jaxpr") + self.assertIn("add a b", dump.generated_code) + source_map = dump.source_map + self.assertLen(source_map.sources, 1) + self.assertEqual(source_map.sources[0], + source_mapper.canonicalize_filename(__file__)) + mappings = source_map.mappings + self.assertLen(mappings, len(dump.generated_code.split("\n")) + 1) + gen_col, file_idx, src_line, _ = mappings[0][0] + # It's hard to guarantee at what column the add instruction will be + # generated in the dump. We just sanity-check that it's greater than 0. + self.assertGreater(gen_col, 0) + # There is only one file, so we should map to that + self.assertEqual(file_idx, 0) + # These should line up with the function definition of jax_fn above. + self.assertEqual(src_line, jax_fn.__code__.co_firstlineno) + # TODO(justinfu): This fails on external but not internal builds. + # self.assertEqual(src_col, 13) + + @parameterized.parameters( + ("hlo:stable-hlo", "stablehlo.add", 13), + ("hlo:original", "add", 0), + ("hlo:optimized", "add", 0), + ) + def test_hlo_passes(self, pass_name, expected_hlo_op, expected_col): + del expected_col + def jax_fn(x, y): + return x + y + test_x = jnp.array([1, 2, 3]) + test_y = jnp.array([4, 5, 6]) + source_maps = source_mapper.generate_sourcemaps( + jax_fn, + passes=source_mapper.filter_passes(pass_name))(test_x, test_y) + self.assertLen(source_maps, 1) + dump = source_maps[0] + self.assertEqual(dump.pass_name, pass_name) + self.assertIn(expected_hlo_op, dump.generated_code) + source_map = dump.source_map + self.assertLen(source_map.sources, 1) + self.assertEqual(source_map.sources[0], + source_mapper.canonicalize_filename(__file__)) + mappings = source_map.mappings + self.assertLen(mappings, len(dump.generated_code.split("\n")) + 1) + nonempty_mappings = [m for m in mappings if m] + self.assertLen(nonempty_mappings, 1) + gen_col, file_idx, src_line, _ = nonempty_mappings[0][0] + self.assertGreater(gen_col, 0) + # There is only one file, so we should map to that + self.assertEqual(file_idx, 0) + # These should line up with the function definition of jax_fn above. + self.assertEqual(src_line, jax_fn.__code__.co_firstlineno) + # TODO(justinfu): This fails on external but not internal builds. + # self.assertEqual(src_col, expected_col) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())