mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] improve repr for signatures
This commit is contained in:
parent
b6e985ffe7
commit
6cf740ceb1
@ -15,7 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import partial, reduce, wraps
|
||||
from functools import partial, reduce, total_ordering, wraps
|
||||
from typing import Any, Callable, Iterator, NamedTuple
|
||||
|
||||
import jax
|
||||
@ -43,6 +43,7 @@ import numpy as np
|
||||
|
||||
# Create Source() and Sink() objects which validate inputs, have
|
||||
# correct equality semantics, and are hashable & immutable.
|
||||
@total_ordering
|
||||
class _SourceSinkBase:
|
||||
idx: int
|
||||
mask: bool | np.ndarray
|
||||
@ -74,6 +75,15 @@ class _SourceSinkBase:
|
||||
and np.shape(self.mask) == np.shape(other.mask)
|
||||
and np.all(self.mask == other.mask))
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, Forward):
|
||||
return True
|
||||
elif isinstance(other, _SourceSinkBase):
|
||||
return ((self.__class__.__name__, self.idx)
|
||||
< (other.__class__.__name__, other.idx))
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
if isinstance(self.mask, bool):
|
||||
return hash((self.__class__, self.idx, self.mask))
|
||||
@ -100,6 +110,9 @@ class Forward(NamedTuple):
|
||||
in_idx: int
|
||||
out_idx: int
|
||||
|
||||
def __repr__(self):
|
||||
return f"Forward({self.in_idx}, {self.out_idx})"
|
||||
|
||||
|
||||
# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
|
||||
# objects, with a few convenience methods related to key reuse checking.
|
||||
@ -109,6 +122,9 @@ class KeyReuseSignature:
|
||||
def __init__(self, *args):
|
||||
self._args = frozenset(args)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KeyReuseSignature{tuple(sorted(self._args))}"
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, KeyReuseSignature) and self._args == other._args
|
||||
|
||||
|
@ -692,6 +692,15 @@ class KeyReuseImplementationTest(jtu.JaxTestCase):
|
||||
self.assertNotEquivalent(
|
||||
KeyReuseSignature(Source(0)), KeyReuseSignature(Sink(0)))
|
||||
|
||||
def test_reprs(self):
|
||||
self.assertEqual(repr(Sink(0)), "Sink(0)")
|
||||
self.assertEqual(repr(Source(0)), "Source(0)")
|
||||
self.assertEqual(repr(Forward(0, 1)), "Forward(0, 1)")
|
||||
self.assertEqual(repr(KeyReuseSignature(Sink(1), Source(0))),
|
||||
"KeyReuseSignature(Sink(1), Source(0))")
|
||||
self.assertEqual(repr(KeyReuseSignature(Sink(1), Sink(0))),
|
||||
"KeyReuseSignature(Sink(0), Sink(1))")
|
||||
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_checks=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user