Speed up name stack printing.

If we repeatedly form tuples by concatenation during printing, we make what should be a linear time operation quadratic.

Also simplify the API contract of extend() to only add a single element, and remove the unused method wrap_name.

PiperOrigin-RevId: 718570432
This commit is contained in:
Peter Hawkins 2025-01-22 16:23:47 -08:00 committed by jax authors
parent 0fccabcd49
commit cd51e9dd14

View File

@ -86,32 +86,22 @@ def register_inclusion(path: str):
class Scope(NamedTuple):
name: str
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
return (self.name, *stack)
def wrap(self, stack: list[str]):
stack.append(self.name)
class Transform(NamedTuple):
name: str
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
def wrap(self, stack: list[str]):
if stack:
return (f'{self.name}({stack[0]})', *stack[1:])
else:
return ()
stack[-1] = f'{self.name}({stack[-1]})'
@dataclasses.dataclass(frozen=True)
class NameStack:
stack: tuple[Scope | Transform, ...] = ()
def extend(self, name: tuple[str, ...] | str) -> NameStack:
if not isinstance(name, tuple):
name = (name,)
scopes = tuple(map(Scope, name))
return NameStack(self.stack + scopes)
def wrap_name(self, name: str) -> str:
if not self.stack:
return name
return f'{self}/{name}'
def extend(self, name: str) -> NameStack:
return NameStack((*self.stack, Scope(name)))
def transform(self, transform_name: str) -> NameStack:
return NameStack((*self.stack, Transform(transform_name)))
@ -129,10 +119,10 @@ class NameStack:
return NameStack(other.stack + self.stack)
def __str__(self) -> str:
scope: tuple[str, ...] = ()
scope: list[str] = []
for elem in self.stack[::-1]:
scope = elem.wrap(scope)
return '/'.join(scope)
elem.wrap(scope)
return '/'.join(reversed(scope))
def new_name_stack(name: str = '') -> NameStack: