remove string annotations from core.py

This commit is contained in:
Matthew Johnson 2022-04-03 11:17:57 -07:00
parent 359b614b5f
commit c72d8f6b09

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
from collections import namedtuple
@ -55,13 +55,13 @@ map, unsafe_map = safe_map, map
# -------------------- jaxprs --------------------
class Jaxpr:
constvars: List['Var']
invars: List['Var']
outvars: List['Atom']
eqns: List['JaxprEqn']
constvars: List[Var]
invars: List[Var]
outvars: List[Atom]
eqns: List[JaxprEqn]
def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'],
outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']):
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn]):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -114,7 +114,7 @@ def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
class ClosedJaxpr:
jaxpr: Jaxpr
consts: List['Any']
consts: List[Any]
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
@ -160,9 +160,9 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
class JaxprEqn(NamedTuple):
invars: List['Atom']
outvars: List['Var']
primitive: 'Primitive'
invars: List[Atom]
outvars: List[Var]
primitive: Primitive
params: Dict[str, Any]
source_info: source_info_util.SourceInfo
@ -182,9 +182,9 @@ class Var:
# by object id, but pretty printing might collide.
count: int
suffix: str
aval: 'AbstractValue'
aval: AbstractValue
def __init__(self, count: int, suffix: str, aval: 'AbstractValue'):
def __init__(self, count: int, suffix: str, aval: AbstractValue):
self.count = count
self.suffix = suffix
self.aval = raise_to_shaped(aval)
@ -213,7 +213,7 @@ def _jaxpr_vars(jaxpr):
(v for eqn in jaxpr.eqns for v in eqn.outvars))
def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
suffix: str = '') -> Callable[['AbstractValue'], Var]:
suffix: str = '') -> Callable[[AbstractValue], Var]:
"""Produce distinct variables, printed with the optional suffix.
If `jaxprs` is provided, the variables produced will be distinct from those in
@ -232,7 +232,7 @@ def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
def __init__(self, aval: 'AbstractValue'):
def __init__(self, aval: AbstractValue):
super().__init__(-1, '', aval)
def __repr__(self): return '_'
@ -240,7 +240,7 @@ class Literal:
__slots__ = ["val", "aval", "hash"]
val: Any
aval: 'AbstractValue'
aval: AbstractValue
hash: Optional[int]
def __init__(self, val, aval):
@ -358,16 +358,16 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
class Trace:
__slots__ = ['main', 'level', 'sublevel']
main: 'MainTrace'
main: MainTrace
level: int
sublevel: 'Sublevel'
sublevel: Sublevel
def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
def __init__(self, main: MainTrace, sublevel: Sublevel) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
def full_raise(self, val) -> 'Tracer':
def full_raise(self, val) -> Tracer:
if not isinstance(val, Tracer):
return self.pure(val)
val._assert_live()
@ -960,10 +960,10 @@ class AbstractValue:
except AttributeError:
return self.__class__.__name__
def strip_weak_type(self) -> 'AbstractValue':
def strip_weak_type(self) -> AbstractValue:
return self
def strip_named_shape(self) -> 'AbstractValue':
def strip_named_shape(self) -> AbstractValue:
return self
def join(self, other):
@ -1167,7 +1167,7 @@ AxisSize = Union[AxisSizeForTracing, AxisSizeForJaxprType,
class DShapedArray(UnshapedArray):
__slots__ = ['shape']
shape: Tuple[AxisSize, ...] # see comment above
shape: Tuple[AxisSize, ...] # noqa: F821
array_abstraction_level = 2
def __init__(self, shape, dtype, weak_type):
@ -1813,7 +1813,7 @@ class MapPrimitive(Primitive):
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
def map_bind(primitive: 'MapPrimitive', fun, *args, out_axes_thunk, **params):
def map_bind(primitive: MapPrimitive, fun, *args, out_axes_thunk, **params):
# The new thunk depends deterministically on the old thunk and the wrapped
# function. Any caching already has to include the wrapped function as part
# of the key, so we only use the previous thunk for equality checks.
@ -2134,7 +2134,7 @@ def check_jaxpr(jaxpr: Jaxpr):
raise JaxprTypeError(msg) from None
def _check_jaxpr(
ctx_factory: Callable[[], Tuple['JaxprPpContext', 'JaxprPpSettings']],
ctx_factory: Callable[[], Tuple[JaxprPpContext, JaxprPpSettings]],
jaxpr: Jaxpr,
in_avals: Sequence[AbstractValue]) -> None: