mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Migrated dot to lower directly to Triton IR
PiperOrigin-RevId: 603768074
This commit is contained in:
parent
5867a05cdd
commit
28eff4f9b8
@ -22,6 +22,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import partial, wraps
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith as arith_dialect
|
||||
@ -966,7 +967,70 @@ def reshape(x: tensor, dst_shape: Sequence[int]) -> tensor:
|
||||
)
|
||||
|
||||
|
||||
dot = wrap_with_builder(tl.core.dot)
|
||||
def _check_dot_operands(x_dtype: dtype, y_dtype: dtype, options: Any):
|
||||
# TODO(slebedev): Ensure that the dtypes are supported by CUDA.
|
||||
return
|
||||
|
||||
|
||||
def dot(
|
||||
x: tensor,
|
||||
y: tensor,
|
||||
acc: tensor | None = None,
|
||||
allow_tf32: bool = True,
|
||||
max_num_imprecise_acc: int | None = None,
|
||||
out_dtype: dtype = float32,
|
||||
) -> tensor:
|
||||
x_dims = [dim.__index__() for dim in x.shape]
|
||||
y_dims = [dim.__index__() for dim in y.shape]
|
||||
if min(*x_dims, *y_dims) < 16:
|
||||
raise ValueError("all dimensions of x and y must be >= 16 ")
|
||||
if out_dtype.is_bf16():
|
||||
raise ValueError(f"out_dtype={out_dtype} is unsupported")
|
||||
b: builder = builder.current
|
||||
_check_dot_operands(x.dtype, y.dtype, b.options)
|
||||
if x.dtype.is_int():
|
||||
if x.dtype != int8:
|
||||
raise TypeError(f"unsupported dtype: {x.dtype}")
|
||||
zero = tensor(b.get_int32(0), int32)
|
||||
element_type = int32
|
||||
elif x.dtype.is_fp32() or x.dtype.is_bf16():
|
||||
zero = tensor(b.get_fp32(0), float32)
|
||||
element_type = float32
|
||||
else:
|
||||
if out_dtype.is_fp16():
|
||||
zero = tensor(b.get_fp16(0), float16)
|
||||
else:
|
||||
zero = tensor(b.get_fp32(0), float32)
|
||||
element_type = out_dtype
|
||||
|
||||
if element_type != out_dtype:
|
||||
raise TypeError(f"out_dtype={out_dtype} does not match element type {element_type}")
|
||||
|
||||
m, _ = x_dims
|
||||
_, n = y_dims
|
||||
result_type = block_type(element_type, [m, n])
|
||||
|
||||
if acc is None:
|
||||
acc = splat(zero, [m, n])
|
||||
else:
|
||||
assert acc.type == result_type
|
||||
|
||||
if max_num_imprecise_acc is None:
|
||||
if x.dtype.is_fp8() and y.dtype.is_fp8():
|
||||
max_num_imprecise_acc = b.options.max_num_imprecise_acc_default
|
||||
else:
|
||||
max_num_imprecise_acc = 0
|
||||
|
||||
return tensor(
|
||||
tt_dialect.dot(
|
||||
x.handle,
|
||||
y.handle,
|
||||
acc.handle if acc is not None else None,
|
||||
allow_tf32,
|
||||
max_num_imprecise_acc,
|
||||
),
|
||||
result_type,
|
||||
)
|
||||
|
||||
|
||||
def atomic_cas(
|
||||
|
Loading…
x
Reference in New Issue
Block a user