mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added batching rules for convolutions + pooling.
Added batching rules: conv_general_dilated_batch_rule select_and_scatter_add_batch_rule reduce_window_max_batch_rule reduce_window_sum_batch_rule
This commit is contained in:
parent
d968e1e572
commit
a15bad401f
112
jax/lax.py
112
jax/lax.py
@ -1128,13 +1128,61 @@ def conv_general_dilated_translation_rule(
|
||||
return c.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers)
|
||||
|
||||
def conv_general_dilated_batch_rule(
|
||||
batched_args, batch_dims, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers, **unused_kwargs):
|
||||
lhs, rhs = batched_args
|
||||
lhs_bdim, rhs_bdim = batch_dims
|
||||
lhs_dim, rhs_dim, out_dim = dimension_numbers
|
||||
|
||||
if lhs_bdim is not None and rhs_bdim is not None:
|
||||
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
|
||||
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
|
||||
|
||||
outputs = [
|
||||
conv_general_dilated(l, r, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers)
|
||||
for l, r in zip(lhs, rhs)]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
|
||||
elif lhs_bdim is not None:
|
||||
# Currently we don't handle cases where the batch dimension of the
|
||||
# convolution isn't the first dimension.
|
||||
if lhs_dim[0] != 0 or out_dim[0] != 0:
|
||||
raise NotImplementedError
|
||||
lhs = batching.move_dim_to_front(lhs, lhs_dim[0])
|
||||
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
|
||||
|
||||
batched_size = lhs.shape[0]
|
||||
n_size = lhs.shape[1]
|
||||
|
||||
lhs = reshape(lhs, (batched_size * n_size,) + lhs.shape[2:])
|
||||
outputs = conv_general_dilated(
|
||||
lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers)
|
||||
outputs = reshape(outputs, (batched_size, n_size,) + outputs.shape[1:])
|
||||
|
||||
return outputs, 0
|
||||
elif rhs_bdim is not None:
|
||||
# TODO(schsam): Consider a loop instead of unrolling.
|
||||
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
|
||||
outputs = [
|
||||
conv_general_dilated(lhs, x, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers)
|
||||
for x in rhs]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
conv_general_dilated_p = standard_primitive(
|
||||
conv_general_dilated_shape_rule, conv_general_dilated_dtype_rule,
|
||||
'conv_general_dilated', conv_general_dilated_translation_rule)
|
||||
ad.defbilinear(conv_general_dilated_p,
|
||||
conv_general_dilated_transpose_lhs,
|
||||
conv_general_dilated_transpose_rhs)
|
||||
|
||||
batching.primitive_batchers[
|
||||
conv_general_dilated_p] = conv_general_dilated_batch_rule
|
||||
|
||||
def dot_shape_rule(lhs, rhs):
|
||||
if lhs.ndim == 0 or rhs.ndim == 0:
|
||||
@ -2291,12 +2339,26 @@ def reduce_window_sum_transpose_rule(cotangent, window_dimensions,
|
||||
xla_bridge.get_xla_client().PaddingType.VALID)
|
||||
assert result.shape == input_shape
|
||||
return [result]
|
||||
def reduce_window_sum_batch_rule(
|
||||
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
|
||||
operand, = batched_args
|
||||
bdim, = bdims
|
||||
|
||||
if bdim is not None:
|
||||
window_dimensions = \
|
||||
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
|
||||
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
|
||||
|
||||
oprand = _reduce_window_sum(
|
||||
operand, window_dimensions, window_strides, padding)
|
||||
|
||||
return oprand, 0
|
||||
|
||||
reduce_window_sum_p = standard_primitive(
|
||||
reduce_window_sum_shape_rule, _input_dtype, 'reduce_window_sum',
|
||||
reduce_window_sum_translation_rule)
|
||||
ad.deflinear(reduce_window_sum_p, reduce_window_sum_transpose_rule)
|
||||
|
||||
batching.primitive_batchers[reduce_window_sum_p] = reduce_window_sum_batch_rule
|
||||
|
||||
def reduce_window_chooser_translation_rule(
|
||||
prim, identity, c, operand, window_dimensions, window_strides, padding):
|
||||
@ -2338,6 +2400,20 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
||||
onp.subtract(operand_padded, window_dimensions), window_strides) + 1
|
||||
return tuple(t)
|
||||
|
||||
def reduce_window_max_batch_rule(
|
||||
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
|
||||
operand, = batched_args
|
||||
bdim, = bdims
|
||||
|
||||
if bdim is not None:
|
||||
window_dimensions = \
|
||||
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
|
||||
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
|
||||
|
||||
operand = _reduce_window_max(
|
||||
operand, window_dimensions, window_strides, padding)
|
||||
|
||||
return operand, 0
|
||||
|
||||
reduce_window_max_translation_rule = partial(
|
||||
reduce_window_chooser_translation_rule, max_p, _get_max_identity)
|
||||
@ -2345,7 +2421,7 @@ reduce_window_max_p = standard_primitive(
|
||||
common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
|
||||
reduce_window_max_translation_rule)
|
||||
ad.defjvp(reduce_window_max_p, partial(reduce_window_chooser_jvp_rule, max_p))
|
||||
|
||||
batching.primitive_batchers[reduce_window_max_p] = reduce_window_max_batch_rule
|
||||
|
||||
reduce_window_min_translation_rule = partial(
|
||||
reduce_window_chooser_translation_rule, min_p, _get_min_identity)
|
||||
@ -2402,12 +2478,40 @@ def select_and_scatter_add_transpose(
|
||||
window_strides, padding)
|
||||
return [result, None]
|
||||
|
||||
def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
source, operand = batched_args
|
||||
s_bdims, o_bdims = batch_dims
|
||||
|
||||
if s_bdims is not None and o_bdims is not None:
|
||||
source = batching.move_dim_to_front(source, s_bdims)
|
||||
operand = batching.move_dim_to_front(operand, o_bdims)
|
||||
outputs = [
|
||||
_select_and_scatter_add(s, o, **kwargs) for s, o in zip(source, operand)]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
elif s_bdims is not None:
|
||||
source = batching.move_dim_to_front(source, s_bdims)
|
||||
outputs = [
|
||||
_select_and_scatter_add(s, operand, **kwargs) for s in source]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
elif o_bdims is not None:
|
||||
operand = batching.move_dim_to_front(operand, o_bdims)
|
||||
outputs = [
|
||||
_select_and_scatter_add(source, o, **kwargs) for o in operand]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
|
||||
select_and_scatter_add_p = standard_primitive(
|
||||
select_and_scatter_add_shape_rule, _input_dtype, 'select_and_scatter_add',
|
||||
select_and_scatter_add_translation)
|
||||
ad.primitive_transposes[select_and_scatter_add_p] = \
|
||||
select_and_scatter_add_transpose
|
||||
|
||||
batching.primitive_batchers[select_and_scatter_add_p] = \
|
||||
select_and_scatter_add_batch_rule
|
||||
|
||||
def _select_and_gather_add_shape_rule(
|
||||
tangents, operand, select_prim, pair_select_jaxpr, pair_select_consts,
|
||||
|
Loading…
x
Reference in New Issue
Block a user