mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge.
This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa a) dma_start can't discharge some but not all its references b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop). PiperOrigin-RevId: 722126566
This commit is contained in:
parent
eb04fcbe5a
commit
8649132d86
@ -550,8 +550,8 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
|
||||
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
|
||||
|
||||
def dma_start_discharge_rule(in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
(
|
||||
src_ref,
|
||||
src_transforms,
|
||||
@ -575,7 +575,22 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
||||
_,
|
||||
) = tree_util.tree_unflatten(tree, in_avals)
|
||||
del out_avals
|
||||
|
||||
(
|
||||
_,
|
||||
_,
|
||||
dst_discharge,
|
||||
_,
|
||||
dst_sem_discharge,
|
||||
_,
|
||||
*maybe_src_sem_discharge,
|
||||
) = tree_util.tree_unflatten(tree, should_discharge)
|
||||
is_remote = device_id is not None
|
||||
src_sem_discharge = None
|
||||
|
||||
if is_remote:
|
||||
src_sem_discharge = maybe_src_sem_discharge[0]
|
||||
|
||||
if not is_remote:
|
||||
# Local async copies only use one semaphore.
|
||||
assert src_sem is None
|
||||
@ -586,7 +601,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
||||
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
|
||||
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))
|
||||
|
||||
updates = state_discharge.transform_array(src_ref, src_transforms)
|
||||
updates = state_discharge.transform_array(src_ref[...], src_transforms)
|
||||
local_src = updates
|
||||
|
||||
if is_remote:
|
||||
@ -641,47 +656,61 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
||||
global_dst_transforms,
|
||||
)
|
||||
|
||||
_, new_dst = state_discharge.transform_swap_array(
|
||||
dst_ref, dst_transforms, updates
|
||||
)
|
||||
def do_discharge_dst(dst_ref=dst_ref):
|
||||
_, ret = state_discharge.transform_swap_array(
|
||||
dst_ref, dst_transforms, updates
|
||||
)
|
||||
return ret
|
||||
|
||||
# Update semaphore values.
|
||||
# TODO(justinfu): Potentially handle asymmetric copy sizes.
|
||||
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
|
||||
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
||||
dst_sem_value = _transform_semaphore(
|
||||
dst_sem, dst_sem_transforms, dst_sem_aval
|
||||
)
|
||||
_, new_dst_sem = state_discharge.transform_swap_array(
|
||||
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
|
||||
)
|
||||
if is_remote:
|
||||
def do_discharge_dst_sem(dst_sem=dst_sem):
|
||||
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
|
||||
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
||||
dst_sem_value = _transform_semaphore(
|
||||
dst_sem, dst_sem_transforms, dst_sem_aval
|
||||
)
|
||||
_, ret = state_discharge.transform_swap_array(
|
||||
dst_sem, dst_sem_transforms, dst_sem_value[...] + recv_size
|
||||
)
|
||||
return ret
|
||||
|
||||
def do_discharge_src_sem(src_sem=src_sem):
|
||||
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
|
||||
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
||||
src_sem_value = _transform_semaphore(
|
||||
src_sem, src_sem_transforms, src_sem_aval
|
||||
)
|
||||
_, new_src_sem = state_discharge.transform_swap_array(
|
||||
src_sem, src_sem_transforms, src_sem_value + send_size
|
||||
_, ret = state_discharge.transform_swap_array(
|
||||
src_sem, src_sem_transforms, src_sem_value[...] + send_size
|
||||
)
|
||||
else:
|
||||
new_src_sem = None
|
||||
return ret
|
||||
|
||||
new_vals = (None,) # src_val
|
||||
new_vals += (None,) * num_src_transform_vals
|
||||
new_vals += (new_dst,) # dst_val
|
||||
new_vals += (do_discharge_dst() if dst_discharge else None,) # dst_val
|
||||
new_vals += (None,) * num_dst_transform_vals
|
||||
new_vals += (new_dst_sem,) # dst_sem
|
||||
new_vals += (do_discharge_dst_sem() if dst_sem_discharge else None,) # dst_sem
|
||||
new_vals += (None,) * num_dst_sem_transforms
|
||||
if is_remote:
|
||||
new_vals += (new_src_sem,) # src_sem
|
||||
new_vals += (do_discharge_src_sem() if src_sem_discharge else None,) # src_sem
|
||||
new_vals += (None,) * num_src_sem_transforms
|
||||
new_vals += (None,) # device_id
|
||||
assert (len(new_vals) ==
|
||||
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"
|
||||
|
||||
# If we didn't discharge everything we could we should keep writes
|
||||
# to the references that are left over.
|
||||
if not dst_discharge:
|
||||
sp.ref_set(dst_ref, None, do_discharge_dst(dst_ref=dst_ref[...]))
|
||||
if not dst_sem_discharge:
|
||||
sp.ref_set(dst_sem, None, do_discharge_dst_sem(dst_sem=dst_sem[...]))
|
||||
if is_remote and not src_sem_discharge:
|
||||
sp.ref_set(src_sem, None, do_discharge_src_sem(src_sem=src_sem[...]))
|
||||
|
||||
return new_vals, []
|
||||
|
||||
state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
|
||||
state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule)
|
||||
|
||||
|
||||
dma_wait_p = jax_core.Primitive('dma_wait')
|
||||
@ -719,8 +748,9 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
|
||||
jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
|
||||
|
||||
def dma_wait_discharge_rule(in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
def dma_wait_partial_discharge_rule(should_discharge,
|
||||
in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
|
||||
del out_avals, device_id_type
|
||||
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
|
||||
@ -735,6 +765,14 @@ def dma_wait_discharge_rule(in_avals, out_avals,
|
||||
src_sem_transforms_avals,
|
||||
device_id_aval,
|
||||
) = tree_util.tree_unflatten(tree, in_avals)
|
||||
|
||||
# The only one we can discharge is the dst semaphore. The provided
|
||||
# buffers are only specified for their types and not their value so
|
||||
# it's completely irrelevant for us here if they are discharged.
|
||||
should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge)
|
||||
if not should_discharge_unflattened[4]:
|
||||
return (None,) * len(in_avals), []
|
||||
|
||||
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
|
||||
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
|
||||
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
|
||||
@ -754,7 +792,7 @@ def dma_wait_discharge_rule(in_avals, out_avals,
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
|
||||
return new_vals, []
|
||||
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
|
||||
state_discharge.register_partial_discharge_rule(dma_wait_p)(dma_wait_partial_discharge_rule)
|
||||
|
||||
def _get_ref_and_transforms(ref):
|
||||
if isinstance(ref, state.TransformedRef):
|
||||
|
@ -931,17 +931,20 @@ state_discharge.register_partial_discharge_rule(run_scoped_p)(
|
||||
|
||||
@functools.partial(mlir.register_lowering, run_scoped_p)
|
||||
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
|
||||
# This lowering rule gets triggered when run_scoped is not discharged.
|
||||
# In this case there are no stateful effects to handle.
|
||||
should_discharge = [
|
||||
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
|
||||
]
|
||||
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
|
||||
num_return_values = len(jaxpr_noconst.outvars)
|
||||
discharged_body, new_consts = state_discharge.discharge_state(
|
||||
jaxpr_noconst, [], should_discharge=True)
|
||||
if new_consts: raise NotImplementedError(
|
||||
"Cannot handle new consts created by state discharge.")
|
||||
|
||||
def _lower_fun(*lower_fun_args):
|
||||
updates, out = _run_scoped_discharge_rule(
|
||||
should_discharge,
|
||||
[], [], *lower_fun_args,
|
||||
jaxpr=jaxpr)
|
||||
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
|
||||
return out
|
||||
# Create inputs filled with uninitialized values to the body.
|
||||
num_consts = len(lower_fun_args)
|
||||
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
|
||||
init_vals = [uninitialized_value(
|
||||
aval.shape, aval.dtype) for aval in body_avals]
|
||||
out = jax_core.eval_jaxpr(discharged_body, [], *lower_fun_args, *init_vals)
|
||||
return out[:num_return_values]
|
||||
|
||||
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user