mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic GPU] Allow unions nested inside the smem ref tree
We don't use this capability just yet, but I want to start allocating barriers as part of the scratch and this will push the unions deeper into the tree. PiperOrigin-RevId: 655475839
This commit is contained in:
parent
6bc7929376
commit
832eb2d8d2
@ -512,23 +512,46 @@ def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
|
||||
|
||||
|
||||
def _construct_smem_reftree(
|
||||
dynamic_smem: ir.Value, smem_buffers: ShapeTree) -> RefTree:
|
||||
dynamic_smem: ir.Value, smem_buffers: ShapeTree, dynamic_smem_offset: int = 0) -> RefTree:
|
||||
index = ir.IndexType.get()
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(smem_buffers)
|
||||
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
|
||||
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
|
||||
)
|
||||
smem_refs = []
|
||||
dynamic_smem_offset = 0
|
||||
for ref_ty in flat_ref_tys:
|
||||
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
|
||||
tile_smem = memref.view(
|
||||
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
|
||||
dynamic_smem, c(dynamic_smem_offset, index), [],
|
||||
)
|
||||
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
|
||||
smem_refs.append(tile_smem)
|
||||
if isinstance(ref_ty, Union):
|
||||
member_trees = [
|
||||
_construct_smem_reftree(dynamic_smem, m, dynamic_smem_offset)
|
||||
for m in ref_ty.members
|
||||
]
|
||||
# TODO(apaszke): This is quadratic, but it shouldn't matter for now...
|
||||
dynamic_smem_offset += _smem_tree_size(ref_ty)
|
||||
smem_refs.append(Union(member_trees))
|
||||
else:
|
||||
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
|
||||
tile_smem = memref.view(
|
||||
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
|
||||
dynamic_smem, c(dynamic_smem_offset, index), [],
|
||||
)
|
||||
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
|
||||
smem_refs.append(tile_smem)
|
||||
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
||||
|
||||
|
||||
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
|
||||
leaves = jax.tree.leaves(
|
||||
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
|
||||
)
|
||||
size = 0
|
||||
for l in leaves:
|
||||
if isinstance(l, Union):
|
||||
size += max(_smem_tree_size(s) for s in smem_buffers.members)
|
||||
else:
|
||||
size += _count_buffer_bytes(l)
|
||||
return size
|
||||
|
||||
|
||||
# TODO(apaszke): Inline this
|
||||
@contextlib.contextmanager
|
||||
def _launch(
|
||||
@ -549,17 +572,9 @@ def _launch(
|
||||
grid_vals = [c(i, index) for i in grid]
|
||||
block_vals = [c(i, index) for i in block]
|
||||
|
||||
if isinstance(smem_buffers, Union):
|
||||
smem_disjoint_live_buffers_collections = smem_buffers.members
|
||||
compute_smem_bytes = max(
|
||||
sum(_count_buffer_bytes(l) for l in jax.tree.leaves(s))
|
||||
for s in smem_buffers.members)
|
||||
else:
|
||||
smem_disjoint_live_buffers_collections = [smem_buffers]
|
||||
compute_smem_bytes = sum(
|
||||
_count_buffer_bytes(l) for l in jax.tree.leaves(smem_buffers))
|
||||
user_smem_bytes = _smem_tree_size(smem_buffers)
|
||||
|
||||
smem_bytes = compute_smem_bytes
|
||||
smem_bytes = user_smem_bytes
|
||||
if profiler_spec is not None:
|
||||
smem_bytes += profiler_spec.smem_bytes(block=block)
|
||||
|
||||
@ -592,11 +607,7 @@ def _launch(
|
||||
)
|
||||
)
|
||||
|
||||
smem_ref_trees = []
|
||||
for smem_live_buffers_collection in smem_disjoint_live_buffers_collections:
|
||||
smem_ref_tree = _construct_smem_reftree(
|
||||
dynamic_smem, smem_live_buffers_collection)
|
||||
smem_ref_trees.append(smem_ref_tree)
|
||||
smem_ref_tree = _construct_smem_reftree(dynamic_smem, smem_buffers)
|
||||
|
||||
if profiler_spec:
|
||||
prof_smem = memref.view(
|
||||
@ -604,7 +615,7 @@ def _launch(
|
||||
(profiler_spec.smem_i32_elements(block=block),),
|
||||
i32, memory_space=smem,
|
||||
),
|
||||
dynamic_smem, c(compute_smem_bytes, index), [],
|
||||
dynamic_smem, c(user_smem_bytes, index), [],
|
||||
)
|
||||
prof = profiler.OnDeviceProfiler(
|
||||
profiler_spec, prof_smem, maybe_prof_buffer
|
||||
@ -612,11 +623,6 @@ def _launch(
|
||||
else:
|
||||
prof = None
|
||||
|
||||
if isinstance(smem_buffers, Union):
|
||||
smem_ref_tree: Union[RefTree] = Union(smem_ref_trees)
|
||||
else:
|
||||
smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else []
|
||||
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
|
||||
yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree
|
||||
|
Loading…
x
Reference in New Issue
Block a user