mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove dot primitive in favor of dot_general
This commit is contained in:
parent
096a52a3a3
commit
6d29c4e352
171
jax/lax/lax.py
171
jax/lax/lax.py
@ -508,7 +508,12 @@ def dot(lhs, rhs, precision=None):
|
||||
rhs = broadcast(rhs, (1,))
|
||||
return reduce(mul(lhs, rhs), _zero(lhs), add, (len(lhs_shape) - 1,))
|
||||
|
||||
return dot_p.bind(lhs, rhs, precision=_canonicalize_precision(precision))
|
||||
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and lhs.shape[-1] == rhs.shape[0]:
|
||||
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
|
||||
precision=precision)
|
||||
else:
|
||||
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
|
||||
lhs.shape, rhs.shape))
|
||||
|
||||
def dot_general(lhs, rhs, dimension_numbers, precision=None):
|
||||
"""More general contraction operator.
|
||||
@ -2019,109 +2024,6 @@ def _reshape_axis_out_of(src, size1, x):
|
||||
shape[src:src+1] = [size1, size2]
|
||||
return reshape(x, shape)
|
||||
|
||||
|
||||
def _dot_shape_rule(lhs, rhs, precision):
|
||||
if lhs.ndim == 0 or rhs.ndim == 0:
|
||||
msg = "Dot only supports rank 1 or above, got shapes {} and {}."
|
||||
raise TypeError(msg.format(lhs.shape, rhs.shape))
|
||||
if lhs.ndim > 2 or rhs.ndim > 2:
|
||||
msg = "Dot only supports rank 2 or less, got shapes {} and {}."
|
||||
raise TypeError(msg.format(lhs.shape, rhs.shape))
|
||||
|
||||
def require(shape_cond):
|
||||
if not shape_cond:
|
||||
msg = "Incompatible shapes for dot: got {} and {}."
|
||||
raise TypeError(msg.format(lhs.shape, rhs.shape))
|
||||
|
||||
if lhs.ndim == rhs.ndim == 1:
|
||||
require(lhs.shape == rhs.shape)
|
||||
return ()
|
||||
elif lhs.ndim == rhs.ndim == 2:
|
||||
require(lhs.shape[1] == rhs.shape[0])
|
||||
return (lhs.shape[0], rhs.shape[1])
|
||||
elif rhs.ndim == 1:
|
||||
require(lhs.shape[-1] == rhs.shape[0])
|
||||
return lhs.shape[:-1]
|
||||
else:
|
||||
require(lhs.shape[-1] == rhs.shape[-2])
|
||||
return lhs.shape[:-1] + rhs.shape[:-2] + rhs.shape[-1:]
|
||||
|
||||
def _dot_transpose_lhs(t, rhs, precision):
|
||||
if onp.ndim(t) == onp.ndim(rhs) == 2:
|
||||
return dot(t, transpose(rhs, (1, 0)), precision=precision)
|
||||
elif onp.ndim(t) == 1 and onp.ndim(rhs) == 2:
|
||||
return dot(rhs, t, precision=precision)
|
||||
elif onp.ndim(t) == onp.ndim(rhs) == 1:
|
||||
return _outer(t, rhs)
|
||||
elif onp.ndim(t) == 0 or onp.ndim(rhs) == 0:
|
||||
return mul(t, rhs)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _dot_transpose_rhs(t, lhs, precision):
|
||||
if onp.ndim(lhs) == onp.ndim(t) == 2:
|
||||
return dot(transpose(lhs, (1, 0)), t)
|
||||
elif onp.ndim(lhs) == 2 and onp.ndim(t) == 1:
|
||||
return dot(t, lhs, precision=precision)
|
||||
elif onp.ndim(t) == onp.ndim(lhs) == 1:
|
||||
return _outer(lhs, t)
|
||||
elif onp.ndim(t) == 0 or onp.ndim(lhs) == 0:
|
||||
return mul(t, lhs)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _outer(x, y):
|
||||
assert onp.ndim(x) == onp.ndim(y) == 1
|
||||
return mul(reshape(x, (x.shape[0], 1)), reshape(y, (1, y.shape[0])))
|
||||
|
||||
def _dot_batch_rule(batched_args, batch_dims, precision=None):
|
||||
lhs, rhs = batched_args
|
||||
lbd, rbd = batch_dims
|
||||
T = lambda x: transpose(x, onp.arange(onp.ndim(x))[::-1])
|
||||
|
||||
# in some cases, we can call dot instead of dot_general
|
||||
if max(onp.ndim(lhs), onp.ndim(rhs)) <= 2:
|
||||
if rbd is None:
|
||||
assert lbd in (0, 1)
|
||||
if lbd == 0:
|
||||
return dot(lhs, rhs, precision=precision), 0
|
||||
else:
|
||||
return dot(T(rhs), lhs, precision=precision), onp.ndim(rhs) - 1
|
||||
|
||||
if lbd is None:
|
||||
assert rbd in (0, 1)
|
||||
if rbd == onp.ndim(rhs) - 1:
|
||||
return dot(lhs, rhs, precision=precision), onp.ndim(lhs) - 1
|
||||
else:
|
||||
return dot(rhs, T(lhs), precision=precision), 0
|
||||
|
||||
assert lbd is not None and rbd is not None
|
||||
assert lhs.ndim == rhs.ndim == 2 # dot only supports rank 1 and above
|
||||
lhs = batching.moveaxis(lhs, lbd, 0)
|
||||
rhs = batching.moveaxis(rhs, rbd, 0)
|
||||
out = dot_general(lhs, rhs, [((1,), (1,)), ((0,), (0,))],
|
||||
precision=precision)
|
||||
return out, 0
|
||||
|
||||
if lbd is None:
|
||||
assert rbd is not None
|
||||
lhs = broadcast(lhs, (rhs.shape[rbd],))
|
||||
else:
|
||||
lhs = batching.moveaxis(lhs, lbd, 0)
|
||||
lhs_batch = (0,)
|
||||
lhs_contracting = (onp.ndim(lhs) - 1,)
|
||||
|
||||
if rbd is None:
|
||||
assert lbd is not None
|
||||
rhs = broadcast(rhs, (lhs.shape[0],))
|
||||
else:
|
||||
rhs = batching.moveaxis(rhs, rbd, 0)
|
||||
rhs_batch = (0,)
|
||||
rhs_contracting = (onp.arange(1, onp.ndim(rhs))[-2:][0],)
|
||||
|
||||
dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)]
|
||||
return dot_general(lhs, rhs, dim_nums, precision=precision), 0
|
||||
|
||||
def _precision_config(precision):
|
||||
if precision is not None:
|
||||
config = xla_client.PrecisionConfig()
|
||||
@ -2129,55 +2031,6 @@ def _precision_config(precision):
|
||||
return config
|
||||
return None
|
||||
|
||||
def _dot_translation_rule(c, lhs, rhs, precision):
|
||||
return c.Dot(lhs, rhs, precision_config=_precision_config(precision))
|
||||
|
||||
def _dot_polymorphic_shape_rule(shape_exprs, precision):
|
||||
del precision # Unused.
|
||||
lhs_shape, rhs_shape = shape_exprs
|
||||
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
|
||||
|
||||
if lhs_ndim == rhs_ndim == 1:
|
||||
if not lhs_shape == rhs_shape: raise ShapeError
|
||||
return ShapeExpr(())
|
||||
elif lhs_ndim == rhs_ndim == 2:
|
||||
if not lhs_shape[1] == rhs_shape[0]: raise ShapeError
|
||||
return ShapeExpr((lhs_shape[0], rhs_shape[1]))
|
||||
elif rhs_ndim == 1:
|
||||
if not lhs_shape[1] == rhs_shape[0]: raise ShapeError
|
||||
return ShapeExpr((lhs_shape[0],))
|
||||
else:
|
||||
if not lhs_shape[0] == rhs_shape[0]: raise ShapeError
|
||||
return ShapeExpr((rhs_shape[1],))
|
||||
|
||||
def _dot_masking_rule(padded_vals, logical_shapes, precision):
|
||||
lhs, rhs = padded_vals
|
||||
lhs_shape, rhs_shape = logical_shapes
|
||||
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
|
||||
|
||||
if lhs_ndim == rhs_ndim == 1:
|
||||
masked_lhs = select(iota(onp.int32, lhs.shape[0]) < lhs_shape[0],
|
||||
lhs, zeros_like_array(lhs))
|
||||
return dot_p.bind(masked_lhs, rhs, precision=precision)
|
||||
elif lhs_ndim == rhs_ndim == 2:
|
||||
# TODO could avoid select if we check whether contracted axis is masked
|
||||
masked_lhs = select(broadcasted_iota(onp.int32, lhs.shape, 1) < lhs_shape[1],
|
||||
lhs, zeros_like_array(lhs))
|
||||
return dot_p.bind(masked_lhs, rhs, precision=precision)
|
||||
elif rhs_ndim == 1:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
_dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
|
||||
dot_p = standard_primitive(_dot_shape_rule, _dot_dtype_rule, 'dot',
|
||||
_dot_translation_rule)
|
||||
ad.defbilinear(dot_p, _dot_transpose_lhs, _dot_transpose_rhs)
|
||||
batching.primitive_batchers[dot_p] = _dot_batch_rule
|
||||
masking.shape_rules[dot_p] = _dot_polymorphic_shape_rule
|
||||
masking.masking_rules[dot_p] = _dot_masking_rule
|
||||
|
||||
|
||||
def _dot_general_shape_rule(lhs, rhs, dimension_numbers, precision):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
@ -2276,8 +2129,8 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers,
|
||||
else:
|
||||
# move the new dimension to the end of lhs to avoid changing batch dims
|
||||
lhs = batching.moveaxis(lhs, lbd, lhs.ndim - 1)
|
||||
# lhs tensor product dims in result come after batch dims
|
||||
result_batch_dim = lhs.ndim - len(lhs_contract) - 1
|
||||
# lhs tensor product dims in result come after batch dims
|
||||
result_batch_dim = lhs.ndim - len(lhs_contract) - 1
|
||||
else:
|
||||
if rhs_batch == () or rbd > onp.max(rhs_batch):
|
||||
# can avoid transposes
|
||||
@ -2288,10 +2141,10 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers,
|
||||
else:
|
||||
# move the new dimension to the end of rhs to avoid changing batch dims
|
||||
rhs = batching.moveaxis(rhs, rbd, rhs.ndim - 1)
|
||||
# rhs tensor product dims in result come after batch dims + lhs tensor
|
||||
# product dims
|
||||
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
|
||||
rhs.ndim - len(rhs_contract) - 1)
|
||||
# rhs tensor product dims in result come after batch dims + lhs tensor
|
||||
# product dims
|
||||
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
|
||||
rhs.ndim - len(rhs_contract) - 1)
|
||||
new_dimension_numbers = [(lhs_contract, rhs_contract), (lhs_batch, rhs_batch)]
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
precision=precision)
|
||||
|
@ -411,11 +411,6 @@ _defreducer(lax.reduce_max_p, pmax_p)
|
||||
_defreducer(lax.reduce_min_p, pmin_p)
|
||||
|
||||
|
||||
def _dot_papply_rule(name, size, vals, dims, precision):
|
||||
x, _ = vals
|
||||
dim_nums = [((x.ndim,), (0,)), ((), ())]
|
||||
return _dot_general_papply_rule(name, size, vals, dims, dim_nums, precision)
|
||||
|
||||
def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers,
|
||||
precision):
|
||||
x, y = vals
|
||||
@ -700,7 +695,6 @@ def _gather_papply_rule(
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
parallel.papply_primitive_rules[lax.dot_p] = _dot_papply_rule
|
||||
parallel.papply_primitive_rules[lax.dot_general_p] = _dot_general_papply_rule
|
||||
parallel.papply_primitive_rules[lax.reshape_p] = _reshape_papply_rule
|
||||
parallel.papply_primitive_rules[lax.transpose_p] = _transpose_papply_rule
|
||||
|
Loading…
x
Reference in New Issue
Block a user