mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove string annotations from core.py
This commit is contained in:
parent
359b614b5f
commit
c72d8f6b09
50
jax/core.py
50
jax/core.py
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
@ -55,13 +55,13 @@ map, unsafe_map = safe_map, map
|
||||
# -------------------- jaxprs --------------------
|
||||
|
||||
class Jaxpr:
|
||||
constvars: List['Var']
|
||||
invars: List['Var']
|
||||
outvars: List['Atom']
|
||||
eqns: List['JaxprEqn']
|
||||
constvars: List[Var]
|
||||
invars: List[Var]
|
||||
outvars: List[Atom]
|
||||
eqns: List[JaxprEqn]
|
||||
|
||||
def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'],
|
||||
outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']):
|
||||
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn]):
|
||||
"""
|
||||
Args:
|
||||
constvars: list of variables introduced for constants. Array constants are
|
||||
@ -114,7 +114,7 @@ def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
|
||||
|
||||
class ClosedJaxpr:
|
||||
jaxpr: Jaxpr
|
||||
consts: List['Any']
|
||||
consts: List[Any]
|
||||
|
||||
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
|
||||
assert len(consts) == len(jaxpr.constvars)
|
||||
@ -160,9 +160,9 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
|
||||
|
||||
|
||||
class JaxprEqn(NamedTuple):
|
||||
invars: List['Atom']
|
||||
outvars: List['Var']
|
||||
primitive: 'Primitive'
|
||||
invars: List[Atom]
|
||||
outvars: List[Var]
|
||||
primitive: Primitive
|
||||
params: Dict[str, Any]
|
||||
source_info: source_info_util.SourceInfo
|
||||
|
||||
@ -182,9 +182,9 @@ class Var:
|
||||
# by object id, but pretty printing might collide.
|
||||
count: int
|
||||
suffix: str
|
||||
aval: 'AbstractValue'
|
||||
aval: AbstractValue
|
||||
|
||||
def __init__(self, count: int, suffix: str, aval: 'AbstractValue'):
|
||||
def __init__(self, count: int, suffix: str, aval: AbstractValue):
|
||||
self.count = count
|
||||
self.suffix = suffix
|
||||
self.aval = raise_to_shaped(aval)
|
||||
@ -213,7 +213,7 @@ def _jaxpr_vars(jaxpr):
|
||||
(v for eqn in jaxpr.eqns for v in eqn.outvars))
|
||||
|
||||
def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
|
||||
suffix: str = '') -> Callable[['AbstractValue'], Var]:
|
||||
suffix: str = '') -> Callable[[AbstractValue], Var]:
|
||||
"""Produce distinct variables, printed with the optional suffix.
|
||||
|
||||
If `jaxprs` is provided, the variables produced will be distinct from those in
|
||||
@ -232,7 +232,7 @@ def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
|
||||
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
|
||||
# treat it as a special case of one. Its `aval` is similarly inexact.
|
||||
class DropVar(Var):
|
||||
def __init__(self, aval: 'AbstractValue'):
|
||||
def __init__(self, aval: AbstractValue):
|
||||
super().__init__(-1, '', aval)
|
||||
def __repr__(self): return '_'
|
||||
|
||||
@ -240,7 +240,7 @@ class Literal:
|
||||
__slots__ = ["val", "aval", "hash"]
|
||||
|
||||
val: Any
|
||||
aval: 'AbstractValue'
|
||||
aval: AbstractValue
|
||||
hash: Optional[int]
|
||||
|
||||
def __init__(self, val, aval):
|
||||
@ -358,16 +358,16 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
|
||||
class Trace:
|
||||
__slots__ = ['main', 'level', 'sublevel']
|
||||
|
||||
main: 'MainTrace'
|
||||
main: MainTrace
|
||||
level: int
|
||||
sublevel: 'Sublevel'
|
||||
sublevel: Sublevel
|
||||
|
||||
def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
|
||||
def __init__(self, main: MainTrace, sublevel: Sublevel) -> None:
|
||||
self.main = main
|
||||
self.level = main.level
|
||||
self.sublevel = sublevel
|
||||
|
||||
def full_raise(self, val) -> 'Tracer':
|
||||
def full_raise(self, val) -> Tracer:
|
||||
if not isinstance(val, Tracer):
|
||||
return self.pure(val)
|
||||
val._assert_live()
|
||||
@ -960,10 +960,10 @@ class AbstractValue:
|
||||
except AttributeError:
|
||||
return self.__class__.__name__
|
||||
|
||||
def strip_weak_type(self) -> 'AbstractValue':
|
||||
def strip_weak_type(self) -> AbstractValue:
|
||||
return self
|
||||
|
||||
def strip_named_shape(self) -> 'AbstractValue':
|
||||
def strip_named_shape(self) -> AbstractValue:
|
||||
return self
|
||||
|
||||
def join(self, other):
|
||||
@ -1167,7 +1167,7 @@ AxisSize = Union[AxisSizeForTracing, AxisSizeForJaxprType,
|
||||
|
||||
class DShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape']
|
||||
shape: Tuple[AxisSize, ...] # see comment above
|
||||
shape: Tuple[AxisSize, ...] # noqa: F821
|
||||
array_abstraction_level = 2
|
||||
|
||||
def __init__(self, shape, dtype, weak_type):
|
||||
@ -1813,7 +1813,7 @@ class MapPrimitive(Primitive):
|
||||
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
||||
return [subfun], new_params
|
||||
|
||||
def map_bind(primitive: 'MapPrimitive', fun, *args, out_axes_thunk, **params):
|
||||
def map_bind(primitive: MapPrimitive, fun, *args, out_axes_thunk, **params):
|
||||
# The new thunk depends deterministically on the old thunk and the wrapped
|
||||
# function. Any caching already has to include the wrapped function as part
|
||||
# of the key, so we only use the previous thunk for equality checks.
|
||||
@ -2134,7 +2134,7 @@ def check_jaxpr(jaxpr: Jaxpr):
|
||||
raise JaxprTypeError(msg) from None
|
||||
|
||||
def _check_jaxpr(
|
||||
ctx_factory: Callable[[], Tuple['JaxprPpContext', 'JaxprPpSettings']],
|
||||
ctx_factory: Callable[[], Tuple[JaxprPpContext, JaxprPpSettings]],
|
||||
jaxpr: Jaxpr,
|
||||
in_avals: Sequence[AbstractValue]) -> None:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user