Fix type errors with current mypy and NumPy.

Enable type stubs for jaxlib.

Fix a nondeterminism problem in jax2tf tests.
This commit is contained in:
Peter Hawkins 2021-06-24 10:51:06 -04:00
parent b7e9a0b18b
commit d658108d36
8 changed files with 31 additions and 28 deletions

View File

@ -502,7 +502,7 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
return concatenate_p.bind(*operands, dimension=dimension)
Precision = xla_client.PrecisionConfig.Precision
Precision.__str__ = lambda precision: precision.name
Precision.__str__ = lambda precision: precision.name # type: ignore
PrecisionType = Any
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
Tuple[PrecisionType, PrecisionType]]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from typing import Callable, NamedTuple, Optional, Union
from typing import Any, Callable, NamedTuple, Optional, Union
from functools import partial
import jax
@ -23,6 +23,8 @@ from .line_search import line_search
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
Array = Any
class LBFGSResults(NamedTuple):
"""Results from L-BFGS optimization
@ -47,32 +49,32 @@ class LBFGSResults(NamedTuple):
5 = line search failed
ls_status: integer describing the end status of the last line search
"""
converged: Union[bool, jnp.ndarray]
failed: Union[bool, jnp.ndarray]
k: Union[int, jnp.ndarray]
nfev: Union[int, jnp.ndarray]
ngev: Union[int, jnp.ndarray]
x_k: jnp.ndarray
f_k: jnp.ndarray
g_k: jnp.ndarray
s_history: jnp.ndarray
y_history: jnp.ndarray
rho_history: jnp.ndarray
gamma: Union[float, jnp.ndarray]
status: Union[int, jnp.ndarray]
ls_status: Union[int, jnp.ndarray]
converged: Union[bool, Array]
failed: Union[bool, Array]
k: Union[int, Array]
nfev: Union[int, Array]
ngev: Union[int, Array]
x_k: Array
f_k: Array
g_k: Array
s_history: Array
y_history: Array
rho_history: Array
gamma: Union[float, Array]
status: Union[int, Array]
ls_status: Union[int, Array]
def _minimize_lbfgs(
fun: Callable,
x0: jnp.ndarray,
maxiter: Optional[int] = None,
x0: Array,
maxiter: Optional[float] = None,
norm=jnp.inf,
maxcor: int = 10,
ftol: float = 2.220446049250313e-09,
gtol: float = 1e-05,
maxfun: Optional[int] = None,
maxgrad: Optional[int] = None,
maxfun: Optional[float] = None,
maxgrad: Optional[float] = None,
maxls: int = 20,
):
"""

View File

@ -877,9 +877,9 @@ def _outside_call_translation_rule(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])
shape = tuple(shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x))
shape = [shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x)]
build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,

View File

@ -1599,6 +1599,7 @@ def _conv_general_dilated(lhs, rhs, *,
# Follow the lowering for complex convolutions from
# lax._conv_general_dilated_translation. We can use the same conversion on all
# platforms because on XLA:TPU the compiler does the same as a rewrite.
preferred_float_et: Optional[Any]
if np.issubdtype(_in_avals[0].dtype, np.complexfloating):
if preferred_element_type is not None:
# Convert complex dtype to types used for real and imaginary parts

View File

@ -585,7 +585,7 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
raise ValueError(msg)
return dim_size
# We have a dimension polynomial for a known dimension.
dim_var = dim_poly.to_var()
dim_var = dim_poly.to_var() # type: ignore
if dim_var is not None:
shape_var_map[dim_spec].add(dim_size) # type: ignore
return dim_poly

View File

@ -528,7 +528,7 @@ def _make_integer_pow_harness(name, *, shape=(20, 30), dtype=np.int32, y=3):
y=y)
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
for dtype in [d for d in jtu.dtypes.all if d not in jtu.dtypes.boolean]:
# Validate dtypes and y values for some special cases.
for y in range(-3, 5):
if np.issubdtype(dtype, np.integer) and y < 0:

View File

@ -105,7 +105,7 @@ _xla_extension_version = getattr(xla_client, '_version', 0)
try:
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
except:
tpu_driver_client = None
tpu_driver_client = None # type: ignore
cuda_path: Optional[str]

View File

@ -4,8 +4,6 @@ disable_error_code = attr-defined
[mypy-absl.*]
ignore_missing_imports = True
[mypy-jaxlib.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-opt_einsum.*]
@ -18,3 +16,5 @@ ignore_errors = True
ignore_errors = True
[mypy-jax.experimental.jax2tf.tests.primitive_harness]
ignore_errors = True
[mypy-libtpu.*]
ignore_missing_imports = True