mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Only donate if memory kinds match. This shouldn't break any existing behavior. Also only warn about unused donations if there are unused donations.
PiperOrigin-RevId: 557231080
This commit is contained in:
parent
47651c6a59
commit
b7796710e4
@ -682,9 +682,14 @@ def lower_jaxpr_to_module(
|
||||
map(sharded_aval, jaxpr.in_avals, arg_shardings))
|
||||
out_avals = (jaxpr.out_avals if result_shardings is None else
|
||||
map(sharded_aval, jaxpr.out_avals, result_shardings))
|
||||
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
|
||||
if arg_shardings is not None else None)
|
||||
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
||||
if result_shardings is not None else None)
|
||||
|
||||
if platform in _platforms_with_donation:
|
||||
input_output_aliases, donated_args = _set_up_aliases(
|
||||
in_avals, out_avals, donated_args)
|
||||
in_avals, out_avals, donated_args, arg_memory_kinds, result_memory_kinds)
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
||||
@ -693,7 +698,9 @@ def lower_jaxpr_to_module(
|
||||
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
||||
if platform not in _platforms_with_donation:
|
||||
msg = f"Donation is not implemented for {platform}.\n{msg}"
|
||||
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
|
||||
if unused_donations:
|
||||
warnings.warn("Some donated buffers were not usable:"
|
||||
f" {', '.join(unused_donations)}.\n{msg}")
|
||||
|
||||
# HLO channels need to start at 1
|
||||
channel_iter = itertools.count(1)
|
||||
@ -718,11 +725,6 @@ def lower_jaxpr_to_module(
|
||||
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
|
||||
if result_shardings is not None else result_shardings)
|
||||
|
||||
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
|
||||
if arg_shardings is not None else None)
|
||||
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
||||
if result_shardings is not None else None)
|
||||
|
||||
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
||||
keepalives, channel_iter, host_callbacks,
|
||||
override_lowering_rules=override_lowering_rules,
|
||||
@ -772,7 +774,8 @@ def module_to_bytecode(module: ir.Module) -> bytes:
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def _set_up_aliases(avals_in, avals_out, donated_args):
|
||||
def _set_up_aliases(avals_in, avals_out, donated_args, arg_memory_kinds,
|
||||
result_memory_kinds):
|
||||
input_output_aliases = [None] * len(avals_in)
|
||||
# To match-up in-avals to out-avals we only care about the number of
|
||||
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
|
||||
@ -780,15 +783,24 @@ def _set_up_aliases(avals_in, avals_out, donated_args):
|
||||
avals_in = map(strip_metadata, avals_in)
|
||||
avals_out = map(strip_metadata, avals_out)
|
||||
|
||||
if arg_memory_kinds is None:
|
||||
arg_memory_kinds = [None] * len(avals_in)
|
||||
if result_memory_kinds is None:
|
||||
result_memory_kinds = [None] * len(avals_out)
|
||||
|
||||
donations = collections.defaultdict(collections.deque)
|
||||
for i, (aval, donated) in enumerate(zip(avals_in, donated_args)):
|
||||
for i, (aval, am, donated) in enumerate(
|
||||
zip(avals_in, arg_memory_kinds, donated_args)):
|
||||
if donated:
|
||||
donations[aval].append(i)
|
||||
donations[(aval, am)].append(i)
|
||||
|
||||
out_donated_args = list(donated_args)
|
||||
for i, aval in enumerate(avals_out):
|
||||
if donations.get(aval, ()):
|
||||
input_id = donations[aval].popleft()
|
||||
for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)):
|
||||
# Only donate if memory kinds match. Relax this when the compiler can
|
||||
# donate across memories.
|
||||
key = (aval, rm)
|
||||
if donations.get(key, ()):
|
||||
input_id = donations[key].popleft()
|
||||
input_output_aliases[input_id] = i
|
||||
out_donated_args[input_id] = False
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user