mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU][NFC] Address some previous stylistic comments.
PiperOrigin-RevId: 715772455
This commit is contained in:
parent
aa19f9c4c4
commit
cdf490a5d0
@ -23,29 +23,16 @@ from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
|
||||
from .fragmented_array import WGSplatFragLayout, WGStridedFragLayout
|
||||
from .layouts import from_splat_fragmented_layout_attr
|
||||
from .layouts import from_strided_fragmented_layout_attr
|
||||
from .layouts import has_any_layout_set, has_in_layouts_set, has_out_layouts_set
|
||||
from .layouts import in_layouts, out_layouts, should_have_layout
|
||||
from .layouts import is_splat_fragmented_layout
|
||||
from .layouts import is_strided_fragmented_layout
|
||||
from .layouts import to_splat_fragmented_layout_attr
|
||||
from .layouts import to_strided_fragmented_layout_attr
|
||||
from . import layouts as layouts_lib
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
_layout_inference_rules: dict[
|
||||
str,
|
||||
Callable[[ir.OpView], tuple[list[ir.Attribute], list[ir.Attribute]] | None],
|
||||
] = {}
|
||||
OptionalLayouts = tuple[list[ir.Attribute], list[ir.Attribute]] | None
|
||||
LayoutInferenceRule = Callable[[ir.OpView], OptionalLayouts]
|
||||
_layout_inference_rules: dict[str, LayoutInferenceRule,] = {}
|
||||
|
||||
|
||||
def _add_layout_inference_rule(
|
||||
op: type[ir.OpView],
|
||||
rule: Callable[
|
||||
[ir.OpView], tuple[list[ir.Attribute], list[ir.Attribute]] | None
|
||||
],
|
||||
):
|
||||
def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule):
|
||||
_layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
||||
|
||||
|
||||
@ -76,15 +63,15 @@ def _choose_representative_layout(
|
||||
return None
|
||||
|
||||
strided_layouts: list[WGStridedFragLayout] = [
|
||||
from_strided_fragmented_layout_attr(layout)
|
||||
layouts_lib.from_strided_fragmented_layout_attr(layout)
|
||||
for layout in layouts
|
||||
if is_strided_fragmented_layout(layout)
|
||||
if layouts_lib.is_strided_fragmented_layout(layout)
|
||||
]
|
||||
|
||||
splat_layouts: list[WGSplatFragLayout] = list(
|
||||
map(
|
||||
from_splat_fragmented_layout_attr,
|
||||
filter(is_splat_fragmented_layout, layouts),
|
||||
layouts_lib.from_splat_fragmented_layout_attr,
|
||||
filter(layouts_lib.is_splat_fragmented_layout, layouts),
|
||||
)
|
||||
)
|
||||
|
||||
@ -104,13 +91,13 @@ def _choose_representative_layout(
|
||||
)
|
||||
|
||||
if not splat_layouts:
|
||||
return to_strided_fragmented_layout_attr(strided_layouts[0])
|
||||
return layouts_lib.to_strided_fragmented_layout_attr(strided_layouts[0])
|
||||
|
||||
if not strided_layouts:
|
||||
return to_splat_fragmented_layout_attr(splat_layouts[0])
|
||||
return layouts_lib.to_splat_fragmented_layout_attr(splat_layouts[0])
|
||||
|
||||
[strided_layout] = strided_layouts
|
||||
return to_strided_fragmented_layout_attr(strided_layout)
|
||||
return layouts_lib.to_strided_fragmented_layout_attr(strided_layout)
|
||||
|
||||
|
||||
def _in_layout_for_operand(
|
||||
@ -130,10 +117,10 @@ def _in_layout_for_operand(
|
||||
o for o in op.operands if ir.VectorType.isinstance(o.type)
|
||||
].index(operand)
|
||||
|
||||
if not has_in_layouts_set(op):
|
||||
if not layouts_lib.has_in_layouts_set(op):
|
||||
return None
|
||||
|
||||
return in_layouts(op)[operand_number]
|
||||
return layouts_lib.in_layouts(op)[operand_number]
|
||||
|
||||
|
||||
def _out_layout_for_result(
|
||||
@ -153,15 +140,13 @@ def _out_layout_for_result(
|
||||
r for r in op.results if ir.VectorType.isinstance(r.type)
|
||||
].index(result)
|
||||
|
||||
if not has_out_layouts_set(op):
|
||||
if not layouts_lib.has_out_layouts_set(op):
|
||||
return None
|
||||
|
||||
return out_layouts(op)[result_number]
|
||||
return layouts_lib.out_layouts(op)[result_number]
|
||||
|
||||
|
||||
def _infer_pointwise_op_layouts(
|
||||
op: ir.OpView,
|
||||
) -> tuple[list[ir.Attribute], list[ir.Attribute]] | None:
|
||||
def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
|
||||
def is_array(v: ir.Value) -> bool:
|
||||
return ir.VectorType.isinstance(v.type)
|
||||
@ -169,14 +154,14 @@ def _infer_pointwise_op_layouts(
|
||||
num_vector_operands = len([o for o in op.operands if is_array(o)])
|
||||
num_vector_results = len([r for r in op.results if is_array(r)])
|
||||
|
||||
if has_in_layouts_set(op):
|
||||
op_in_layouts = in_layouts(op)
|
||||
if layouts_lib.has_in_layouts_set(op):
|
||||
op_in_layouts = layouts_lib.in_layouts(op)
|
||||
if op_in_layouts:
|
||||
layout = op_in_layouts[0]
|
||||
return (num_vector_operands * [layout], num_vector_results * [layout])
|
||||
|
||||
if has_out_layouts_set(op):
|
||||
op_out_layouts = out_layouts(op)
|
||||
if layouts_lib.has_out_layouts_set(op):
|
||||
op_out_layouts = layouts_lib.out_layouts(op)
|
||||
if op_out_layouts:
|
||||
layout = op_out_layouts[0]
|
||||
return (num_vector_operands * [layout], num_vector_results * [layout])
|
||||
@ -260,7 +245,7 @@ def traverse_op(
|
||||
|
||||
def infer_layout(module: ir.Module):
|
||||
def inference_step(op: ir.Operation):
|
||||
if not should_have_layout(op):
|
||||
if not layouts_lib.should_have_layout(op):
|
||||
return
|
||||
elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
|
||||
pass
|
||||
@ -300,10 +285,11 @@ def infer_layout(module: ir.Module):
|
||||
if not ir.VectorType.isinstance(ty):
|
||||
return None
|
||||
layout = WGStridedFragLayout.from_shaped_type(ty)
|
||||
return to_strided_fragmented_layout_attr(layout)
|
||||
return layouts_lib.to_strided_fragmented_layout_attr(layout)
|
||||
|
||||
def set_default_layout(op: ir.OpView):
|
||||
if should_have_layout(op) and not has_any_layout_set(op):
|
||||
if (layouts_lib.should_have_layout(op) and
|
||||
not layouts_lib.has_any_layout_set(op)):
|
||||
in_layouts = []
|
||||
for operand in op.operands:
|
||||
if (layout := to_default_layout(operand.type)) is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user