mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
update mypy & related package versions
This commit is contained in:
parent
375777f43c
commit
7972b98a7b
@ -14,11 +14,11 @@ repos:
|
||||
- id: flake8
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v0.931'
|
||||
rev: 'v0.942'
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: jax/
|
||||
additional_dependencies: [types-requests==0.1.11, jaxlib==0.1.74]
|
||||
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]
|
||||
|
||||
- repo: https://github.com/mwouts/jupytext
|
||||
rev: v1.13.6
|
||||
|
@ -379,7 +379,7 @@ def _one_hot(x: Array, num_classes: int, *,
|
||||
f"but {num_classes} != {axis_size}") from None
|
||||
axis_idx = lax.axis_index(axis)
|
||||
return jnp.asarray(x == axis_idx, dtype=dtype)
|
||||
axis = operator.index(axis)
|
||||
axis = operator.index(axis) # type: ignore[arg-type]
|
||||
lhs = lax.expand_dims(x, (axis,))
|
||||
rhs_shape = [1] * x.ndim
|
||||
rhs_shape.insert(output_pos_axis, num_classes)
|
||||
|
@ -157,7 +157,7 @@ def sharding_to_proto(sharding: SpatialSharding):
|
||||
proto.type = xc.OpSharding.Type.REPLICATED
|
||||
else:
|
||||
proto.type = xc.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = list(sharding)
|
||||
proto.tile_assignment_dimensions = list(sharding) # type: ignore
|
||||
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
|
||||
return proto
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user