Merge pull request #19587 from jakevdp:key-reuse-info

PiperOrigin-RevId: 602838627
This commit is contained in:
jax authors 2024-01-30 14:17:17 -08:00
commit cce6520dfa
3 changed files with 50 additions and 15 deletions

View File

@ -24,11 +24,23 @@ class Sink(NamedTuple):
idx: int
mask: bool | np.ndarray = True
def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Sink({self.idx})"
else:
return f"Sink({self.idx}, mask={self.mask})"
class Source(NamedTuple):
idx: int
mask: bool | np.ndarray = True
def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Source({self.idx})"
else:
return f"Source({self.idx}, mask={self.mask})"
class KeyReuseSignature(NamedTuple):
sinks: list[Sink]

View File

@ -254,18 +254,23 @@ def _scan_key_type_signature(eqn, args_consumed):
jaxpr = eqn.params['jaxpr'].jaxpr
num_consts = eqn.params['num_consts']
num_carry = eqn.params['num_carry']
length = eqn.params['length']
signature = get_jaxpr_type_signature(jaxpr, args_consumed)
# scan body should not consume key in constants
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
# scan carry should only consume keys that are sourced on output.
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
return signature
key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature
@ -283,10 +288,12 @@ def _while_key_type_signature(eqn, args_consumed):
# Error if there are sinks among consts.
if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts):
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
f"{cond_signature=}")
f" {cond_signature=}\n"
f" {eqn=}")
if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts):
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
f"{body_signature=}")
f" {body_signature=}\n"
f" {eqn=}")
# carry should only consume keys that are sourced on output.
body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts}
@ -295,13 +302,17 @@ def _while_key_type_signature(eqn, args_consumed):
# TODO(jakevdp): check masks at each index?
if not (cond_carry_sinks.keys() <= carry_sources.keys()):
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
f"{cond_signature=}")
f" {cond_signature=}\n"
f" {eqn=}")
if not (body_carry_sinks.keys() <= carry_sources.keys()):
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
f"{body_signature=}")
f" {body_signature=}\n"
f" {eqn=}")
if body_carry_sinks.keys() & cond_carry_sinks.keys():
raise KeyReuseError("while_loop cond and body functions both use the same key: "
f"{cond_signature=} {body_signature=}")
f" {cond_signature=}\n"
f" {body_signature=}\n"
f" {eqn=}")
return body_signature
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature

View File

@ -228,13 +228,19 @@ def _scan_key_type_signature(eqn, args_consumed):
# scan body should not consume key in constants
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
# scan carry should only consume keys that are sourced on output.
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
return signature
key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature
@ -252,10 +258,12 @@ def _while_key_type_signature(eqn, args_consumed):
# Error if there are sinks among consts.
if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts):
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
f"{cond_signature=}")
f" {cond_signature=}\n"
f" {eqn=}")
if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts):
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
f"{body_signature=}")
f" {body_signature=}\n"
f" {eqn=}")
# carry should only consume keys that are sourced on output.
body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts}
@ -264,13 +272,17 @@ def _while_key_type_signature(eqn, args_consumed):
# TODO(jakevdp): check masks at each index?
if not (cond_carry_sinks.keys() <= carry_sources.keys()):
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
f"{cond_signature=}")
f"{ cond_signature=}\n"
f" {eqn=}")
if not (body_carry_sinks.keys() <= carry_sources.keys()):
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
f"{body_signature=}")
f" {body_signature=}\n"
f" {eqn=}")
if body_carry_sinks.keys() & cond_carry_sinks.keys():
raise KeyReuseError("while_loop cond and body functions both use the same key: "
f"{cond_signature=} {body_signature=}")
f" {cond_signature=}\n"
f" {body_signature=}\n"
f" {eqn=}")
return body_signature
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature