[key reuse] improve repr for signatures

This commit is contained in:
Jake VanderPlas 2024-03-11 15:17:08 -07:00
parent b6e985ffe7
commit 6cf740ceb1
2 changed files with 26 additions and 1 deletions

View File

@ -15,7 +15,7 @@
from __future__ import annotations
from collections import defaultdict
from functools import partial, reduce, wraps
from functools import partial, reduce, total_ordering, wraps
from typing import Any, Callable, Iterator, NamedTuple
import jax
@ -43,6 +43,7 @@ import numpy as np
# Create Source() and Sink() objects which validate inputs, have
# correct equality semantics, and are hashable & immutable.
@total_ordering
class _SourceSinkBase:
idx: int
mask: bool | np.ndarray
@ -74,6 +75,15 @@ class _SourceSinkBase:
and np.shape(self.mask) == np.shape(other.mask)
and np.all(self.mask == other.mask))
def __lt__(self, other):
if isinstance(other, Forward):
return True
elif isinstance(other, _SourceSinkBase):
return ((self.__class__.__name__, self.idx)
< (other.__class__.__name__, other.idx))
else:
return NotImplemented
def __hash__(self):
if isinstance(self.mask, bool):
return hash((self.__class__, self.idx, self.mask))
@ -100,6 +110,9 @@ class Forward(NamedTuple):
in_idx: int
out_idx: int
def __repr__(self):
return f"Forward({self.in_idx}, {self.out_idx})"
# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
# objects, with a few convenience methods related to key reuse checking.
@ -109,6 +122,9 @@ class KeyReuseSignature:
def __init__(self, *args):
self._args = frozenset(args)
def __repr__(self):
return f"KeyReuseSignature{tuple(sorted(self._args))}"
def __eq__(self, other):
return isinstance(other, KeyReuseSignature) and self._args == other._args

View File

@ -692,6 +692,15 @@ class KeyReuseImplementationTest(jtu.JaxTestCase):
self.assertNotEquivalent(
KeyReuseSignature(Source(0)), KeyReuseSignature(Sink(0)))
def test_reprs(self):
self.assertEqual(repr(Sink(0)), "Sink(0)")
self.assertEqual(repr(Source(0)), "Source(0)")
self.assertEqual(repr(Forward(0, 1)), "Forward(0, 1)")
self.assertEqual(repr(KeyReuseSignature(Sink(1), Source(0))),
"KeyReuseSignature(Sink(1), Source(0))")
self.assertEqual(repr(KeyReuseSignature(Sink(1), Sink(0))),
"KeyReuseSignature(Sink(0), Sink(1))")
@jtu.with_config(jax_enable_checks=False)