From 2ea222544e2d8560f7a38b8491e52617559638c8 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 17 Jul 2024 05:28:34 -0700 Subject: [PATCH] Add a Promela spec generator for Pallas TPU kernels This adds a simple model extractor for TPU kernels that generates a Promela spec outlining the semantics of semaphores and DMAs. The model can be fed into SPIN and used to e.g. verify the lack of data races or deadlocks. While compelte verification is very expensive, the tool seems especially good at finding races that are really there. PiperOrigin-RevId: 653198263 --- jax/BUILD | 1 + jax/_src/pallas/mosaic/BUILD | 10 + .../pallas/mosaic/pallas_call_registration.py | 41 ++ jax/_src/pallas/mosaic/verification.py | 502 ++++++++++++++++++ jax/_src/pallas/pallas_call.py | 28 +- jax/experimental/jax2tf/jax2tf.py | 4 + jax/experimental/pallas/__init__.py | 7 - 7 files changed, 575 insertions(+), 18 deletions(-) create mode 100644 jax/_src/pallas/mosaic/verification.py diff --git a/jax/BUILD b/jax/BUILD index d6be8cb85..cec3f833b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -635,6 +635,7 @@ pytype_strict_library( "//jax/_src/pallas/mosaic:pipeline", "//jax/_src/pallas/mosaic:primitives", "//jax/_src/pallas/mosaic:random", + "//jax/_src/pallas/mosaic:verification", ], ) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 4c849dfba..57dad7793 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -33,6 +33,16 @@ py_library( ], ) +py_library( + name = "verification", + srcs = ["verification.py"], + deps = [ + "//jax", + "//jax:mlir", + "//jax/_src/lib", + ], +) + py_library( name = "primitives", srcs = ["primitives.py"], diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index e7d6726b5..79ab23c45 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -16,18 +16,22 @@ from __future__ import annotations +import os +import tempfile from typing import Any import warnings import jax from jax import dtypes from jax import core as jax_core +from jax._src import config from jax._src import core as jax_src_core from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core from jax._src.pallas.mosaic import lowering +from jax._src.pallas.mosaic import verification from jax._src.pallas.pallas_call import pallas_call_p from jax.experimental import mosaic from jax.experimental.mosaic.dialects import tpu @@ -48,6 +52,17 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray): return jax_core.ShapedArray(x.shape, lowering.BOOL_MEMREF_TYPE) return x +_DUMP_PROMELA_TO = config.string_flag( + "jax_pallas_dump_promela_to", + default=os.getenv("JAX_PALLAS_DUMP_PROMELA_TO", ""), + help=( + "If set, dumps a Promela model of the kernel to the specified" + " directory. The model can verify that the kernel is free of data" + " races, deadlocks, etc." + ), +) + + def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, @@ -106,6 +121,32 @@ def pallas_call_tpu_lowering_rule( ) out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] + if promela_dump_path := _DUMP_PROMELA_TO.value: + num_devices = 1 if mesh is None else mesh.devices.size + num_cores = ( + jax.devices()[0].num_cores + if mesh is None + else mesh.devices[0].num_cores + ) + model = verification.export_promela_model( + mosaic_module, num_devices, num_cores + ) + if promela_dump_path == "stdout": + print(model) + else: + if promela_dump_path == "sponge": + promela_dump_path = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "") + if not promela_dump_path: + raise ValueError( + "TEST_UNDECLARED_OUTPUTS_DIR must be set when" + " --jax_pallas_dump_promela_to=sponge" + ) + dump_ctx = tempfile.NamedTemporaryFile( + mode="w", prefix=name + "-", suffix=".pml", dir=promela_dump_path, delete=False, + ) + with dump_ctx as f: + f.write(model) + # Replace in_avals to physical avals. # This step is required for mapping logical types to physical types. # (e.g. PRNG key -> uint32[2]) diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py new file mode 100644 index 000000000..d8a383254 --- /dev/null +++ b/jax/_src/pallas/mosaic/verification.py @@ -0,0 +1,502 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import io +import itertools +import math +import textwrap +from typing import Any, Sequence +from jaxlib.mlir import ir +from jaxlib.mlir.passmanager import PassManager +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import func +from jax._src.lib import tpu +from jax._src.util import split_list + +_UNSPECIFIED = object() +Var = str + +# TODO(apaszke): Add checks that semaphores are always left at 0. +# TODO(apaszke): Add checks that no remote resources are used while the remote +# device is not in the kernel (both before and after). +# TODO(apaszke): Model 0-sized DMAs faithfully. + + +PREAMBLE = """ +#define buf_readers(index, device, core) _buf_readers[(index)*(NDEVICE*NCORE) + (device)*NCORE + core] +#define buf_written(index, device, core) _buf_written[(index)*(NDEVICE*NCORE) + (device)*NCORE + core] +#define sems(index, device, core) _sems[(index)*(NDEVICE*NCORE) + (device)*NCORE + core] +#define barrier_sems(device, core) _barrier_sems[(device)*NCORE + core] + +#ifndef NDMA +#define NDMA 2 +#endif + +mtype = { DMA }; +chan dma_queue = [NDMA] of { mtype, int, int, int, int, int, int, int, int, int, int }; +""" + +DMA_PROCESS = """ +active [NDMA] proctype DmaEngine() { + int src_dev, src_core, src_sem, src_buf_base, src_buf_len; + int dst_dev, dst_core, dst_sem, dst_buf_base, dst_buf_len; + do + :: skip; + end: dma_queue?DMA(src_dev, src_core, src_sem, src_buf_base, src_buf_len, dst_dev, dst_core, dst_sem, dst_buf_base, dst_buf_len); + d_step { + printf("DMA read done: [%d, %d)@{%d, %d} (%d++)\\n", src_buf_base, src_buf_base + src_buf_len, src_dev, src_core, src_sem); + int i; + for (i : src_buf_base .. src_buf_base + src_buf_len - 1) { + buf_readers(i, src_dev, src_core)--; + } + sems(src_sem, src_dev, src_core)++; + } // Read read + d_step { + printf("DMA write done: [%d, %d)@{%d, %d} (%d++)\\n", dst_buf_base, dst_buf_base + dst_buf_len, dst_dev, dst_core, dst_sem); + int i; + for (i : dst_buf_base .. dst_buf_base + dst_buf_len - 1) { + buf_written(i, dst_dev, dst_core)--; + } + sems(dst_sem, dst_dev, dst_core)++; + } // Write complete + od +} +""" + +class PrintCtx: + MAX_REF_UNROLL = 8 + + def __init__(self, iteration_bounds): + self.level = 1 + self.num_semaphores = 0 + self.num_buffers = 0 + self.locals = [] + self.counter = itertools.count() + self.env: dict[ir.Value, Var | int] = {} + self.program_ids = tuple(f"pid{i}" for i in range(len(iteration_bounds))) + self.device_id = "dev_id" + + # TODO(apaszke): Clean up core_id! This is not a visible detail in the Mosaic + # programming model. + self.emit(None, "int core_id = 0") + # Reconstruct device id and program ids from the pid. + self.emit(None, "int dev_id") + if iteration_bounds: + self.emit(None, f"int {', '.join(self.program_ids)}") + with self.block("d_step {", "}"): + idx = "_pid" + program_ids = [] + for i, b in reversed(list(enumerate(iteration_bounds))): + program_ids.append(self.emit(None, f"pid{i} = {idx} % {b}")) + idx = self.emit("int", f"{idx} / {b}") + self.emit(None, f"dev_id = {idx}") + + def emit_global_ref(self, shape: Sequence[int]): + slots = 1 + if shape and shape[0] <= self.MAX_REF_UNROLL: + slots = shape[0] + base = self.num_buffers + self.num_buffers += slots + return GlobalRefModel(base, slots) + + def emit_global_semaphore_ref(self, shape: Sequence[int]): + count = math.prod(shape) + base = self.num_semaphores + self.num_semaphores += count + return GlobalSemaphoreModel(base, count) + + def _indent(self, text: str) -> str: + return textwrap.indent(text, " " * self.level) + + def emit(self, ty, expr): + name = None + if ty is not None: + name = "l" + str(next(self.counter)) + expr = f"{ty} {name} = {expr}" + self.locals.append(self._indent(expr) + ";\n") + return name + + def comment(self, comment): + self.locals.append(self._indent(f"/* {comment} */\n")) + + @contextlib.contextmanager + def block(self, begin: str, end: str): + self.locals.append(self._indent(begin) + "\n") + self.level += 1 + yield + self.level -= 1 + self.locals.append(self._indent(end) + "\n") + + def get(self, value: ir.Value, default: Any = _UNSPECIFIED): + if default is _UNSPECIFIED: + return self.env[value] + else: + return self.env.get(value, default) + + def set(self, value: ir.Value, model_value: Any): + self.env[value] = model_value + + def get_model( + self, + has_barrier_sems: bool, + num_devices: int, + num_cores: int, + parallel_iteration_bounds: Sequence[int], + ) -> str: + result = io.StringIO() + result.write(f"#define NDEVICE {num_devices}\n") + result.write("#define NCORE 1\n") + result.write(f"int _buf_readers[{self.num_buffers}*NDEVICE*NCORE] = 0;\n") + result.write(f"bool _buf_written[{self.num_buffers}*NDEVICE*NCORE] = 0;\n") + result.write(f"int _sems[{self.num_semaphores}*NDEVICE*NCORE] = 0;\n") + if has_barrier_sems: + result.write("int _barrier_sems[NDEVICE*NCORE] = 0;\n") + result.write(PREAMBLE) + result.write("\n") + parallel_threads = math.prod(parallel_iteration_bounds) + result.write(f"active [NDEVICE*{parallel_threads}] proctype Kernel() {{\n") + for l in self.locals: + result.write(l) + result.write("}\n") + result.write(DMA_PROCESS) + return result.getvalue() + + +def resolve_location(location): + if location is None: + location = [None, None] + else: + location = list(location) + if location[0] is None: + location[0] = "dev_id" + if location[1] is None: + location[1] = "core_id" + return tuple(location) + + +@dataclasses.dataclass(frozen=True) +class GlobalRefModel: + """A model of a memory reference. + + When a reference has a small leading dimension, it might be represented by + multiple slots in the reference array. Its region starts at base (that can be + dynamic) and has the given length (always static). + """ + base: Any + length: int + + def readers_at(self, location): + dev, core = resolve_location(location) + return [f"buf_readers({self.base} + {i}, {dev}, {core})" for i in range(self.length)] + + def written_at(self, location): + dev, core = resolve_location(location) + return [f"buf_written({self.base} + {i}, {dev}, {core})" for i in range(self.length)] + + +@dataclasses.dataclass(frozen=True) +class GlobalSemaphoreModel: + """A model of a semaphore reference. + + Semaphore arrays are always fully unrolled and are represented by a contiguous + subset of the global semaphore array. + """ + base: Any + length: int + + def at(self, location): + dev, core = resolve_location(location) + return f"sems({self.base}, {dev}, {core})" + + +@dataclasses.dataclass(frozen=True) +class GlobalBarrierSemaphoreModel: + def at(self, location): + dev, core = resolve_location(location) + return f"barrier_sems({dev}, {core})" + + +def _print_op(ctx, op): + match op.OPERATION_NAME: + case "tpu.region": + _print_block(ctx, op.body) + case "tpu.device_id": + return ctx.device_id + case "arith.constant": + if ir.IntegerType.isinstance(op.result.type): + return str(ir.IntegerAttr(op.value).value) + else: + return + case "tpu.sem_signal": + location = resolve_location((ctx.get(op.device_id, None), ctx.get(op.core_id, None))) + sem_model = ctx.get(op.semaphore) + sem = sem_model.at(location) + amount = ctx.get(op.amount) + if isinstance(sem_model, GlobalBarrierSemaphoreModel): + ctx.emit(None, f'printf("Signal: BARRIER@{{%d, %d}} += %d\\n", {location[0]}, {location[1]}, {amount})') + else: + ctx.emit(None, f'printf("Signal: %d@{{%d, %d}} += %d\\n", {sem_model.base}, {location[0]}, {location[1]}, {amount})') + ctx.emit(None, f"d_step {{ {sem} = {sem} + {amount} }}") + case "tpu.sem_wait": + sem_model = ctx.get(op.semaphore) + sem = sem_model.at(location=None) + amount = ctx.get(op.amount) + ctx.emit(None, f"atomic {{ {sem} >= {amount}; {sem} = {sem} - {amount} }}") + if isinstance(sem_model, GlobalBarrierSemaphoreModel): + ctx.emit(None, f'printf("Wait done: BARRIER -= %d\\n", {amount})') + else: + ctx.emit(None, f'printf("Wait done: %d -= %d\\n", {sem_model.base}, {amount})') + case "tpu.enqueue_dma": + dst_location = resolve_location((ctx.get(op.device_id, None), ctx.get(op.core_id, None))) + src = ctx.get(op.source) + src_sem = ctx.get(op.source_semaphore) + dst = ctx.get(op.target) + dst_sem = ctx.get(op.target_semaphore) + src_readonly = "\n && ".join(is_written + " == 0" for is_written in src.written_at(None)) + dst_unused = "\n && ".join( + is_written + " == 0" + for is_written in itertools.chain( + dst.written_at(dst_location), dst.readers_at(dst_location) + ) + ) + ctx.emit( + None, + 'printf("DMA: [%d, %d)@{%d, %d} -> [%d, %d)@{%d, %d}\\n",' + f" {src.base}, {src.base} + {src.length}, dev_id, core_id," + f" {dst.base}, {dst.base} + {dst.length}, {dst_location[0]}," + f" {dst_location[1]})", + ) + with ctx.block("d_step {", "}"): + ctx.emit(None, f"assert({src_readonly}); // Source is not written to.") + ctx.emit(None, f"assert({dst_unused}); // Destination is unused.") + for r in src.readers_at(None): + ctx.emit(None, f"{r}++") + for w in dst.written_at(dst_location): + ctx.emit(None, f"{w} = 1") + ctx.emit( + None, + f"dma_queue!DMA(dev_id, core_id, {src_sem.base}, {src.base}," + f" {src.length}, {dst_location[0]}, {dst_location[1]}," + f" {dst_sem.base}, {dst.base}, {dst.length})", + ) + case "tpu.wait_dma": + sem_model = ctx.get(op.semaphore) + sem = sem_model.at(location=None) + ctx.emit(None, f"atomic {{ {sem} >= 1; {sem} = {sem} - 1 }}") + ctx.emit(None, f'printf("Awaited DMA: %d\\n", {sem_model.base})') + case "tpu.sem_barrier": + return GlobalBarrierSemaphoreModel() + case "tpu.memref_slice": + result = ctx.get(op.mem_ref, None) + if result is None: + return NotImplemented + src_shape = ir.MemRefType(op.mem_ref.type).shape + dst_shape = ir.MemRefType(op.result.type).shape + dynamic = ir.ShapedType.get_dynamic_size() + # We always unroll semaphore references entirely, and we need to be + # faithful when slicing them. + if isinstance(result, GlobalSemaphoreModel): + # We only support contiguous slices of semaphore arrays at the moment. + seen_nontrivial_unequal = False + for s, d in zip(src_shape, dst_shape): + if d == 1: + continue + if s != d: + if seen_nontrivial_unequal: + raise NotImplementedError("Non-contiguous slices of semaphore arrays") + seen_nontrivial_unequal = True + strides = [] + stride = 1 + for s in src_shape[::-1]: + strides.append(stride) + stride *= s + strides = reversed(strides) + indices = [ctx.get(idx) for idx in op.base_idx] + linear_offset = " + ".join(f"{idx} * {s}" for idx, s in zip(indices, strides)) + return GlobalSemaphoreModel( + base=f"{result.base} + {linear_offset}", length=math.prod(dst_shape) + ) + else: + assert isinstance(result, GlobalRefModel) + major_idx = ctx.get(op.base_idx[0], None) + if (not src_shape or src_shape[0] == dynamic or dst_shape[0] == dynamic + or result.length == 1 or major_idx is None): + return result + return GlobalRefModel(f"{result.base} + {major_idx}", dst_shape[0]) + case "tpu.memref_squeeze": + result = ctx.get(op.input, None) + return NotImplemented if result is None else result + case "tpu.assume_multiple": + result = ctx.get(op.value, None) + return NotImplemented if result is None else result + case "arith.addi": + return bin_op(ctx, "int", "+", *op.operands) + case "arith.subi": + return bin_op(ctx, "int", "-", *op.operands) + case "arith.muli": + return bin_op(ctx, "int", "*", *op.operands) + case "arith.remsi": + # TODO(apaszke): Make sure this has right semantics for negative integers. + return bin_op(ctx, "int", "%", *op.operands) + case "arith.divsi": + return bin_op(ctx, "int", "/", *op.operands) + case "arith.cmpi": + match op.predicate.value: + case arith.CmpIPredicate.eq: + return bin_op(ctx, "bool", "==", *op.operands) + case arith.CmpIPredicate.ne: + return bin_op(ctx, "bool", "!=", *op.operands) + case arith.CmpIPredicate.slt: + return bin_op(ctx, "bool", "<", *op.operands) + case arith.CmpIPredicate.sle: + return bin_op(ctx, "bool", "<=", *op.operands) + case arith.CmpIPredicate.sgt: + return bin_op(ctx, "bool", ">", *op.operands) + case arith.CmpIPredicate.sge: + return bin_op(ctx, "bool", ">=", *op.operands) + return bin_op(ctx, "bool", "/", *op.operands) + case "tpu.trace_start": + ctx.comment(op.message.value) + case "scf.for": + carrys = [ + ctx.emit("int", ctx.get(arg)) + if ir.IntegerType.isinstance(arg.type) else None + for arg in op.initArgs + ] + bounds = (op.lowerBound, op.upperBound, op.step) + lower, upper, step = bound_models = map(ctx.get, bounds) + for model, v in zip(bound_models, bounds): + if model is None: + raise ValueError(f"Could not model loop bound or step: {v}") + induction_var = ctx.emit("int", lower) + with ctx.block("do", "od"): + ctx.emit(None, f":: {induction_var} < {upper}; ") + ctx.set(op.induction_variable, induction_var) + for c, arg in zip(carrys, op.inner_iter_args, strict=True): + if c is not None: + ctx.set(arg, c) + _print_block(ctx, op.body) + terminator = op.body.operations[len(op.body.operations) - 1] + new_carrys = terminator.operands + with ctx.block("d_step {", "}"): + for c, new in zip(carrys, new_carrys, strict=True): + if c is not None: + ctx.emit(None, f"{c} = {ctx.get(new)}") + ctx.emit(None, f"{induction_var} = {induction_var} + {step}") + ctx.emit(None, ":: else -> break") + if len(carrys) == 1: + return carrys[0] + else: + return tuple(carrys) + case "scf.if": + if op.results: + raise NotImplementedError + if (condition := ctx.get(op.condition, None)) is None: + raise ValueError(f"Could not model branch condition: {op.condition}") + with ctx.block("if", "fi"): + ctx.emit(None, f":: ({condition})") + _print_block(ctx, op.then_block) + if op.regions[1].blocks: + ctx.emit(None, ":: else") + _print_block(ctx, op.else_block) + else: + ctx.emit(None, ":: else -> skip") + case _: + if not op.regions: + return NotImplemented + raise NotImplementedError("Must handle all ops with regions") + + +def bin_op(ctx, result_ty, op, lhs, rhs): + lhs = ctx.get(lhs, None) + rhs = ctx.get(rhs, None) + if lhs is None or rhs is None: + return NotImplemented + return ctx.emit(result_ty, f"{lhs} {op} {rhs}") + + +def _print_block(ctx, block): + for op in block: + try: + results = _print_op(ctx, op) + except Exception as e: + raise RuntimeError(f"Failed to print op: {op}") from e + if results is NotImplemented: + continue + if not op.results: + assert results is None + elif len(op.results) > 1: + raise NotImplementedError(op) + else: + ctx.set(op.result, results) + + +def export_promela_model( + module, num_devices: int, num_cores_per_device: int +) -> str: + with module.context: + _, uses_barrier_semaphores = tpu.private_has_communication(module.operation) + # Clone the module and simplify it to make the model smaller and simpler. + module = ir.Module.parse(module.operation.get_asm(binary=True)) + passes = ["canonicalize", "cse"] + pipeline = PassManager.parse(f"builtin.module({','.join(passes)})") + pipeline.run(module.operation) + main_str_attr = ir.StringAttr.get("main") + for f in module.body: + if getattr(f, "name", None) == main_str_attr: + break + else: + raise ValueError("No main function found") + assert isinstance(f, func.FuncOp) + + iteration_bounds: Sequence[int] = () + if "iteration_bounds" in f.attributes: + iteration_bounds = ir.DenseI64ArrayAttr(f.attributes["iteration_bounds"]) # type: ignore + dynamic = ir.ShapedType.get_dynamic_size() + if any(b == dynamic for b in iteration_bounds): + raise ValueError("Dynamic iteration bounds not supported") + + dimension_semantics = ir.ArrayAttr(f.attributes["dimension_semantics"]) + + parallel = ir.Attribute.parse("#tpu.dimension_semantics") + if any(s != parallel for s in dimension_semantics): + raise NotImplementedError("Non-parallel dimensions not supported") + + num_scalar_prefetch = 0 + if "scalar_prefetch" in f.attributes: + num_scalar_prefetch = ir.IntegerAttr(f.attributes["scalar_prefetch"]).value + + (entry_block,) = f.body + ctx = PrintCtx(iteration_bounds) + sem_ty = ir.Type.parse("!tpu.semaphore") + dma_sem_ty = ir.Type.parse("!tpu.dma_semaphore") + program_id_args, prefetch_args, other_args = split_list( + entry_block.arguments, [len(iteration_bounds), num_scalar_prefetch] + ) + for arg, model in zip(program_id_args, ctx.program_ids, strict=True): + ctx.set(arg, model) + del prefetch_args # We ignore prefetch_args + for arg in other_args: + if ir.MemRefType.isinstance(arg.type): + ty = ir.MemRefType(arg.type) + if ty.element_type == sem_ty or ty.element_type == dma_sem_ty: + ctx.set(arg, ctx.emit_global_semaphore_ref(ty.shape)) + else: + ctx.set(arg, ctx.emit_global_ref(ty.shape)) + _print_block(ctx, entry_block) + return ctx.get_model( + uses_barrier_semaphores, num_devices, num_cores_per_device, iteration_bounds + ) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 0bac322fe..359a42e86 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -949,14 +949,11 @@ def _pallas_call_lowering( def tpu_lowering(ctx: mlir.LoweringRuleContext, *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], **params): - try: - from jax._src.pallas.mosaic import pallas_call_registration - except ImportError: + if mosaic_tpu_backend is None: raise _unsupported_lowering_error("tpu") - else: - return pallas_call_registration.pallas_call_tpu_lowering_rule( - ctx, *in_nodes, **params - ) + return mosaic_tpu_backend.pallas_call_tpu_lowering_rule( + ctx, *in_nodes, **params + ) def gpu_lowering(ctx: mlir.LoweringRuleContext, *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], @@ -968,10 +965,9 @@ def _pallas_call_lowering( from jax._src.pallas.triton import pallas_call_registration # type: ignore except ImportError: raise _unsupported_lowering_error("gpu") - else: - return pallas_call_registration.pallas_call_lowering( - ctx, *in_nodes, **params - ) + return pallas_call_registration.pallas_call_lowering( + ctx, *in_nodes, **params + ) return mlir.lower_per_platform(ctx, "pallas_call", dict(cpu=cpu_lowering, @@ -1122,3 +1118,13 @@ def pallas_call( out = tree_util.tree_unflatten(out_tree, out_flat) return out return wrapped + + +# We import the TPU backend at the top level because it defines flags. Note that +# we can only do that at the bottom of this file, beacuse it also depends on +# this module already being initialized. + +try: + from jax._src.pallas.mosaic import pallas_call_registration as mosaic_tpu_backend +except ImportError: + mosaic_tpu_backend = None # type: ignore diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 0a8867bb6..c0b4b39f6 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1550,6 +1550,10 @@ tf_not_yet_impl = [ "consume", "ragged_dot", "cholesky_update", + # Pallas TPU primitives + "bitcast", + "repeat", + "roll", ] tf_impl[random_internal.random_clone_p] = lambda x: x diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index e2ed7ebe5..b4c402eae 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -55,10 +55,3 @@ from jax._src.state.primitives import broadcast_to from jax._src.deprecations import register as _register_deprecation _register_deprecation("pallas-block-spec-order") del _register_deprecation - -# Import tpu_custom_call for its flag definitions, needed for cross-platform lowering -try: - from jax._src import tpu_custom_call # pytype: disable=import-error - del tpu_custom_call -except ImportError: - pass