[Mosaic GPU] Fix layout API bugs.

PiperOrigin-RevId: 715077057
This commit is contained in:
Justin Fu 2025-01-13 12:58:54 -08:00 committed by jax authors
parent dabe27bc1b
commit f69592ae78

View File

@ -652,6 +652,16 @@ class ParameterizedLayout:
kwargs: Any
def _get_mgpu_layout(layout: Layout | ParameterizedLayout
) -> mgpu.FragmentedLayout:
if isinstance(layout, Layout):
return layout.value()
elif isinstance(layout, ParameterizedLayout):
return layout.layout_cls.value(*layout.args,
**layout.kwargs)
else:
raise TypeError(f"Unsupported layout: {layout}")
layout_cast_p = jax_core.Primitive("layout_cast")
@ -664,14 +674,7 @@ def _layout_cast_abstract_eval(x, new_layout):
@lowering.register_lowering_rule(layout_cast_p)
def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout):
del ctx # Unused.
if isinstance(new_layout, Layout):
return x.to_layout(new_layout.value())
elif isinstance(new_layout, ParameterizedLayout):
layout = new_layout.layout_cls(*new_layout.args,
**new_layout.kwargs)
return x.to_layout(layout)
else:
raise TypeError(f"Unsupported layout: {new_layout}")
return x.to_layout(_get_mgpu_layout(new_layout))
def layout_cast(x: Any, new_layout: Layout | ParameterizedLayout):
@ -753,7 +756,7 @@ def _broadcasted_iota_lowering(
return mgpu.FragmentedArray.splat(
llvm_dialect.mlir_undef(mlir_dtype),
shape,
layout.value,
_get_mgpu_layout(layout),
is_signed=is_signed,
).foreach(
lambda _, idx: cast(idx[dimension]),