Export PointerType and register_dialect from jaxlib.triton.dialect

The `... as ...` form tells the type checker that the name is exported.
See #7570.

PiperOrigin-RevId: 671318047
This commit is contained in:
Sergei Lebedev 2024-09-05 04:14:48 -07:00 committed by jax authors
parent fb77c73ba8
commit f3b91b2042

View File

@ -21,9 +21,9 @@ from __future__ import annotations
from collections.abc import Sequence
from jaxlib.mlir._mlir_libs._triton_ext import (
PointerType,
infer_reduce_op_encoding,
register_dialect,
PointerType as PointerType,
register_dialect as register_dialect,
infer_reduce_op_encoding as _infer_reduce_op_encoding,
)
from jaxlib.mlir import ir
@ -86,7 +86,7 @@ def _infer_reduce_op_return_types(
if not shape:
return_types.append(op_type.element_type)
elif op_encoding := op_type.encoding:
encoding = infer_reduce_op_encoding(op_encoding, axis)
encoding = _infer_reduce_op_encoding(op_encoding, axis)
if encoding is not None:
raise RuntimeError("Failed to infer return type encoding for ReduceOp")
return_types.append(