start to sketch out gather batching rule (WIP)

This commit is contained in:
Matthew Johnson 2019-02-03 09:00:16 -08:00
parent 42b3218e90
commit cde5c925fd

View File

@ -2079,12 +2079,40 @@ def _gather_transpose_rule(t, operand, start_indices, dimension_numbers,
index_vector_dim=dimension_numbers.index_vector_dim)
return [scatter_add(zeros, start_indices, t, scatter_dnums), ad_util.zero]
def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
slice_sizes, operand_shape):
operand, start_indices = batched_args
operand_bdim, start_indices_bdim = batch_dims
if operand_bdim is not None and start_indices_bdim is None:
slice_sizes = list(slice_sizes)
slice_sizes.insert(operand_bdim, operand.shape[operand_bdim])
offset_dims = tuple(dimension_numbers.offset_dims) + (operand_bdim,)
collapsed_slice_dims = tuple(i+1 if i >= operand_bdim else i
for i in dimension_numbers.collapsed_slice_dims)
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=collapsed_slice_dims,
start_index_map=dimension_numbers.start_index_map,
index_vector_dim=dimension_numbers.index_vector_dim)
out_bdim = 0 # TODO
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), out_bdim
else:
raise NotImplementedError # TODO(mattjj):
gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',
_gather_translation_rule)
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
ScatterDimensionNumbers = collections.namedtuple(