add numpy indexing batching tests

This commit is contained in:
Matthew Johnson 2019-02-10 18:36:21 -08:00
parent b53eb241f7
commit 6dfe2d6e36
2 changed files with 605 additions and 504 deletions

View File

@ -2102,15 +2102,30 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
index_vector_dim = dimension_numbers.index_vector_dim + 1
start_index_map = (0,) + tuple(dimension_numbers.start_index_map)
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
start_index_map=start_index_map, index_vector_dim=index_vector_dim)
start_index_map=dimension_numbers.start_index_map,
index_vector_dim=index_vector_dim)
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0
else:
raise NotImplementedError # TODO(mattjj, phawkins)
# TODO(mattjj,phawkins): this code is wrong
operand = batching.move_dim_to_front(operand, operand_bdim)
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
slice_sizes = (1,) + slice_sizes
collapsed_slice_dims = (0,) + tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
start_index_map = tuple(onp.add(1, dimension_numbers.start_index_map))
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=collapsed_slice_dims,
start_index_map=start_index_map,
index_vector_dim=dimension_numbers.index_vector_dim + 1)
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0
gather_p = standard_primitive(

File diff suppressed because it is too large Load Diff