mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19587 from jakevdp:key-reuse-info
PiperOrigin-RevId: 602838627
This commit is contained in:
commit
cce6520dfa
@ -24,11 +24,23 @@ class Sink(NamedTuple):
|
||||
idx: int
|
||||
mask: bool | np.ndarray = True
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.mask, bool) and self.mask:
|
||||
return f"Sink({self.idx})"
|
||||
else:
|
||||
return f"Sink({self.idx}, mask={self.mask})"
|
||||
|
||||
|
||||
class Source(NamedTuple):
|
||||
idx: int
|
||||
mask: bool | np.ndarray = True
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.mask, bool) and self.mask:
|
||||
return f"Source({self.idx})"
|
||||
else:
|
||||
return f"Source({self.idx}, mask={self.mask})"
|
||||
|
||||
|
||||
class KeyReuseSignature(NamedTuple):
|
||||
sinks: list[Sink]
|
||||
|
@ -254,18 +254,23 @@ def _scan_key_type_signature(eqn, args_consumed):
|
||||
jaxpr = eqn.params['jaxpr'].jaxpr
|
||||
num_consts = eqn.params['num_consts']
|
||||
num_carry = eqn.params['num_carry']
|
||||
length = eqn.params['length']
|
||||
signature = get_jaxpr_type_signature(jaxpr, args_consumed)
|
||||
|
||||
# scan body should not consume key in constants
|
||||
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
|
||||
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
|
||||
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed:\n"
|
||||
f" {signature=}\n"
|
||||
f" {eqn=}\n"
|
||||
f" {jaxpr=}")
|
||||
|
||||
# scan carry should only consume keys that are sourced on output.
|
||||
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
|
||||
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
|
||||
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
|
||||
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
|
||||
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
|
||||
f" {signature=}\n"
|
||||
f" {eqn=}\n"
|
||||
f" {jaxpr=}")
|
||||
return signature
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature
|
||||
@ -283,10 +288,12 @@ def _while_key_type_signature(eqn, args_consumed):
|
||||
# Error if there are sinks among consts.
|
||||
if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts):
|
||||
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
|
||||
f"{cond_signature=}")
|
||||
f" {cond_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts):
|
||||
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
|
||||
f"{body_signature=}")
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
|
||||
# carry should only consume keys that are sourced on output.
|
||||
body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts}
|
||||
@ -295,13 +302,17 @@ def _while_key_type_signature(eqn, args_consumed):
|
||||
# TODO(jakevdp): check masks at each index?
|
||||
if not (cond_carry_sinks.keys() <= carry_sources.keys()):
|
||||
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
|
||||
f"{cond_signature=}")
|
||||
f" {cond_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if not (body_carry_sinks.keys() <= carry_sources.keys()):
|
||||
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
|
||||
f"{body_signature=}")
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if body_carry_sinks.keys() & cond_carry_sinks.keys():
|
||||
raise KeyReuseError("while_loop cond and body functions both use the same key: "
|
||||
f"{cond_signature=} {body_signature=}")
|
||||
f" {cond_signature=}\n"
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
return body_signature
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
|
||||
|
@ -228,13 +228,19 @@ def _scan_key_type_signature(eqn, args_consumed):
|
||||
|
||||
# scan body should not consume key in constants
|
||||
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
|
||||
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
|
||||
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
|
||||
f" {signature=}\n"
|
||||
f" {eqn=}\n"
|
||||
f" {jaxpr=}")
|
||||
|
||||
# scan carry should only consume keys that are sourced on output.
|
||||
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
|
||||
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
|
||||
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
|
||||
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}")
|
||||
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
|
||||
f" {signature=}\n"
|
||||
f" {eqn=}\n"
|
||||
f" {jaxpr=}")
|
||||
return signature
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature
|
||||
@ -252,10 +258,12 @@ def _while_key_type_signature(eqn, args_consumed):
|
||||
# Error if there are sinks among consts.
|
||||
if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts):
|
||||
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
|
||||
f"{cond_signature=}")
|
||||
f" {cond_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts):
|
||||
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
|
||||
f"{body_signature=}")
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
|
||||
# carry should only consume keys that are sourced on output.
|
||||
body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts}
|
||||
@ -264,13 +272,17 @@ def _while_key_type_signature(eqn, args_consumed):
|
||||
# TODO(jakevdp): check masks at each index?
|
||||
if not (cond_carry_sinks.keys() <= carry_sources.keys()):
|
||||
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
|
||||
f"{cond_signature=}")
|
||||
f"{ cond_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if not (body_carry_sinks.keys() <= carry_sources.keys()):
|
||||
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
|
||||
f"{body_signature=}")
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if body_carry_sinks.keys() & cond_carry_sinks.keys():
|
||||
raise KeyReuseError("while_loop cond and body functions both use the same key: "
|
||||
f"{cond_signature=} {body_signature=}")
|
||||
f" {cond_signature=}\n"
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
return body_signature
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
|
||||
|
Loading…
x
Reference in New Issue
Block a user