parallelization rule for general dot

This commit is contained in:
Roy Frostig 2019-03-22 14:10:38 -07:00
parent 4cd5da74e2
commit e4a041ce53

View File

@ -2197,11 +2197,55 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers):
batched_out = dot_general(lhs, rhs, new_dimension_numbers)
return batched_out, 0
def _dot_general_papply_rule(name, vals, dims, dimension_numbers):
x, y = vals
xdim, ydim = dims
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
if len(lhs_batch) > 0 or len(rhs_batch) > 0:
raise NotImplementedError
def adjust_dims(dims, thresh):
return tuple(i - 1 if i >= thresh else i for i in dims if i != thresh)
sub_lhs_contract, sub_rhs_contract = lhs_contract, rhs_contract
if xdim is not None:
sub_lhs_contract = adjust_dims(lhs_contract, xdim)
if ydim is not None:
sub_rhs_contract = adjust_dims(rhs_contract, ydim)
sub_dimension_numbers = (
(sub_lhs_contract, sub_rhs_contract), (lhs_batch, rhs_batch))
if xdim in lhs_contract and ydim in rhs_contract:
z = dot_general(x, y, sub_dimension_numbers)
return psum(z, name), None
elif xdim in lhs_contract:
if ydim is not None: # Cannot hide two dimensions, so collect one
y = pcollect(y, name)
return dot_general(x, y, sub_dimension_numbers), xdim
elif ydim in rhs_contract:
if xdim is not None: # Cannot hide two dimensions, so collect one
x = pcollect(x, name)
return dot_general(x, y, sub_dimension_numbers), ydim
elif xdim is not None:
if ydim is not None: # Cannot hide two dimensions, so collect one
y = pcollect(y, name)
return dot_general(x, y, sub_dimension_numbers), xdim
elif ydim is not None:
return dot_general(x, y, sub_dimension_numbers), ydim
else:
return dot_general(x, y, sub_dimension_numbers), None
dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general')
ad.defbilinear(dot_general_p,
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
parallel.papply_primitive_rules[dot_general_p] = _dot_general_papply_rule
def _broadcast_shape_rule(operand, sizes):