Add a Promela spec generator for Pallas TPU kernels

This adds a simple model extractor for TPU kernels that generates a Promela spec
outlining the semantics of semaphores and DMAs. The model can be fed into SPIN
and used to e.g. verify the lack of data races or deadlocks. While compelte verification
is very expensive, the tool seems especially good at finding races that are really there.

PiperOrigin-RevId: 653198263
This commit is contained in:
Adam Paszke 2024-07-17 05:28:34 -07:00 committed by jax authors
parent 6c5583d6aa
commit 2ea222544e
7 changed files with 575 additions and 18 deletions

View File

@ -635,6 +635,7 @@ pytype_strict_library(
"//jax/_src/pallas/mosaic:pipeline",
"//jax/_src/pallas/mosaic:primitives",
"//jax/_src/pallas/mosaic:random",
"//jax/_src/pallas/mosaic:verification",
],
)

View File

@ -33,6 +33,16 @@ py_library(
],
)
py_library(
name = "verification",
srcs = ["verification.py"],
deps = [
"//jax",
"//jax:mlir",
"//jax/_src/lib",
],
)
py_library(
name = "primitives",
srcs = ["primitives.py"],

View File

@ -16,18 +16,22 @@
from __future__ import annotations
import os
import tempfile
from typing import Any
import warnings
import jax
from jax import dtypes
from jax import core as jax_core
from jax._src import config
from jax._src import core as jax_src_core
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core
from jax._src.pallas.mosaic import lowering
from jax._src.pallas.mosaic import verification
from jax._src.pallas.pallas_call import pallas_call_p
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
@ -48,6 +52,17 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray):
return jax_core.ShapedArray(x.shape, lowering.BOOL_MEMREF_TYPE)
return x
_DUMP_PROMELA_TO = config.string_flag(
"jax_pallas_dump_promela_to",
default=os.getenv("JAX_PALLAS_DUMP_PROMELA_TO", ""),
help=(
"If set, dumps a Promela model of the kernel to the specified"
" directory. The model can verify that the kernel is free of data"
" races, deadlocks, etc."
),
)
def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext, *in_nodes,
jaxpr: jax_core.Jaxpr,
@ -106,6 +121,32 @@ def pallas_call_tpu_lowering_rule(
)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
if promela_dump_path := _DUMP_PROMELA_TO.value:
num_devices = 1 if mesh is None else mesh.devices.size
num_cores = (
jax.devices()[0].num_cores
if mesh is None
else mesh.devices[0].num_cores
)
model = verification.export_promela_model(
mosaic_module, num_devices, num_cores
)
if promela_dump_path == "stdout":
print(model)
else:
if promela_dump_path == "sponge":
promela_dump_path = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "")
if not promela_dump_path:
raise ValueError(
"TEST_UNDECLARED_OUTPUTS_DIR must be set when"
" --jax_pallas_dump_promela_to=sponge"
)
dump_ctx = tempfile.NamedTemporaryFile(
mode="w", prefix=name + "-", suffix=".pml", dir=promela_dump_path, delete=False,
)
with dump_ctx as f:
f.write(model)
# Replace in_avals to physical avals.
# This step is required for mapping logical types to physical types.
# (e.g. PRNG key -> uint32[2])

View File

