mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Export PointerType
and register_dialect
from jaxlib.triton.dialect
The `... as ...` form tells the type checker that the name is exported. See #7570. PiperOrigin-RevId: 671318047
This commit is contained in:
parent
fb77c73ba8
commit
f3b91b2042
@ -21,9 +21,9 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jaxlib.mlir._mlir_libs._triton_ext import (
|
||||
PointerType,
|
||||
infer_reduce_op_encoding,
|
||||
register_dialect,
|
||||
PointerType as PointerType,
|
||||
register_dialect as register_dialect,
|
||||
infer_reduce_op_encoding as _infer_reduce_op_encoding,
|
||||
)
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
@ -86,7 +86,7 @@ def _infer_reduce_op_return_types(
|
||||
if not shape:
|
||||
return_types.append(op_type.element_type)
|
||||
elif op_encoding := op_type.encoding:
|
||||
encoding = infer_reduce_op_encoding(op_encoding, axis)
|
||||
encoding = _infer_reduce_op_encoding(op_encoding, axis)
|
||||
if encoding is not None:
|
||||
raise RuntimeError("Failed to infer return type encoding for ReduceOp")
|
||||
return_types.append(
|
||||
|
Loading…
x
Reference in New Issue
Block a user