Fix some type errors in lax.py found by pytype. (#3292)

This commit is contained in:
Peter Hawkins 2020-06-02 10:27:14 -04:00 committed by GitHub
parent 042df4ebff
commit dd81a8dded
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2278,9 +2278,9 @@ masking.defvectorized(bitcast_convert_type_p)
def _conv_general_dilated_shape_rule(
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
**unused_kwargs):
lhs: ShapedArray, rhs: ShapedArray, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, **unused_kwargs) -> Tuple[int, ...]:
assert type(dimension_numbers) is ConvDimensionNumbers
if not feature_group_count > 0:
msg = ("conv_general_dilated feature_group_count "
@ -2317,7 +2317,7 @@ def _conv_general_dilated_shape_rule(
msg = ("conv_general_dilated rhs output feature dimension size must be a "
"multiple of batch_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
batch_ground_count))
batch_group_count))
if not batch_group_count > 0 and feature_group_count > 0:
msg = ("At most one of batch_group_count and feature_group_count may be > "
@ -2669,7 +2669,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
# product dims
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
rhs.ndim - len(rhs_contract) - 1)
new_dimension_numbers = [(lhs_contract, rhs_contract), (lhs_batch, rhs_batch)]
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision)
return batched_out, int(result_batch_dim)
@ -2899,7 +2899,8 @@ def _pad_transpose(t, operand, padding_value, *, padding_config):
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
def t_op():
unpad_config = zip(onp.negative(lo), onp.negative(hi), onp.zeros_like(interior))
unpad_config = safe_zip(onp.negative(lo), onp.negative(hi),
onp.zeros_like(interior))
unpadded = pad(t, onp.array(0., t.dtype), unpad_config)
return slice(unpadded, onp.zeros_like(lo), unpadded.shape, onp.add(interior, 1))
@ -3317,8 +3318,8 @@ def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
else:
real_limits = onp.add(onp.add(start_indices, 1),
onp.multiply(onp.subtract(t.shape, 1), strides))
pads = zip(start_indices, onp.subtract(operand_shape, real_limits),
onp.subtract(strides, 1))
pads = safe_zip(start_indices, onp.subtract(operand_shape, real_limits),
onp.subtract(strides, 1))
result = pad(t, _const(t, 0), pads)
assert result.shape == operand_shape
return [result]
@ -4243,7 +4244,9 @@ def _generic_reduce_window_batch_rule(
x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
window_strides=window_strides, padding=padding)
return _reduce_window_batch_rule(reduce_window, (operand,), (bdim,),
window_dimensions, window_strides, padding)
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding)
reduce_window_p = standard_primitive(
@ -5428,14 +5431,14 @@ def _conv_general_proto(dimension_numbers):
def _conv_general_vjp_lhs_padding(
in_shape, window_dimensions, window_strides, out_shape, padding,
lhs_dilation, rhs_dilation):
lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]:
lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation)
out_dilated_shape = _dilate_shape(out_shape, window_strides)
pad_before = onp.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1
pad_after = (onp.add(lhs_dilated_shape, rhs_dilated_shape) - 1
- out_dilated_shape - pad_before)
return zip(pad_before, pad_after)
return safe_zip(pad_before, pad_after)
def _conv_general_vjp_rhs_padding(