mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Mosaic GPU] Fix layout API bugs.
PiperOrigin-RevId: 715077057
This commit is contained in:
parent
dabe27bc1b
commit
f69592ae78
@ -652,6 +652,16 @@ class ParameterizedLayout:
|
||||
kwargs: Any
|
||||
|
||||
|
||||
def _get_mgpu_layout(layout: Layout | ParameterizedLayout
|
||||
) -> mgpu.FragmentedLayout:
|
||||
if isinstance(layout, Layout):
|
||||
return layout.value()
|
||||
elif isinstance(layout, ParameterizedLayout):
|
||||
return layout.layout_cls.value(*layout.args,
|
||||
**layout.kwargs)
|
||||
else:
|
||||
raise TypeError(f"Unsupported layout: {layout}")
|
||||
|
||||
layout_cast_p = jax_core.Primitive("layout_cast")
|
||||
|
||||
|
||||
@ -664,14 +674,7 @@ def _layout_cast_abstract_eval(x, new_layout):
|
||||
@lowering.register_lowering_rule(layout_cast_p)
|
||||
def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout):
|
||||
del ctx # Unused.
|
||||
if isinstance(new_layout, Layout):
|
||||
return x.to_layout(new_layout.value())
|
||||
elif isinstance(new_layout, ParameterizedLayout):
|
||||
layout = new_layout.layout_cls(*new_layout.args,
|
||||
**new_layout.kwargs)
|
||||
return x.to_layout(layout)
|
||||
else:
|
||||
raise TypeError(f"Unsupported layout: {new_layout}")
|
||||
return x.to_layout(_get_mgpu_layout(new_layout))
|
||||
|
||||
|
||||
def layout_cast(x: Any, new_layout: Layout | ParameterizedLayout):
|
||||
@ -753,7 +756,7 @@ def _broadcasted_iota_lowering(
|
||||
return mgpu.FragmentedArray.splat(
|
||||
llvm_dialect.mlir_undef(mlir_dtype),
|
||||
shape,
|
||||
layout.value,
|
||||
_get_mgpu_layout(layout),
|
||||
is_signed=is_signed,
|
||||
).foreach(
|
||||
lambda _, idx: cast(idx[dimension]),
|
||||
|
Loading…
x
Reference in New Issue
Block a user