Integrate Triton up to [9f816a7b](9f816a7b98)

PiperOrigin-RevId: 602641874
This commit is contained in:
Goran Flegar 2024-01-30 01:15:23 -08:00 committed by jax authors
parent da6fa63bf3
commit 66308c30ad
2 changed files with 4 additions and 7 deletions

View File

@ -61,8 +61,8 @@ from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
from jax_triton.triton_lib import get_triton_type
import numpy as np
from triton._C.libtriton import ir as tl_ir
import triton.backends.nvidia.compiler as cb
from triton.compiler import code_generator as code_gen
import triton.compiler.backends.cuda as cb
# TODO(sharadmv): Enable type checking.

View File

@ -28,7 +28,7 @@ from jaxlib.mlir.dialects import arith as arith_dialect
from jaxlib.mlir.dialects import math as math_dialect
from jaxlib.mlir.dialects import scf as scf_dialect
import numpy as np
import triton.compiler.backends.cuda as cb
import triton.backends.nvidia.compiler as cb
import triton.language as tl
from . import dialect as tt_dialect
@ -1108,9 +1108,6 @@ def set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None:
op.attributes[name] = attr
_LIBDEVICE_PATH = tl.math.libdevice_path()
def libdevice_extern_elementwise(
table: Mapping[tuple[dtype, ...], tuple[str, dtype]],
is_pure: bool = True,
@ -1132,8 +1129,8 @@ def libdevice_extern_elementwise(
tt_dialect.extern_elementwise(
return_type.to_ir(builder.current),
[arg.handle for arg in args],
libname="libdevice",
libpath=_LIBDEVICE_PATH,
libname="",
libpath="",
symbol=symbol,
pure=is_pure,
),