[Mosaic GPU] Allow unions nested inside the smem ref tree

We don't use this capability just yet, but I want to start allocating barriers
as part of the scratch and this will push the unions deeper into the tree.

PiperOrigin-RevId: 655475839
This commit is contained in:
Adam Paszke 2024-07-24 01:42:49 -07:00 committed by jax authors
parent 6bc7929376
commit 832eb2d8d2

View File

@ -512,23 +512,46 @@ def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
def _construct_smem_reftree(
dynamic_smem: ir.Value, smem_buffers: ShapeTree) -> RefTree:
dynamic_smem: ir.Value, smem_buffers: ShapeTree, dynamic_smem_offset: int = 0) -> RefTree:
index = ir.IndexType.get()
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(smem_buffers)
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
)
smem_refs = []
dynamic_smem_offset = 0
for ref_ty in flat_ref_tys:
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
tile_smem = memref.view(
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
dynamic_smem, c(dynamic_smem_offset, index), [],
)
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
smem_refs.append(tile_smem)
if isinstance(ref_ty, Union):
member_trees = [
_construct_smem_reftree(dynamic_smem, m, dynamic_smem_offset)
for m in ref_ty.members
]
# TODO(apaszke): This is quadratic, but it shouldn't matter for now...
dynamic_smem_offset += _smem_tree_size(ref_ty)
smem_refs.append(Union(member_trees))
else:
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
tile_smem = memref.view(
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
dynamic_smem, c(dynamic_smem_offset, index), [],
)
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
smem_refs.append(tile_smem)
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
leaves = jax.tree.leaves(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
)
size = 0
for l in leaves:
if isinstance(l, Union):
size += max(_smem_tree_size(s) for s in smem_buffers.members)
else:
size += _count_buffer_bytes(l)
return size
# TODO(apaszke): Inline this
@contextlib.contextmanager
def _launch(
@ -549,17 +572,9 @@ def _launch(
grid_vals = [c(i, index) for i in grid]
block_vals = [c(i, index) for i in block]
if isinstance(smem_buffers, Union):
smem_disjoint_live_buffers_collections = smem_buffers.members
compute_smem_bytes = max(
sum(_count_buffer_bytes(l) for l in jax.tree.leaves(s))
for s in smem_buffers.members)
else:
smem_disjoint_live_buffers_collections = [smem_buffers]
compute_smem_bytes = sum(
_count_buffer_bytes(l) for l in jax.tree.leaves(smem_buffers))
user_smem_bytes = _smem_tree_size(smem_buffers)
smem_bytes = compute_smem_bytes
smem_bytes = user_smem_bytes
if profiler_spec is not None:
smem_bytes += profiler_spec.smem_bytes(block=block)
@ -592,11 +607,7 @@ def _launch(
)
)
smem_ref_trees = []
for smem_live_buffers_collection in smem_disjoint_live_buffers_collections:
smem_ref_tree = _construct_smem_reftree(
dynamic_smem, smem_live_buffers_collection)
smem_ref_trees.append(smem_ref_tree)
smem_ref_tree = _construct_smem_reftree(dynamic_smem, smem_buffers)
if profiler_spec:
prof_smem = memref.view(
@ -604,7 +615,7 @@ def _launch(
(profiler_spec.smem_i32_elements(block=block),),
i32, memory_space=smem,
),
dynamic_smem, c(compute_smem_bytes, index), [],
dynamic_smem, c(user_smem_bytes, index), [],
)
prof = profiler.OnDeviceProfiler(
profiler_spec, prof_smem, maybe_prof_buffer
@ -612,11 +623,6 @@ def _launch(
else:
prof = None
if isinstance(smem_buffers, Union):
smem_ref_tree: Union[RefTree] = Union(smem_ref_trees)
else:
smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else []
ptr_ty = ir.Type.parse("!llvm.ptr")
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree