Repair various type errors

This commit is contained in:
Neil Girdhar 2023-11-16 21:08:44 -05:00
parent a8e2ee9b65
commit 1e580457ba
4 changed files with 23 additions and 15 deletions

View File

@ -32,8 +32,8 @@ from functools import partial
import math
import operator
import types
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol,
TypeVar, Union)
from typing import (cast, overload, Any, Callable, Literal, NamedTuple,
Protocol, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings
@ -851,10 +851,12 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
try:
shape = list(shape)
except TypeError:
shape = [shape]
# TODO: Consider warning here since shape is supposed to be a sequence, so
# this should not happen.
shape = cast(list[Any], [shape])
if any(ndim(s) != 0 for s in shape):
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [0] * len(shape)
out_indices: list[ArrayLike] = [0] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s)
oob_pos = indices_arr > 0
@ -1137,7 +1139,12 @@ def where(
else:
util.check_arraylike("where", acondition, if_true, if_false)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
raise ValueError("size and fill_value arguments cannot be used in "
"three-term where function.")
if if_true is None or if_false is None:
raise ValueError("Either both or neither of the x and y arguments "
"should be provided to jax.numpy.where, got "
f"{if_true} and {if_false}.")
return util._where(acondition, if_true, if_false)

View File

@ -46,7 +46,7 @@ def _roots_no_zeros(p: Array) -> Array:
@jit
def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
# Avoid lapack errors when p is all zero
p = _where(len(p) == num_leading_zeros, 1.0, p)
# Roll any leading zeros to the end & compute the roots
@ -85,7 +85,7 @@ strip_zeros : bool, default=True
""")
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
check_arraylike("roots", p)
p_arr = atleast_1d(*promote_dtypes_inexact(p))
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p_arr.size < 2:
@ -96,7 +96,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
"The error occurred in the jnp.roots() function. To use this within a "
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
"will be result in some returned roots being set to NaN.")
"will result in some returned roots being set to NaN.")
return _roots_no_zeros(p_arr[num_leading_zeros:])
else:
return _roots_with_zeros(p_arr, num_leading_zeros)

View File

@ -18,7 +18,7 @@ from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import math
from typing import Any, Generic, TypeVar, Union
from typing import Any, Union
from jax._src import core
from jax._src import effects
@ -74,8 +74,6 @@ StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
# ## `Ref`s
Aval = TypeVar("Aval", bound=core.AbstractValue)
@dataclasses.dataclass
class RefIndexer:
ref_or_view: Any
@ -124,7 +122,7 @@ class RefView:
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
class AbstractRef(core.AbstractValue, Generic[Aval]):
class AbstractRef(core.AbstractValue):
__slots__ = ["inner_aval"]
def __init__(self, inner_aval: core.AbstractValue):
@ -212,6 +210,6 @@ def get_ref_state_effects(
def shaped_array_ref(shape: tuple[int, ...], dtype,
weak_type: bool = False,
named_shape = None) -> AbstractRef[core.AbstractValue]:
named_shape = None) -> AbstractRef:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
named_shape=named_shape))

View File

@ -138,10 +138,13 @@ def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]
lists[b].append(x)
return lists
def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]:
def merge_lists(bs: Sequence[bool],
l0: Sequence[T1],
l1: Sequence[T2]
) -> list[T1 | T2]:
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
i0, i1 = iter(l0), iter(l1)
out = [next(i1) if b else next(i0) for b in bs]
out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs]
sentinel = object()
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
return out