Migrated dot to lower directly to Triton IR

PiperOrigin-RevId: 603768074
This commit is contained in:
Sergei Lebedev 2024-02-02 13:08:47 -08:00 committed by jax authors
parent 5867a05cdd
commit 28eff4f9b8

View File

@ -22,6 +22,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from functools import partial, wraps
import threading
from typing import Any
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith as arith_dialect
@ -966,7 +967,70 @@ def reshape(x: tensor, dst_shape: Sequence[int]) -> tensor:
)
dot = wrap_with_builder(tl.core.dot)
def _check_dot_operands(x_dtype: dtype, y_dtype: dtype, options: Any):
# TODO(slebedev): Ensure that the dtypes are supported by CUDA.
return
def dot(
x: tensor,
y: tensor,
acc: tensor | None = None,
allow_tf32: bool = True,
max_num_imprecise_acc: int | None = None,
out_dtype: dtype = float32,
) -> tensor:
x_dims = [dim.__index__() for dim in x.shape]
y_dims = [dim.__index__() for dim in y.shape]
if min(*x_dims, *y_dims) < 16:
raise ValueError("all dimensions of x and y must be >= 16 ")
if out_dtype.is_bf16():
raise ValueError(f"out_dtype={out_dtype} is unsupported")
b: builder = builder.current
_check_dot_operands(x.dtype, y.dtype, b.options)
if x.dtype.is_int():
if x.dtype != int8:
raise TypeError(f"unsupported dtype: {x.dtype}")
zero = tensor(b.get_int32(0), int32)
element_type = int32
elif x.dtype.is_fp32() or x.dtype.is_bf16():
zero = tensor(b.get_fp32(0), float32)
element_type = float32
else:
if out_dtype.is_fp16():
zero = tensor(b.get_fp16(0), float16)
else:
zero = tensor(b.get_fp32(0), float32)
element_type = out_dtype
if element_type != out_dtype:
raise TypeError(f"out_dtype={out_dtype} does not match element type {element_type}")
m, _ = x_dims
_, n = y_dims
result_type = block_type(element_type, [m, n])
if acc is None:
acc = splat(zero, [m, n])
else:
assert acc.type == result_type
if max_num_imprecise_acc is None:
if x.dtype.is_fp8() and y.dtype.is_fp8():
max_num_imprecise_acc = b.options.max_num_imprecise_acc_default
else:
max_num_imprecise_acc = 0
return tensor(
tt_dialect.dot(
x.handle,
y.handle,
acc.handle if acc is not None else None,
allow_tf32,
max_num_imprecise_acc,
),
result_type,
)
def atomic_cas(