gather passing all operand vmap tests

This commit is contained in:
Matthew Johnson 2019-02-06 10:58:41 -08:00
parent b6cb3509cd
commit b53eb241f7
2 changed files with 550 additions and 527 deletions

View File

@ -2085,40 +2085,32 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
operand_bdim, start_indices_bdim = batch_dims
if operand_bdim is not None and start_indices_bdim is None:
collapsed_slice_dims = set(dimension_numbers.collapsed_slice_dims)
num_preceding_window_dims = sum( # 2
1 for i in range(len(slice_sizes))
if i < operand_bdim and i not in collapsed_slice_dims)
offset_dims = list(dimension_numbers.offset_dims) # [2, 3, 7]
if num_preceding_window_dims == 0:
bdim_offset_dim = 0
else:
bdim_offset_dim = offset_dims[num_preceding_window_dims - 1] + 1
new_offset_dims = (offset_dims[:num_preceding_window_dims]
+ [bdim_offset_dim]
+ list(onp.add(1, offset_dims[num_preceding_window_dims:])))
new_offset_dims = tuple(new_offset_dims)
slice_sizes = list(slice_sizes)
slice_sizes.insert(operand_bdim, operand.shape[operand_bdim])
slice_sizes = tuple(slice_sizes)
collapsed_slice_dims = tuple(i + 1 if i >= operand_bdim else i
for i in dimension_numbers.collapsed_slice_dims)
start_index_map = tuple(i + 1 if i > operand_bdim else i
for i in dimension_numbers.start_index_map)
operand = batching.move_dim_to_front(operand, operand_bdim)
slice_sizes = (operand.shape[0],) + slice_sizes
offset_dims = (0,) + tuple(onp.add(1, dimension_numbers.offset_dims))
collapsed_slice_dims = tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))
start_index_map = tuple(onp.add(1, dimension_numbers.start_index_map))
dnums = GatherDimensionNumbers(
offset_dims=new_offset_dims,
offset_dims=offset_dims,
collapsed_slice_dims=collapsed_slice_dims,
start_index_map=start_index_map,
index_vector_dim=dimension_numbers.index_vector_dim)
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), bdim_offset_dim
slice_sizes=slice_sizes), 0
elif operand_bdim is None and start_indices_bdim is not None:
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
index_vector_dim = dimension_numbers.index_vector_dim + 1
start_index_map = (0,) + tuple(dimension_numbers.start_index_map)
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
start_index_map=start_index_map, index_vector_dim=index_vector_dim)
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0
else:
raise NotImplementedError # TODO(mattjj):
raise NotImplementedError # TODO(mattjj, phawkins)
gather_p = standard_primitive(

File diff suppressed because it is too large Load Diff