[pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge.

This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa

a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).

PiperOrigin-RevId: 722126566
This commit is contained in:
Christos Perivolaropoulos 2025-02-01 08:22:43 -08:00 committed by jax authors
parent eb04fcbe5a
commit 8649132d86
2 changed files with 78 additions and 37 deletions

View File

@ -550,8 +550,8 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
def dma_start_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals,
*args, tree, device_id_type):
(
src_ref,
src_transforms,
@ -575,7 +575,22 @@ def dma_start_discharge_rule(in_avals, out_avals,
_,
) = tree_util.tree_unflatten(tree, in_avals)
del out_avals
(
_,
_,
dst_discharge,
_,
dst_sem_discharge,
_,
*maybe_src_sem_discharge,
) = tree_util.tree_unflatten(tree, should_discharge)
is_remote = device_id is not None
src_sem_discharge = None
if is_remote:
src_sem_discharge = maybe_src_sem_discharge[0]
if not is_remote:
# Local async copies only use one semaphore.
assert src_sem is None
@ -586,7 +601,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))
updates = state_discharge.transform_array(src_ref, src_transforms)
updates = state_discharge.transform_array(src_ref[...], src_transforms)
local_src = updates
if is_remote:
@ -641,47 +656,61 @@ def dma_start_discharge_rule(in_avals, out_avals,
global_dst_transforms,
)
_, new_dst = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
def do_discharge_dst(dst_ref=dst_ref):
_, ret = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
return ret
# Update semaphore values.
# TODO(justinfu): Potentially handle asymmetric copy sizes.
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, new_dst_sem = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
)
if is_remote:
def do_discharge_dst_sem(dst_sem=dst_sem):
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, ret = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value[...] + recv_size
)
return ret
def do_discharge_src_sem(src_sem=src_sem):
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
src_sem_value = _transform_semaphore(
src_sem, src_sem_transforms, src_sem_aval
)
_, new_src_sem = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value + send_size
_, ret = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value[...] + send_size
)
else:
new_src_sem = None
return ret
new_vals = (None,) # src_val
new_vals += (None,) * num_src_transform_vals
new_vals += (new_dst,) # dst_val
new_vals += (do_discharge_dst() if dst_discharge else None,) # dst_val
new_vals += (None,) * num_dst_transform_vals
new_vals += (new_dst_sem,) # dst_sem
new_vals += (do_discharge_dst_sem() if dst_sem_discharge else None,) # dst_sem
new_vals += (None,) * num_dst_sem_transforms
if is_remote:
new_vals += (new_src_sem,) # src_sem
new_vals += (do_discharge_src_sem() if src_sem_discharge else None,) # src_sem
new_vals += (None,) * num_src_sem_transforms
new_vals += (None,) # device_id
assert (len(new_vals) ==
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"
# If we didn't discharge everything we could we should keep writes
# to the references that are left over.
if not dst_discharge:
sp.ref_set(dst_ref, None, do_discharge_dst(dst_ref=dst_ref[...]))
if not dst_sem_discharge:
sp.ref_set(dst_sem, None, do_discharge_dst_sem(dst_sem=dst_sem[...]))
if is_remote and not src_sem_discharge:
sp.ref_set(src_sem, None, do_discharge_src_sem(src_sem=src_sem[...]))
return new_vals, []
state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule)
dma_wait_p = jax_core.Primitive('dma_wait')
@ -719,8 +748,9 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_wait_partial_discharge_rule(should_discharge,
in_avals, out_avals,
*args, tree, device_id_type):
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
del out_avals, device_id_type
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
@ -735,6 +765,14 @@ def dma_wait_discharge_rule(in_avals, out_avals,
src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, in_avals)
# The only one we can discharge is the dst semaphore. The provided
# buffers are only specified for their types and not their value so
# it's completely irrelevant for us here if they are discharged.
should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge)
if not should_discharge_unflattened[4]:
return (None,) * len(in_avals), []
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
@ -754,7 +792,7 @@ def dma_wait_discharge_rule(in_avals, out_avals,
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_wait_p)(dma_wait_partial_discharge_rule)
def _get_ref_and_transforms(ref):
if isinstance(ref, state.TransformedRef):

View File

@ -931,17 +931,20 @@ state_discharge.register_partial_discharge_rule(run_scoped_p)(
@functools.partial(mlir.register_lowering, run_scoped_p)
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
# This lowering rule gets triggered when run_scoped is not discharged.
# In this case there are no stateful effects to handle.
should_discharge = [
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
]
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=True)
if new_consts: raise NotImplementedError(
"Cannot handle new consts created by state discharge.")
def _lower_fun(*lower_fun_args):
updates, out = _run_scoped_discharge_rule(
should_discharge,
[], [], *lower_fun_args,
jaxpr=jaxpr)
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
return out
# Create inputs filled with uninitialized values to the body.
num_consts = len(lower_fun_args)
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
out = jax_core.eval_jaxpr(discharged_body, [], *lower_fun_args, *init_vals)
return out[:num_return_values]
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)