mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
parallelization rule for general dot
This commit is contained in:
parent
4cd5da74e2
commit
e4a041ce53
44
jax/lax.py
44
jax/lax.py
@ -2197,11 +2197,55 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers):
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers)
|
||||
return batched_out, 0
|
||||
|
||||
|
||||
def _dot_general_papply_rule(name, vals, dims, dimension_numbers):
|
||||
x, y = vals
|
||||
xdim, ydim = dims
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
if len(lhs_batch) > 0 or len(rhs_batch) > 0:
|
||||
raise NotImplementedError
|
||||
|
||||
def adjust_dims(dims, thresh):
|
||||
return tuple(i - 1 if i >= thresh else i for i in dims if i != thresh)
|
||||
|
||||
sub_lhs_contract, sub_rhs_contract = lhs_contract, rhs_contract
|
||||
if xdim is not None:
|
||||
sub_lhs_contract = adjust_dims(lhs_contract, xdim)
|
||||
if ydim is not None:
|
||||
sub_rhs_contract = adjust_dims(rhs_contract, ydim)
|
||||
|
||||
sub_dimension_numbers = (
|
||||
(sub_lhs_contract, sub_rhs_contract), (lhs_batch, rhs_batch))
|
||||
|
||||
if xdim in lhs_contract and ydim in rhs_contract:
|
||||
z = dot_general(x, y, sub_dimension_numbers)
|
||||
return psum(z, name), None
|
||||
elif xdim in lhs_contract:
|
||||
if ydim is not None: # Cannot hide two dimensions, so collect one
|
||||
y = pcollect(y, name)
|
||||
return dot_general(x, y, sub_dimension_numbers), xdim
|
||||
elif ydim in rhs_contract:
|
||||
if xdim is not None: # Cannot hide two dimensions, so collect one
|
||||
x = pcollect(x, name)
|
||||
return dot_general(x, y, sub_dimension_numbers), ydim
|
||||
elif xdim is not None:
|
||||
if ydim is not None: # Cannot hide two dimensions, so collect one
|
||||
y = pcollect(y, name)
|
||||
return dot_general(x, y, sub_dimension_numbers), xdim
|
||||
elif ydim is not None:
|
||||
return dot_general(x, y, sub_dimension_numbers), ydim
|
||||
else:
|
||||
return dot_general(x, y, sub_dimension_numbers), None
|
||||
|
||||
|
||||
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
||||
_dot_general_dtype_rule, 'dot_general')
|
||||
ad.defbilinear(dot_general_p,
|
||||
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
||||
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
||||
parallel.papply_primitive_rules[dot_general_p] = _dot_general_papply_rule
|
||||
|
||||
|
||||
def _broadcast_shape_rule(operand, sizes):
|
||||
|
Loading…
x
Reference in New Issue
Block a user