@ -0,0 +1,502 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import dataclasses
import io
import itertools
import math
import textwrap
from typing import Any, Sequence
from jaxlib.mlir import ir
from jaxlib.mlir.passmanager import PassManager
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import func
from jax._src.lib import tpu
from jax._src.util import split_list
_UNSPECIFIED = object()
Var = str
# TODO(apaszke): Add checks that semaphores are always left at 0.
# TODO(apaszke): Add checks that no remote resources are used while the remote
# device is not in the kernel (both before and after).
# TODO(apaszke): Model 0-sized DMAs faithfully.
PREAMBLE = """
#define buf_readers(index, device, core) _buf_readers[(index)*(NDEVICE*NCORE) + (device)*NCORE + core]
#define buf_written(index, device, core) _buf_written[(index)*(NDEVICE*NCORE) + (device)*NCORE + core]
#define sems(index, device, core) _sems[(index)*(NDEVICE*NCORE) + (device)*NCORE + core]
#define barrier_sems(device, core) _barrier_sems[(device)*NCORE + core]
#ifndef NDMA
#define NDMA 2
#endif
mtype = { DMA };
chan dma_queue = [NDMA] of { mtype, int, int, int, int, int, int, int, int, int, int };
"""
DMA_PROCESS = """
active [NDMA] proctype DmaEngine() {
int src_dev, src_core, src_sem, src_buf_base, src_buf_len;
int dst_dev, dst_core, dst_sem, dst_buf_base, dst_buf_len;
do
:: skip;
end: dma_queue?DMA(src_dev, src_core, src_sem, src_buf_base, src_buf_len, dst_dev, dst_core, dst_sem, dst_buf_base, dst_buf_len);
d_step {
printf("DMA read done: [%d, %d)@{%d, %d} (%d++)\\n", src_buf_base, src_buf_base + src_buf_len, src_dev, src_core, src_sem);
int i;
for (i : src_buf_base .. src_buf_base + src_buf_len - 1) {
buf_readers(i, src_dev, src_core)--;
}
sems(src_sem, src_dev, src_core)++;
} // Read read
d_step {
printf("DMA write done: [%d, %d)@{%d, %d} (%d++)\\n", dst_buf_base, dst_buf_base + dst_buf_len, dst_dev, dst_core, dst_sem);
int i;
for (i : dst_buf_base .. dst_buf_base + dst_buf_len - 1) {
buf_written(i, dst_dev, dst_core)--;
}
sems(dst_sem, dst_dev, dst_core)++;
} // Write complete
od
}
"""
class PrintCtx:
MAX_REF_UNROLL = 8
def __init__(self, iteration_bounds):
self.level = 1
self.num_semaphores = 0
self.num_buffers = 0
self.locals = []
self.counter = itertools.count()
self.env: dict[ir.Value, Var | int] = {}
self.program_ids = tuple(f"pid{i}" for i in range(len(iteration_bounds)))
self.device_id = "dev_id"
# TODO(apaszke): Clean up core_id! This is not a visible detail in the Mosaic
# programming model.
self.emit(None, "int core_id = 0")
# Reconstruct device id and program ids from the pid.
self.emit(None, "int dev_id")
if iteration_bounds:
self.emit(None, f"int {', '.join(self.program_ids)}")
with self.block("d_step {", "}"):
idx = "_pid"
program_ids = []
for i, b in reversed(list(enumerate(iteration_bounds))):
program_ids.append(self.emit(None, f"pid{i} = {idx} % {b}"))
idx = self.emit("int", f"{idx} / {b}")
self.emit(None, f"dev_id = {idx}")
def emit_global_ref(self, shape: Sequence[int]):
slots = 1
if shape and shape[0] <= self.MAX_REF_UNROLL:
slots = shape[0]
base = self.num_buffers
self.num_buffers += slots
return GlobalRefModel(base, slots)
def emit_global_semaphore_ref(self, shape: Sequence[int]):
count = math.prod(shape)
base = self.num_semaphores
self.num_semaphores += count
return GlobalSemaphoreModel(base, count)
def _indent(self, text: str) -> str:
return textwrap.indent(text, " " * self.level)
def emit(self, ty, expr):
name = None
if ty is not None:
name = "l" + str(next(self.counter))
expr = f"{ty} {name} = {expr}"
self.locals.append(self._indent(expr) + ";\n")
return name
def comment(self, comment):
self.locals.append(self._indent(f"/* {comment} */\n"))
@contextlib.contextmanager
def block(self, begin: str, end: str):
self.locals.append(self._indent(begin) + "\n")
self.level += 1
yield
self.level -= 1
self.locals.append(self._indent(end) + "\n")
def get(self, value: ir.Value, default: Any = _UNSPECIFIED):
if default is _UNSPECIFIED:
return self.env[value]
else:
return self.env.get(value, default)
def set(self, value: ir.Value, model_value: Any):
self.env[value] = model_value
def get_model(
self,
has_barrier_sems: bool,
num_devices: int,
num_cores: int,
parallel_iteration_bounds: Sequence[int],
) -> str:
result = io.StringIO()
result.write(f"#define NDEVICE {num_devices}\n")
result.write("#define NCORE 1\n")
result.write(f"int _buf_readers[{self.num_buffers}*NDEVICE*NCORE] = 0;\n")
result.write(f"bool _buf_written[{self.num_buffers}*NDEVICE*NCORE] = 0;\n")
result.write(f"int _sems[{self.num_semaphores}*NDEVICE*NCORE] = 0;\n")
if has_barrier_sems:
result.write("int _barrier_sems[NDEVICE*NCORE] = 0;\n")
result.write(PREAMBLE)
result.write("\n")
parallel_threads = math.prod(parallel_iteration_bounds)
result.write(f"active [NDEVICE*{parallel_threads}] proctype Kernel() {{\n")
for l in self.locals:
result.write(l)
result.write("}\n")
result.write(DMA_PROCESS)
return result.getvalue()
def resolve_location(location):
if location is None:
location = [None, None]
else:
location = list(location)
if location[0] is None:
location[0] = "dev_id"
if location[1] is None:
location[1] = "core_id"
return tuple(location)
@dataclasses.dataclass(frozen=True)
class GlobalRefModel:
"""A model of a memory reference.
When a reference has a small leading dimension, it might be represented by
multiple slots in the reference array. Its region starts at base (that can be
dynamic) and has the given length (always static).
"""
base: Any
length: int
def readers_at(self, location):
dev, core = resolve_location(location)
return [f"buf_readers({self.base} + {i}, {dev}, {core})" for i in range(self.length)]
def written_at(self, location):
dev, core = resolve_location(location)
return [f"buf_written({self.base} + {i}, {dev}, {core})" for i in range(self.length)]
@dataclasses.dataclass(frozen=True)
class GlobalSemaphoreModel:
"""A model of a semaphore reference.
Semaphore arrays are always fully unrolled and are represented by a contiguous
subset of the global semaphore array.
"""
base: Any
length: int
def at(self, location):
dev, core = resolve_location(location)
return f"sems({self.base}, {dev}, {core})"
@dataclasses.dataclass(frozen=True)
class GlobalBarrierSemaphoreModel:
def at(self, location):
dev, core = resolve_location(location)
return f"barrier_sems({dev}, {core})"
def _print_op(ctx, op):
match op.OPERATION_NAME:
case "tpu.region":
_print_block(ctx, op.body)
case "tpu.device_id":
return ctx.device_id
case "arith.constant":
if ir.IntegerType.isinstance(op.result.type):
return str(ir.IntegerAttr(op.value).value)
else:
return
case "tpu.sem_signal":
location = resolve_location((ctx.get(op.device_id, None), ctx.get(op.core_id, None)))
sem_model = ctx.get(op.semaphore)
sem = sem_model.at(location)
amount = ctx.get(op.amount)
if isinstance(sem_model, GlobalBarrierSemaphoreModel):
ctx.emit(None, f'printf("Signal: BARRIER@{{%d, %d}} += %d\\n", {location[0]}, {location[1]}, {amount})')
else:
ctx.emit(None, f'printf("Signal: %d@{{%d, %d}} += %d\\n", {sem_model.base}, {location[0]}, {location[1]}, {amount})')
ctx.emit(None, f"d_step {{ {sem} = {sem} + {amount} }}")
case "tpu.sem_wait":
sem_model = ctx.get(op.semaphore)
sem = sem_model.at(location=None)
amount = ctx.get(op.amount)
ctx.emit(None, f"atomic {{ {sem} >= {amount}; {sem} = {sem} - {amount} }}")
if isinstance(sem_model, GlobalBarrierSemaphoreModel):
ctx.emit(None, f'printf("Wait done: BARRIER -= %d\\n", {amount})')
else:
ctx.emit(None, f'printf("Wait done: %d -= %d\\n", {sem_model.base}, {amount})')
case "tpu.enqueue_dma":
dst_location = resolve_location((ctx.get(op.device_id, None), ctx.get(op.core_id, None)))
src = ctx.get(op.source)
src_sem = ctx.get(op.source_semaphore)
dst = ctx.get(op.target)
dst_sem = ctx.get(op.target_semaphore)
src_readonly = "\n && ".join(is_written + " == 0" for is_written in src.written_at(None))
dst_unused = "\n && ".join(
is_written + " == 0"
for is_written in itertools.chain(
dst.written_at(dst_location), dst.readers_at(dst_location)
)
)
ctx.emit(
None,
'printf("DMA: [%d, %d)@{%d, %d} -> [%d, %d)@{%d, %d}\\n",'
f" {src.base}, {src.base} + {src.length}, dev_id, core_id,"
f" {dst.base}, {dst.base} + {dst.length}, {dst_location[0]},"
f" {dst_location[1]})",
)
with ctx.block("d_step {", "}"):
ctx.emit(None, f"assert({src_readonly}); // Source is not written to.")
ctx.emit(None, f"assert({dst_unused}); // Destination is unused.")
for r in src.readers_at(None):
ctx.emit(None, f"{r}++")
for w in dst.written_at(dst_location):
ctx.emit(None, f"{w} = 1")
ctx.emit(
None,
f"dma_queue!DMA(dev_id, core_id, {src_sem.base}, {src.base},"
f" {src.length}, {dst_location[0]}, {dst_location[1]},"
f" {dst_sem.base}, {dst.base}, {dst.length})",
)
case "tpu.wait_dma":
sem_model = ctx.get(op.semaphore)
sem = sem_model.at(location=None)
ctx.emit(None, f"atomic {{ {sem} >= 1; {sem} = {sem} - 1 }}")
ctx.emit(None, f'printf("Awaited DMA: %d\\n", {sem_model.base})')
case "tpu.sem_barrier":
return GlobalBarrierSemaphoreModel()
case "tpu.memref_slice":
result = ctx.get(op.mem_ref, None)
if result is None:
return NotImplemented
src_shape = ir.MemRefType(op.mem_ref.type).shape
dst_shape = ir.MemRefType(op.result.type).shape
dynamic = ir.ShapedType.get_dynamic_size()
# We always unroll semaphore references entirely, and we need to be
# faithful when slicing them.
if isinstance(result, GlobalSemaphoreModel):
# We only support contiguous slices of semaphore arrays at the moment.
seen_nontrivial_unequal = False
for s, d in zip(src_shape, dst_shape):
if d == 1:
continue
if s != d:
if seen_nontrivial_unequal:
raise NotImplementedError("Non-contiguous slices of semaphore arrays")
seen_nontrivial_unequal = True
strides = []
stride = 1
for s in src_shape[::-1]:
strides.append(stride)
stride *= s
strides = reversed(strides)
indices = [ctx.get(idx) for idx in op.base_idx]
linear_offset = " + ".join(f"{idx} * {s}" for idx, s in zip(indices, strides))
return GlobalSemaphoreModel(
base=f"{result.base} + {linear_offset}", length=math.prod(dst_shape)
)
else:
assert isinstance(result, GlobalRefModel)
major_idx = ctx.get(op.base_idx[0], None)
if (not src_shape or src_shape[0] == dynamic or dst_shape[0] == dynamic
or result.length == 1 or major_idx is None):
return result
return GlobalRefModel(f"{result.base} + {major_idx}", dst_shape[0])
case "tpu.memref_squeeze":
result = ctx.get(op.input, None)
return NotImplemented if result is None else result
case "tpu.assume_multiple":
result = ctx.get(op.value, None)
return NotImplemented if result is None else result
case "arith.addi":
return bin_op(ctx, "int", "+", *op.operands)
case "arith.subi":
return bin_op(ctx, "int", "-", *op.operands)
case "arith.muli":
return bin_op(ctx, "int", "*", *op.operands)
case "arith.remsi":
# TODO(apaszke): Make sure this has right semantics for negative integers.
return bin_op(ctx, "int", "%", *op.operands)
case "arith.divsi":
return bin_op(ctx, "int", "/", *op.operands)
case "arith.cmpi":
match op.predicate.value:
case arith.CmpIPredicate.eq:
return bin_op(ctx, "bool", "==", *op.operands)
case arith.CmpIPredicate.ne:
return bin_op(ctx, "bool", "!=", *op.operands)
case arith.CmpIPredicate.slt:
return bin_op(ctx, "bool", "<", *op.operands)
case arith.CmpIPredicate.sle:
return bin_op(ctx, "bool", "<=", *op.operands)
case arith.CmpIPredicate.sgt:
return bin_op(ctx, "bool", ">", *op.operands)
case arith.CmpIPredicate.sge:
return bin_op(ctx, "bool", ">=", *op.operands)
return bin_op(ctx, "bool", "/", *op.operands)
case "tpu.trace_start":
ctx.comment(op.message.value)
case "scf.for":
carrys = [
ctx.emit("int", ctx.get(arg))
if ir.IntegerType.isinstance(arg.type) else None
for arg in op.initArgs
]
bounds = (op.lowerBound, op.upperBound, op.step)
lower, upper, step = bound_models = map(ctx.get, bounds)
for model, v in zip(bound_models, bounds):
if model is None:
raise ValueError(f"Could not model loop bound or step: {v}")
induction_var = ctx.emit("int", lower)
with ctx.block("do", "od"):
ctx.emit(None, f":: {induction_var} < {upper}; ")
ctx.set(op.induction_variable, induction_var)
for c, arg in zip(carrys, op.inner_iter_args, strict=True):
if c is not None:
ctx.set(arg, c)
_print_block(ctx, op.body)
terminator = op.body.operations[len(op.body.operations) - 1]
new_carrys = terminator.operands
with ctx.block("d_step {", "}"):
for c, new in zip(carrys, new_carrys, strict=True):
if c is not None:
ctx.emit(None, f"{c} = {ctx.get(new)}")
ctx.emit(None, f"{induction_var} = {induction_var} + {step}")
ctx.emit(None, ":: else -> break")
if len(carrys) == 1:
return carrys[0]
else:
return tuple(carrys)
case "scf.if":
if op.results:
raise NotImplementedError
if (condition := ctx.get(op.condition, None)) is None:
raise ValueError(f"Could not model branch condition: {op.condition}")
with ctx.block("if", "fi"):
ctx.emit(None, f":: ({condition})")
_print_block(ctx, op.then_block)
if op.regions[1].blocks:
ctx.emit(None, ":: else")
_print_block(ctx, op.else_block)
else:
ctx.emit(None, ":: else -> skip")
case _:
if not op.regions:
return NotImplemented
raise NotImplementedError("Must handle all ops with regions")
def bin_op(ctx, result_ty, op, lhs, rhs):
lhs = ctx.get(lhs, None)
rhs = ctx.get(rhs, None)
if lhs is None or rhs is None:
return NotImplemented
return ctx.emit(result_ty, f"{lhs} {op} {rhs}")
def _print_block(ctx, block):
for op in block:
try:
results = _print_op(ctx, op)
except Exception as e:
raise RuntimeError(f"Failed to print op: {op}") from e
if results is NotImplemented:
continue
if not op.results:
assert results is None
elif len(op.results) > 1:
raise NotImplementedError(op)
else:
ctx.set(op.result, results)
def export_promela_model(
module, num_devices: int, num_cores_per_device: int
) -> str:
with module.context:
_, uses_barrier_semaphores = tpu.private_has_communication(module.operation)
# Clone the module and simplify it to make the model smaller and simpler.
module = ir.Module.parse(module.operation.get_asm(binary=True))
passes = ["canonicalize", "cse"]
pipeline = PassManager.parse(f"builtin.module({','.join(passes)})")
pipeline.run(module.operation)
main_str_attr = ir.StringAttr.get("main")
for f in module.body:
if getattr(f, "name", None) == main_str_attr:
break
else:
raise ValueError("No main function found")
assert isinstance(f, func.FuncOp)
iteration_bounds: Sequence[int] = ()
if "iteration_bounds" in f.attributes:
iteration_bounds = ir.DenseI64ArrayAttr(f.attributes["iteration_bounds"]) # type: ignore
dynamic = ir.ShapedType.get_dynamic_size()
if any(b == dynamic for b in iteration_bounds):
raise ValueError("Dynamic iteration bounds not supported")
dimension_semantics = ir.ArrayAttr(f.attributes["dimension_semantics"])
parallel = ir.Attribute.parse("#tpu.dimension_semantics<parallel>")
if any(s != parallel for s in dimension_semantics):
raise NotImplementedError("Non-parallel dimensions not supported")
num_scalar_prefetch = 0
if "scalar_prefetch" in f.attributes:
num_scalar_prefetch = ir.IntegerAttr(f.attributes["scalar_prefetch"]).value
(entry_block,) = f.body
ctx = PrintCtx(iteration_bounds)
sem_ty = ir.Type.parse("!tpu.semaphore")
dma_sem_ty = ir.Type.parse("!tpu.dma_semaphore")
program_id_args, prefetch_args, other_args = split_list(
entry_block.arguments, [len(iteration_bounds), num_scalar_prefetch]
)
for arg, model in zip(program_id_args, ctx.program_ids, strict=True):
ctx.set(arg, model)
del prefetch_args # We ignore prefetch_args
for arg in other_args:
if ir.MemRefType.isinstance(arg.type):
ty = ir.MemRefType(arg.type)
if ty.element_type == sem_ty or ty.element_type == dma_sem_ty:
ctx.set(arg, ctx.emit_global_semaphore_ref(ty.shape))
else:
ctx.set(arg, ctx.emit_global_ref(ty.shape))
_print_block(ctx, entry_block)
return ctx.get_model(
uses_barrier_semaphores, num_devices, num_cores_per_device, iteration_bounds
)

