Only donate if memory kinds match. This shouldn't break any existing behavior. Also only warn about unused donations if there are unused donations.

PiperOrigin-RevId: 557231080
This commit is contained in:
Yash Katariya 2023-08-15 13:25:37 -07:00 committed by jax authors
parent 47651c6a59
commit b7796710e4

View File

@ -682,9 +682,14 @@ def lower_jaxpr_to_module(
map(sharded_aval, jaxpr.in_avals, arg_shardings))
out_avals = (jaxpr.out_avals if result_shardings is None else
map(sharded_aval, jaxpr.out_avals, result_shardings))
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
if arg_shardings is not None else None)
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)
if platform in _platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args)
in_avals, out_avals, donated_args, arg_memory_kinds, result_memory_kinds)
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
@ -693,7 +698,9 @@ def lower_jaxpr_to_module(
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in _platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
if unused_donations:
warnings.warn("Some donated buffers were not usable:"
f" {', '.join(unused_donations)}.\n{msg}")
# HLO channels need to start at 1
channel_iter = itertools.count(1)
@ -718,11 +725,6 @@ def lower_jaxpr_to_module(
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
if arg_shardings is not None else None)
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks,
override_lowering_rules=override_lowering_rules,
@ -772,7 +774,8 @@ def module_to_bytecode(module: ir.Module) -> bytes:
return output.getvalue()
def _set_up_aliases(avals_in, avals_out, donated_args):
def _set_up_aliases(avals_in, avals_out, donated_args, arg_memory_kinds,
result_memory_kinds):
input_output_aliases = [None] * len(avals_in)
# To match-up in-avals to out-avals we only care about the number of
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
@ -780,15 +783,24 @@ def _set_up_aliases(avals_in, avals_out, donated_args):
avals_in = map(strip_metadata, avals_in)
avals_out = map(strip_metadata, avals_out)
if arg_memory_kinds is None:
arg_memory_kinds = [None] * len(avals_in)
if result_memory_kinds is None:
result_memory_kinds = [None] * len(avals_out)
donations = collections.defaultdict(collections.deque)
for i, (aval, donated) in enumerate(zip(avals_in, donated_args)):
for i, (aval, am, donated) in enumerate(
zip(avals_in, arg_memory_kinds, donated_args)):
if donated:
donations[aval].append(i)
donations[(aval, am)].append(i)
out_donated_args = list(donated_args)
for i, aval in enumerate(avals_out):
if donations.get(aval, ()):
input_id = donations[aval].popleft()
for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)):
# Only donate if memory kinds match. Relax this when the compiler can
# donate across memories.
key = (aval, rm)
if donations.get(key, ()):
input_id = donations[key].popleft()
input_output_aliases[input_id] = i
out_donated_args[input_id] = False