From f69592ae78cc302e5b5789172759b8548addeaf3 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 13 Jan 2025 12:58:54 -0800 Subject: [PATCH] [Mosaic GPU] Fix layout API bugs. PiperOrigin-RevId: 715077057 --- jax/_src/pallas/mosaic_gpu/primitives.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 37eec3b06..58955c6ec 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -652,6 +652,16 @@ class ParameterizedLayout: kwargs: Any +def _get_mgpu_layout(layout: Layout | ParameterizedLayout + ) -> mgpu.FragmentedLayout: + if isinstance(layout, Layout): + return layout.value() + elif isinstance(layout, ParameterizedLayout): + return layout.layout_cls.value(*layout.args, + **layout.kwargs) + else: + raise TypeError(f"Unsupported layout: {layout}") + layout_cast_p = jax_core.Primitive("layout_cast") @@ -664,14 +674,7 @@ def _layout_cast_abstract_eval(x, new_layout): @lowering.register_lowering_rule(layout_cast_p) def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout): del ctx # Unused. - if isinstance(new_layout, Layout): - return x.to_layout(new_layout.value()) - elif isinstance(new_layout, ParameterizedLayout): - layout = new_layout.layout_cls(*new_layout.args, - **new_layout.kwargs) - return x.to_layout(layout) - else: - raise TypeError(f"Unsupported layout: {new_layout}") + return x.to_layout(_get_mgpu_layout(new_layout)) def layout_cast(x: Any, new_layout: Layout | ParameterizedLayout): @@ -753,7 +756,7 @@ def _broadcasted_iota_lowering( return mgpu.FragmentedArray.splat( llvm_dialect.mlir_undef(mlir_dtype), shape, - layout.value, + _get_mgpu_layout(layout), is_signed=is_signed, ).foreach( lambda _, idx: cast(idx[dimension]),