mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix some type errors in lax.py found by pytype. (#3292)
This commit is contained in:
parent
042df4ebff
commit
dd81a8dded
@ -2278,9 +2278,9 @@ masking.defvectorized(bitcast_convert_type_p)
|
||||
|
||||
|
||||
def _conv_general_dilated_shape_rule(
|
||||
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
**unused_kwargs):
|
||||
lhs: ShapedArray, rhs: ShapedArray, *, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
|
||||
batch_group_count, **unused_kwargs) -> Tuple[int, ...]:
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
if not feature_group_count > 0:
|
||||
msg = ("conv_general_dilated feature_group_count "
|
||||
@ -2317,7 +2317,7 @@ def _conv_general_dilated_shape_rule(
|
||||
msg = ("conv_general_dilated rhs output feature dimension size must be a "
|
||||
"multiple of batch_group_count, but {} is not a multiple of {}.")
|
||||
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
||||
batch_ground_count))
|
||||
batch_group_count))
|
||||
|
||||
if not batch_group_count > 0 and feature_group_count > 0:
|
||||
msg = ("At most one of batch_group_count and feature_group_count may be > "
|
||||
@ -2669,7 +2669,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
# product dims
|
||||
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
|
||||
rhs.ndim - len(rhs_contract) - 1)
|
||||
new_dimension_numbers = [(lhs_contract, rhs_contract), (lhs_batch, rhs_batch)]
|
||||
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
precision=precision)
|
||||
return batched_out, int(result_batch_dim)
|
||||
@ -2899,7 +2899,8 @@ def _pad_transpose(t, operand, padding_value, *, padding_config):
|
||||
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
|
||||
|
||||
def t_op():
|
||||
unpad_config = zip(onp.negative(lo), onp.negative(hi), onp.zeros_like(interior))
|
||||
unpad_config = safe_zip(onp.negative(lo), onp.negative(hi),
|
||||
onp.zeros_like(interior))
|
||||
unpadded = pad(t, onp.array(0., t.dtype), unpad_config)
|
||||
return slice(unpadded, onp.zeros_like(lo), unpadded.shape, onp.add(interior, 1))
|
||||
|
||||
@ -3317,8 +3318,8 @@ def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
|
||||
else:
|
||||
real_limits = onp.add(onp.add(start_indices, 1),
|
||||
onp.multiply(onp.subtract(t.shape, 1), strides))
|
||||
pads = zip(start_indices, onp.subtract(operand_shape, real_limits),
|
||||
onp.subtract(strides, 1))
|
||||
pads = safe_zip(start_indices, onp.subtract(operand_shape, real_limits),
|
||||
onp.subtract(strides, 1))
|
||||
result = pad(t, _const(t, 0), pads)
|
||||
assert result.shape == operand_shape
|
||||
return [result]
|
||||
@ -4243,7 +4244,9 @@ def _generic_reduce_window_batch_rule(
|
||||
x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
|
||||
window_strides=window_strides, padding=padding)
|
||||
return _reduce_window_batch_rule(reduce_window, (operand,), (bdim,),
|
||||
window_dimensions, window_strides, padding)
|
||||
window_dimensions=window_dimensions,
|
||||
window_strides=window_strides,
|
||||
padding=padding)
|
||||
|
||||
|
||||
reduce_window_p = standard_primitive(
|
||||
@ -5428,14 +5431,14 @@ def _conv_general_proto(dimension_numbers):
|
||||
|
||||
def _conv_general_vjp_lhs_padding(
|
||||
in_shape, window_dimensions, window_strides, out_shape, padding,
|
||||
lhs_dilation, rhs_dilation):
|
||||
lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]:
|
||||
lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
|
||||
rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation)
|
||||
out_dilated_shape = _dilate_shape(out_shape, window_strides)
|
||||
pad_before = onp.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1
|
||||
pad_after = (onp.add(lhs_dilated_shape, rhs_dilated_shape) - 1
|
||||
- out_dilated_shape - pad_before)
|
||||
return zip(pad_before, pad_after)
|
||||
return safe_zip(pad_before, pad_after)
|
||||
|
||||
|
||||
def _conv_general_vjp_rhs_padding(
|
||||
|
Loading…
x
Reference in New Issue
Block a user