mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix type errors with current mypy and NumPy.
Enable type stubs for jaxlib. Fix a nondeterminism problem in jax2tf tests.
This commit is contained in:
parent
b7e9a0b18b
commit
d658108d36
@ -502,7 +502,7 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
|
||||
return concatenate_p.bind(*operands, dimension=dimension)
|
||||
|
||||
Precision = xla_client.PrecisionConfig.Precision
|
||||
Precision.__str__ = lambda precision: precision.name
|
||||
Precision.__str__ = lambda precision: precision.name # type: ignore
|
||||
PrecisionType = Any
|
||||
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
|
||||
Tuple[PrecisionType, PrecisionType]]
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
|
||||
from typing import Callable, NamedTuple, Optional, Union
|
||||
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
@ -23,6 +23,8 @@ from .line_search import line_search
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
Array = Any
|
||||
|
||||
class LBFGSResults(NamedTuple):
|
||||
"""Results from L-BFGS optimization
|
||||
|
||||
@ -47,32 +49,32 @@ class LBFGSResults(NamedTuple):
|
||||
5 = line search failed
|
||||
ls_status: integer describing the end status of the last line search
|
||||
"""
|
||||
converged: Union[bool, jnp.ndarray]
|
||||
failed: Union[bool, jnp.ndarray]
|
||||
k: Union[int, jnp.ndarray]
|
||||
nfev: Union[int, jnp.ndarray]
|
||||
ngev: Union[int, jnp.ndarray]
|
||||
x_k: jnp.ndarray
|
||||
f_k: jnp.ndarray
|
||||
g_k: jnp.ndarray
|
||||
s_history: jnp.ndarray
|
||||
y_history: jnp.ndarray
|
||||
rho_history: jnp.ndarray
|
||||
gamma: Union[float, jnp.ndarray]
|
||||
status: Union[int, jnp.ndarray]
|
||||
ls_status: Union[int, jnp.ndarray]
|
||||
converged: Union[bool, Array]
|
||||
failed: Union[bool, Array]
|
||||
k: Union[int, Array]
|
||||
nfev: Union[int, Array]
|
||||
ngev: Union[int, Array]
|
||||
x_k: Array
|
||||
f_k: Array
|
||||
g_k: Array
|
||||
s_history: Array
|
||||
y_history: Array
|
||||
rho_history: Array
|
||||
gamma: Union[float, Array]
|
||||
status: Union[int, Array]
|
||||
ls_status: Union[int, Array]
|
||||
|
||||
|
||||
def _minimize_lbfgs(
|
||||
fun: Callable,
|
||||
x0: jnp.ndarray,
|
||||
maxiter: Optional[int] = None,
|
||||
x0: Array,
|
||||
maxiter: Optional[float] = None,
|
||||
norm=jnp.inf,
|
||||
maxcor: int = 10,
|
||||
ftol: float = 2.220446049250313e-09,
|
||||
gtol: float = 1e-05,
|
||||
maxfun: Optional[int] = None,
|
||||
maxgrad: Optional[int] = None,
|
||||
maxfun: Optional[float] = None,
|
||||
maxgrad: Optional[float] = None,
|
||||
maxls: int = 20,
|
||||
):
|
||||
"""
|
||||
|
@ -877,9 +877,9 @@ def _outside_call_translation_rule(
|
||||
[array_sharding_proto] * len(non_empty_flat_results_aval) +
|
||||
[token_sharding_proto])
|
||||
|
||||
shape = tuple(shape.with_major_to_minor_layout_if_absent()
|
||||
for x in non_empty_flat_results_aval
|
||||
for shape in xla.aval_to_xla_shapes(x))
|
||||
shape = [shape.with_major_to_minor_layout_if_absent()
|
||||
for x in non_empty_flat_results_aval
|
||||
for shape in xla.aval_to_xla_shapes(x)]
|
||||
|
||||
build_infeed = functools.partial(xops.InfeedWithToken,
|
||||
after_outfeed_itoken,
|
||||
|
@ -1599,6 +1599,7 @@ def _conv_general_dilated(lhs, rhs, *,
|
||||
# Follow the lowering for complex convolutions from
|
||||
# lax._conv_general_dilated_translation. We can use the same conversion on all
|
||||
# platforms because on XLA:TPU the compiler does the same as a rewrite.
|
||||
preferred_float_et: Optional[Any]
|
||||
if np.issubdtype(_in_avals[0].dtype, np.complexfloating):
|
||||
if preferred_element_type is not None:
|
||||
# Convert complex dtype to types used for real and imaginary parts
|
||||
|
@ -585,7 +585,7 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
raise ValueError(msg)
|
||||
return dim_size
|
||||
# We have a dimension polynomial for a known dimension.
|
||||
dim_var = dim_poly.to_var()
|
||||
dim_var = dim_poly.to_var() # type: ignore
|
||||
if dim_var is not None:
|
||||
shape_var_map[dim_spec].add(dim_size) # type: ignore
|
||||
return dim_poly
|
||||
|
@ -528,7 +528,7 @@ def _make_integer_pow_harness(name, *, shape=(20, 30), dtype=np.int32, y=3):
|
||||
y=y)
|
||||
|
||||
|
||||
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
|
||||
for dtype in [d for d in jtu.dtypes.all if d not in jtu.dtypes.boolean]:
|
||||
# Validate dtypes and y values for some special cases.
|
||||
for y in range(-3, 5):
|
||||
if np.issubdtype(dtype, np.integer) and y < 0:
|
||||
|
@ -105,7 +105,7 @@ _xla_extension_version = getattr(xla_client, '_version', 0)
|
||||
try:
|
||||
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
|
||||
except:
|
||||
tpu_driver_client = None
|
||||
tpu_driver_client = None # type: ignore
|
||||
|
||||
|
||||
cuda_path: Optional[str]
|
||||
|
4
mypy.ini
4
mypy.ini
@ -4,8 +4,6 @@ disable_error_code = attr-defined
|
||||
|
||||
[mypy-absl.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-jaxlib.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-numpy.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-opt_einsum.*]
|
||||
@ -18,3 +16,5 @@ ignore_errors = True
|
||||
ignore_errors = True
|
||||
[mypy-jax.experimental.jax2tf.tests.primitive_harness]
|
||||
ignore_errors = True
|
||||
[mypy-libtpu.*]
|
||||
ignore_missing_imports = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user