mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
start to sketch out gather batching rule (WIP)
This commit is contained in:
parent
42b3218e90
commit
cde5c925fd
28
jax/lax.py
28
jax/lax.py
@ -2079,12 +2079,40 @@ def _gather_transpose_rule(t, operand, start_indices, dimension_numbers,
|
||||
index_vector_dim=dimension_numbers.index_vector_dim)
|
||||
return [scatter_add(zeros, start_indices, t, scatter_dnums), ad_util.zero]
|
||||
|
||||
def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
|
||||
slice_sizes, operand_shape):
|
||||
operand, start_indices = batched_args
|
||||
operand_bdim, start_indices_bdim = batch_dims
|
||||
|
||||
if operand_bdim is not None and start_indices_bdim is None:
|
||||
slice_sizes = list(slice_sizes)
|
||||
slice_sizes.insert(operand_bdim, operand.shape[operand_bdim])
|
||||
|
||||
offset_dims = tuple(dimension_numbers.offset_dims) + (operand_bdim,)
|
||||
|
||||
collapsed_slice_dims = tuple(i+1 if i >= operand_bdim else i
|
||||
for i in dimension_numbers.collapsed_slice_dims)
|
||||
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=collapsed_slice_dims,
|
||||
start_index_map=dimension_numbers.start_index_map,
|
||||
index_vector_dim=dimension_numbers.index_vector_dim)
|
||||
|
||||
out_bdim = 0 # TODO
|
||||
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), out_bdim
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj):
|
||||
|
||||
|
||||
gather_p = standard_primitive(
|
||||
_gather_shape_rule, _gather_dtype_rule, 'gather',
|
||||
_gather_translation_rule)
|
||||
ad.defjvp(gather_p, _gather_jvp_rule, None)
|
||||
ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
||||
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
|
||||
|
||||
ScatterDimensionNumbers = collections.namedtuple(
|
||||
|
Loading…
x
Reference in New Issue
Block a user