mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Repair various type errors
This commit is contained in:
parent
a8e2ee9b65
commit
1e580457ba
@ -32,8 +32,8 @@ from functools import partial
|
||||
import math
|
||||
import operator
|
||||
import types
|
||||
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol,
|
||||
TypeVar, Union)
|
||||
from typing import (cast, overload, Any, Callable, Literal, NamedTuple,
|
||||
Protocol, TypeVar, Union)
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
||||
@ -851,10 +851,12 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
|
||||
try:
|
||||
shape = list(shape)
|
||||
except TypeError:
|
||||
shape = [shape]
|
||||
# TODO: Consider warning here since shape is supposed to be a sequence, so
|
||||
# this should not happen.
|
||||
shape = cast(list[Any], [shape])
|
||||
if any(ndim(s) != 0 for s in shape):
|
||||
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
|
||||
out_indices = [0] * len(shape)
|
||||
out_indices: list[ArrayLike] = [0] * len(shape)
|
||||
for i, s in reversed(list(enumerate(shape))):
|
||||
indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s)
|
||||
oob_pos = indices_arr > 0
|
||||
@ -1137,7 +1139,12 @@ def where(
|
||||
else:
|
||||
util.check_arraylike("where", acondition, if_true, if_false)
|
||||
if size is not None or fill_value is not None:
|
||||
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
|
||||
raise ValueError("size and fill_value arguments cannot be used in "
|
||||
"three-term where function.")
|
||||
if if_true is None or if_false is None:
|
||||
raise ValueError("Either both or neither of the x and y arguments "
|
||||
"should be provided to jax.numpy.where, got "
|
||||
f"{if_true} and {if_false}.")
|
||||
return util._where(acondition, if_true, if_false)
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ def _roots_no_zeros(p: Array) -> Array:
|
||||
|
||||
|
||||
@jit
|
||||
def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
|
||||
def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
|
||||
# Avoid lapack errors when p is all zero
|
||||
p = _where(len(p) == num_leading_zeros, 1.0, p)
|
||||
# Roll any leading zeros to the end & compute the roots
|
||||
@ -85,7 +85,7 @@ strip_zeros : bool, default=True
|
||||
""")
|
||||
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
|
||||
check_arraylike("roots", p)
|
||||
p_arr = atleast_1d(*promote_dtypes_inexact(p))
|
||||
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
|
||||
if p_arr.ndim != 1:
|
||||
raise ValueError("Input must be a rank-1 array.")
|
||||
if p_arr.size < 2:
|
||||
@ -96,7 +96,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
|
||||
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
|
||||
"The error occurred in the jnp.roots() function. To use this within a "
|
||||
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
|
||||
"will be result in some returned roots being set to NaN.")
|
||||
"will result in some returned roots being set to NaN.")
|
||||
return _roots_no_zeros(p_arr[num_leading_zeros:])
|
||||
else:
|
||||
return _roots_with_zeros(p_arr, num_leading_zeros)
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import math
|
||||
from typing import Any, Generic, TypeVar, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
@ -74,8 +74,6 @@ StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
|
||||
|
||||
# ## `Ref`s
|
||||
|
||||
Aval = TypeVar("Aval", bound=core.AbstractValue)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefIndexer:
|
||||
ref_or_view: Any
|
||||
@ -124,7 +122,7 @@ class RefView:
|
||||
|
||||
|
||||
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
||||
class AbstractRef(core.AbstractValue, Generic[Aval]):
|
||||
class AbstractRef(core.AbstractValue):
|
||||
__slots__ = ["inner_aval"]
|
||||
|
||||
def __init__(self, inner_aval: core.AbstractValue):
|
||||
@ -212,6 +210,6 @@ def get_ref_state_effects(
|
||||
|
||||
def shaped_array_ref(shape: tuple[int, ...], dtype,
|
||||
weak_type: bool = False,
|
||||
named_shape = None) -> AbstractRef[core.AbstractValue]:
|
||||
named_shape = None) -> AbstractRef:
|
||||
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
|
||||
named_shape=named_shape))
|
||||
|
@ -138,10 +138,13 @@ def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]
|
||||
lists[b].append(x)
|
||||
return lists
|
||||
|
||||
def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]:
|
||||
def merge_lists(bs: Sequence[bool],
|
||||
l0: Sequence[T1],
|
||||
l1: Sequence[T2]
|
||||
) -> list[T1 | T2]:
|
||||
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
|
||||
i0, i1 = iter(l0), iter(l1)
|
||||
out = [next(i1) if b else next(i0) for b in bs]
|
||||
out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs]
|
||||
sentinel = object()
|
||||
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
|
||||
return out
|
||||
|
Loading…
x
Reference in New Issue
Block a user