mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Speed up name stack printing.
If we repeatedly form tuples by concatenation during printing, we make what should be a linear time operation quadratic. Also simplify the API contract of extend() to only add a single element, and remove the unused method wrap_name. PiperOrigin-RevId: 718570432
This commit is contained in:
parent
0fccabcd49
commit
cd51e9dd14
@ -86,32 +86,22 @@ def register_inclusion(path: str):
|
||||
class Scope(NamedTuple):
|
||||
name: str
|
||||
|
||||
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
|
||||
return (self.name, *stack)
|
||||
def wrap(self, stack: list[str]):
|
||||
stack.append(self.name)
|
||||
|
||||
class Transform(NamedTuple):
|
||||
name: str
|
||||
|
||||
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
|
||||
def wrap(self, stack: list[str]):
|
||||
if stack:
|
||||
return (f'{self.name}({stack[0]})', *stack[1:])
|
||||
else:
|
||||
return ()
|
||||
stack[-1] = f'{self.name}({stack[-1]})'
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NameStack:
|
||||
stack: tuple[Scope | Transform, ...] = ()
|
||||
|
||||
def extend(self, name: tuple[str, ...] | str) -> NameStack:
|
||||
if not isinstance(name, tuple):
|
||||
name = (name,)
|
||||
scopes = tuple(map(Scope, name))
|
||||
return NameStack(self.stack + scopes)
|
||||
|
||||
def wrap_name(self, name: str) -> str:
|
||||
if not self.stack:
|
||||
return name
|
||||
return f'{self}/{name}'
|
||||
def extend(self, name: str) -> NameStack:
|
||||
return NameStack((*self.stack, Scope(name)))
|
||||
|
||||
def transform(self, transform_name: str) -> NameStack:
|
||||
return NameStack((*self.stack, Transform(transform_name)))
|
||||
@ -129,10 +119,10 @@ class NameStack:
|
||||
return NameStack(other.stack + self.stack)
|
||||
|
||||
def __str__(self) -> str:
|
||||
scope: tuple[str, ...] = ()
|
||||
scope: list[str] = []
|
||||
for elem in self.stack[::-1]:
|
||||
scope = elem.wrap(scope)
|
||||
return '/'.join(scope)
|
||||
elem.wrap(scope)
|
||||
return '/'.join(reversed(scope))
|
||||
|
||||
|
||||
def new_name_stack(name: str = '') -> NameStack:
|
||||
|
Loading…
x
Reference in New Issue
Block a user