[Mosaic GPU][NFC] Address some previous stylistic comments.

PiperOrigin-RevId: 715772455
This commit is contained in:
Benjamin Chetioui 2025-01-15 06:20:29 -08:00 committed by jax authors
parent aa19f9c4c4
commit cdf490a5d0

View File

@ -23,29 +23,16 @@ from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import vector
from .fragmented_array import WGSplatFragLayout, WGStridedFragLayout
from .layouts import from_splat_fragmented_layout_attr
from .layouts import from_strided_fragmented_layout_attr
from .layouts import has_any_layout_set, has_in_layouts_set, has_out_layouts_set
from .layouts import in_layouts, out_layouts, should_have_layout
from .layouts import is_splat_fragmented_layout
from .layouts import is_strided_fragmented_layout
from .layouts import to_splat_fragmented_layout_attr
from .layouts import to_strided_fragmented_layout_attr
from . import layouts as layouts_lib
# mypy: ignore-errors
_layout_inference_rules: dict[
str,
Callable[[ir.OpView], tuple[list[ir.Attribute], list[ir.Attribute]] | None],
] = {}
OptionalLayouts = tuple[list[ir.Attribute], list[ir.Attribute]] | None
LayoutInferenceRule = Callable[[ir.OpView], OptionalLayouts]
_layout_inference_rules: dict[str, LayoutInferenceRule,] = {}
def _add_layout_inference_rule(
op: type[ir.OpView],
rule: Callable[
[ir.OpView], tuple[list[ir.Attribute], list[ir.Attribute]] | None
],
):
def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule):
_layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
@ -76,15 +63,15 @@ def _choose_representative_layout(
return None
strided_layouts: list[WGStridedFragLayout] = [
from_strided_fragmented_layout_attr(layout)
layouts_lib.from_strided_fragmented_layout_attr(layout)
for layout in layouts
if is_strided_fragmented_layout(layout)
if layouts_lib.is_strided_fragmented_layout(layout)
]
splat_layouts: list[WGSplatFragLayout] = list(
map(
from_splat_fragmented_layout_attr,
filter(is_splat_fragmented_layout, layouts),
layouts_lib.from_splat_fragmented_layout_attr,
filter(layouts_lib.is_splat_fragmented_layout, layouts),
)
)
@ -104,13 +91,13 @@ def _choose_representative_layout(
)
if not splat_layouts:
return to_strided_fragmented_layout_attr(strided_layouts[0])
return layouts_lib.to_strided_fragmented_layout_attr(strided_layouts[0])
if not strided_layouts:
return to_splat_fragmented_layout_attr(splat_layouts[0])
return layouts_lib.to_splat_fragmented_layout_attr(splat_layouts[0])
[strided_layout] = strided_layouts
return to_strided_fragmented_layout_attr(strided_layout)
return layouts_lib.to_strided_fragmented_layout_attr(strided_layout)
def _in_layout_for_operand(
@ -130,10 +117,10 @@ def _in_layout_for_operand(
o for o in op.operands if ir.VectorType.isinstance(o.type)
].index(operand)
if not has_in_layouts_set(op):
if not layouts_lib.has_in_layouts_set(op):
return None
return in_layouts(op)[operand_number]
return layouts_lib.in_layouts(op)[operand_number]
def _out_layout_for_result(
@ -153,15 +140,13 @@ def _out_layout_for_result(
r for r in op.results if ir.VectorType.isinstance(r.type)
].index(result)
if not has_out_layouts_set(op):
if not layouts_lib.has_out_layouts_set(op):
return None
return out_layouts(op)[result_number]
return layouts_lib.out_layouts(op)[result_number]
def _infer_pointwise_op_layouts(
op: ir.OpView,
) -> tuple[list[ir.Attribute], list[ir.Attribute]] | None:
def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
def is_array(v: ir.Value) -> bool:
return ir.VectorType.isinstance(v.type)
@ -169,14 +154,14 @@ def _infer_pointwise_op_layouts(
num_vector_operands = len([o for o in op.operands if is_array(o)])
num_vector_results = len([r for r in op.results if is_array(r)])
if has_in_layouts_set(op):
op_in_layouts = in_layouts(op)
if layouts_lib.has_in_layouts_set(op):
op_in_layouts = layouts_lib.in_layouts(op)
if op_in_layouts:
layout = op_in_layouts[0]
return (num_vector_operands * [layout], num_vector_results * [layout])
if has_out_layouts_set(op):
op_out_layouts = out_layouts(op)
if layouts_lib.has_out_layouts_set(op):
op_out_layouts = layouts_lib.out_layouts(op)
if op_out_layouts:
layout = op_out_layouts[0]
return (num_vector_operands * [layout], num_vector_results * [layout])
@ -260,7 +245,7 @@ def traverse_op(
def infer_layout(module: ir.Module):
def inference_step(op: ir.Operation):
if not should_have_layout(op):
if not layouts_lib.should_have_layout(op):
return
elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
pass
@ -300,10 +285,11 @@ def infer_layout(module: ir.Module):
if not ir.VectorType.isinstance(ty):
return None
layout = WGStridedFragLayout.from_shaped_type(ty)
return to_strided_fragmented_layout_attr(layout)
return layouts_lib.to_strided_fragmented_layout_attr(layout)
def set_default_layout(op: ir.OpView):
if should_have_layout(op) and not has_any_layout_set(op):
if (layouts_lib.should_have_layout(op) and
not layouts_lib.has_any_layout_set(op)):
in_layouts = []
for operand in op.operands:
if (layout := to_default_layout(operand.type)) is not None: