mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Misc. small fixes.
This commit is contained in:
parent
6181560ef0
commit
3b8d43cefa
12
jax/lax.py
12
jax/lax.py
@ -1136,9 +1136,9 @@ def conv_general_dilated_batch_rule(
|
||||
lhs_dim, rhs_dim, out_dim = dimension_numbers
|
||||
|
||||
if lhs_bdim is not None and rhs_bdim is not None:
|
||||
#TODO(#212): use a map construct instead of unrolling.
|
||||
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
|
||||
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
|
||||
|
||||
outputs = [
|
||||
conv_general_dilated(l, r, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers)
|
||||
@ -1152,12 +1152,10 @@ def conv_general_dilated_batch_rule(
|
||||
# convolution isn't the first dimension.
|
||||
if lhs_dim[0] != 0 or out_dim[0] != 0:
|
||||
raise NotImplementedError
|
||||
lhs = batching.move_dim_to_front(lhs, lhs_dim[0])
|
||||
|
||||
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
|
||||
|
||||
batched_size = lhs.shape[0]
|
||||
n_size = lhs.shape[1]
|
||||
|
||||
lhs = reshape(lhs, (batched_size * n_size,) + lhs.shape[2:])
|
||||
outputs = conv_general_dilated(
|
||||
lhs, rhs, window_strides, padding,
|
||||
@ -1166,7 +1164,7 @@ def conv_general_dilated_batch_rule(
|
||||
|
||||
return outputs, 0
|
||||
elif rhs_bdim is not None:
|
||||
# TODO(schsam): Consider a loop instead of unrolling.
|
||||
#TODO(#212): use a map construct instead of unrolling.
|
||||
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
|
||||
outputs = [
|
||||
conv_general_dilated(lhs, x, window_strides, padding,
|
||||
@ -2339,6 +2337,7 @@ def reduce_window_sum_transpose_rule(cotangent, window_dimensions,
|
||||
xla_bridge.get_xla_client().PaddingType.VALID)
|
||||
assert result.shape == input_shape
|
||||
return [result]
|
||||
|
||||
def reduce_window_sum_batch_rule(
|
||||
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
|
||||
operand, = batched_args
|
||||
@ -2483,6 +2482,7 @@ def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
s_bdims, o_bdims = batch_dims
|
||||
|
||||
if s_bdims is not None and o_bdims is not None:
|
||||
#TODO(#212): use a map construct instead of unrolling.
|
||||
source = batching.move_dim_to_front(source, s_bdims)
|
||||
operand = batching.move_dim_to_front(operand, o_bdims)
|
||||
outputs = [
|
||||
@ -2491,6 +2491,7 @@ def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
elif s_bdims is not None:
|
||||
#TODO(#212): use a map construct instead of unrolling.
|
||||
source = batching.move_dim_to_front(source, s_bdims)
|
||||
outputs = [
|
||||
_select_and_scatter_add(s, operand, **kwargs) for s in source]
|
||||
@ -2498,6 +2499,7 @@ def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
elif o_bdims is not None:
|
||||
#TODO(#212): use a map construct instead of unrolling.
|
||||
operand = batching.move_dim_to_front(operand, o_bdims)
|
||||
outputs = [
|
||||
_select_and_scatter_add(source, o, **kwargs) for o in operand]
|
||||
|
Loading…
x
Reference in New Issue
Block a user