mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Integrate Triton up to [9f816a7b](9f816a7b98
)
PiperOrigin-RevId: 602641874
This commit is contained in:
parent
da6fa63bf3
commit
66308c30ad
@ -61,8 +61,8 @@ from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
|
||||
from jax_triton.triton_lib import get_triton_type
|
||||
import numpy as np
|
||||
from triton._C.libtriton import ir as tl_ir
|
||||
import triton.backends.nvidia.compiler as cb
|
||||
from triton.compiler import code_generator as code_gen
|
||||
import triton.compiler.backends.cuda as cb
|
||||
|
||||
|
||||
# TODO(sharadmv): Enable type checking.
|
||||
|
@ -28,7 +28,7 @@ from jaxlib.mlir.dialects import arith as arith_dialect
|
||||
from jaxlib.mlir.dialects import math as math_dialect
|
||||
from jaxlib.mlir.dialects import scf as scf_dialect
|
||||
import numpy as np
|
||||
import triton.compiler.backends.cuda as cb
|
||||
import triton.backends.nvidia.compiler as cb
|
||||
import triton.language as tl
|
||||
|
||||
from . import dialect as tt_dialect
|
||||
@ -1108,9 +1108,6 @@ def set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None:
|
||||
op.attributes[name] = attr
|
||||
|
||||
|
||||
_LIBDEVICE_PATH = tl.math.libdevice_path()
|
||||
|
||||
|
||||
def libdevice_extern_elementwise(
|
||||
table: Mapping[tuple[dtype, ...], tuple[str, dtype]],
|
||||
is_pure: bool = True,
|
||||
@ -1132,8 +1129,8 @@ def libdevice_extern_elementwise(
|
||||
tt_dialect.extern_elementwise(
|
||||
return_type.to_ir(builder.current),
|
||||
[arg.handle for arg in args],
|
||||
libname="libdevice",
|
||||
libpath=_LIBDEVICE_PATH,
|
||||
libname="",
|
||||
libpath="",
|
||||
symbol=symbol,
|
||||
pure=is_pure,
|
||||
),
|
||||
|
Loading…
x
Reference in New Issue
Block a user