update mypy & related package versions

This commit is contained in:
Jake VanderPlas 2022-04-15 08:55:06 -07:00
parent 375777f43c
commit 7972b98a7b
3 changed files with 4 additions and 4 deletions

View File

@ -14,11 +14,11 @@ repos:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.931'
rev: 'v0.942'
hooks:
- id: mypy
files: jax/
additional_dependencies: [types-requests==0.1.11, jaxlib==0.1.74]
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]
- repo: https://github.com/mwouts/jupytext
rev: v1.13.6

View File

@ -379,7 +379,7 @@ def _one_hot(x: Array, num_classes: int, *,
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x == axis_idx, dtype=dtype)
axis = operator.index(axis)
axis = operator.index(axis) # type: ignore[arg-type]
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)

View File

@ -157,7 +157,7 @@ def sharding_to_proto(sharding: SpatialSharding):
proto.type = xc.OpSharding.Type.REPLICATED
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_dimensions = list(sharding) # type: ignore
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
return proto