View File

@ -949,14 +949,11 @@ def _pallas_call_lowering(
def tpu_lowering(ctx: mlir.LoweringRuleContext,
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
**params):
try:
from jax._src.pallas.mosaic import pallas_call_registration
except ImportError:
if mosaic_tpu_backend is None:
raise _unsupported_lowering_error("tpu")
else:
return pallas_call_registration.pallas_call_tpu_lowering_rule(
ctx, *in_nodes, **params
)
return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
ctx, *in_nodes, **params
)
def gpu_lowering(ctx: mlir.LoweringRuleContext,
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
@ -968,10 +965,9 @@ def _pallas_call_lowering(
from jax._src.pallas.triton import pallas_call_registration # type: ignore
except ImportError:
raise _unsupported_lowering_error("gpu")
else:
return pallas_call_registration.pallas_call_lowering(
ctx, *in_nodes, **params
)
return pallas_call_registration.pallas_call_lowering(
ctx, *in_nodes, **params
)
return mlir.lower_per_platform(ctx, "pallas_call",
dict(cpu=cpu_lowering,
@ -1122,3 +1118,13 @@ def pallas_call(
out = tree_util.tree_unflatten(out_tree, out_flat)
return out
return wrapped
# We import the TPU backend at the top level because it defines flags. Note that
# we can only do that at the bottom of this file, beacuse it also depends on
# this module already being initialized.
try:
from jax._src.pallas.mosaic import pallas_call_registration as mosaic_tpu_backend
except ImportError:
mosaic_tpu_backend = None # type: ignore

View File

@ -1550,6 +1550,10 @@ tf_not_yet_impl = [
"consume",
"ragged_dot",
"cholesky_update",
# Pallas TPU primitives
"bitcast",
"repeat",
"roll",
]
tf_impl[random_internal.random_clone_p] = lambda x: x

View File

@ -55,10 +55,3 @@ from jax._src.state.primitives import broadcast_to
from jax._src.deprecations import register as _register_deprecation
_register_deprecation("pallas-block-spec-order")
del _register_deprecation
# Import tpu_custom_call for its flag definitions, needed for cross-platform lowering
try:
from jax._src import tpu_custom_call # pytype: disable=import-error
del tpu_custom_call
except ImportError:
pass