Misc. small fixes.

This commit is contained in:
sschoenholz 2019-01-28 19:08:05 -08:00 committed by GitHub
parent 6181560ef0
commit 3b8d43cefa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1136,9 +1136,9 @@ def conv_general_dilated_batch_rule(
lhs_dim, rhs_dim, out_dim = dimension_numbers
if lhs_bdim is not None and rhs_bdim is not None:
#TODO(#212): use a map construct instead of unrolling.
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
outputs = [
conv_general_dilated(l, r, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers)
@ -1152,12 +1152,10 @@ def conv_general_dilated_batch_rule(
# convolution isn't the first dimension.
if lhs_dim[0] != 0 or out_dim[0] != 0:
raise NotImplementedError
lhs = batching.move_dim_to_front(lhs, lhs_dim[0])
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
batched_size = lhs.shape[0]
n_size = lhs.shape[1]
lhs = reshape(lhs, (batched_size * n_size,) + lhs.shape[2:])
outputs = conv_general_dilated(
lhs, rhs, window_strides, padding,
@ -1166,7 +1164,7 @@ def conv_general_dilated_batch_rule(
return outputs, 0
elif rhs_bdim is not None:
# TODO(schsam): Consider a loop instead of unrolling.
#TODO(#212): use a map construct instead of unrolling.
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
outputs = [
conv_general_dilated(lhs, x, window_strides, padding,
@ -2339,6 +2337,7 @@ def reduce_window_sum_transpose_rule(cotangent, window_dimensions,
xla_bridge.get_xla_client().PaddingType.VALID)
assert result.shape == input_shape
return [result]
def reduce_window_sum_batch_rule(
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
operand, = batched_args
@ -2483,6 +2482,7 @@ def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
s_bdims, o_bdims = batch_dims
if s_bdims is not None and o_bdims is not None:
#TODO(#212): use a map construct instead of unrolling.
source = batching.move_dim_to_front(source, s_bdims)
operand = batching.move_dim_to_front(operand, o_bdims)
outputs = [
@ -2491,6 +2491,7 @@ def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
outputs = concatenate(outputs, 0)
return outputs, 0
elif s_bdims is not None:
#TODO(#212): use a map construct instead of unrolling.
source = batching.move_dim_to_front(source, s_bdims)
outputs = [
_select_and_scatter_add(s, operand, **kwargs) for s in source]
@ -2498,6 +2499,7 @@ def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
outputs = concatenate(outputs, 0)
return outputs, 0
elif o_bdims is not None:
#TODO(#212): use a map construct instead of unrolling.
operand = batching.move_dim_to_front(operand, o_bdims)
outputs = [
_select_and_scatter_add(source, o, **kwargs) for o in operand]