[sharding_in_types] Add out_type argument to einsum and dot_general to allow specifying for the output type. Right now, it only accept a NamedSharding but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout.

PiperOrigin-RevId: 688399552
This commit is contained in:
Yash Katariya 2024-10-21 22:23:15 -07:00 committed by jax authors
parent 5d3cac6603
commit ebb75db8a5
8 changed files with 142 additions and 34 deletions

View File

@ -1040,7 +1040,8 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
preferred_element_type: DTypeLike | None = None,
out_type=None) -> Array:
"""General dot product/contraction operator.
Wraps XLA's `DotGeneral
@ -1086,6 +1087,10 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
non-contracting/non-batch dimensions.
"""
if out_type is not None and not isinstance(out_type, NamedSharding):
raise NotImplementedError(
'`out_type` argument of `dot_general` only supports NamedSharding '
'instances. Please file a bug if this is not enough for your use case.')
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
@ -1097,7 +1102,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(cdims, bdims),
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type)
preferred_element_type=preferred_element_type,
out_type=out_type)
def ragged_dot(
@ -3002,7 +3008,11 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = hlo.real(operand)
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
if config.sharding_in_types.value:
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [out]
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
@ -3164,7 +3174,8 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
out_type):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
for d in (lhs_contracting, lhs_batch)):
@ -3241,12 +3252,16 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
raise TypeError(msg)
def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
out_type):
if lhs.sharding.mesh != rhs.sharding.mesh:
raise ValueError(
'Mesh of both lhs and rhs should match. Got lhs:'
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')
if out_type is not None:
return out_type
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch)
@ -3280,7 +3295,8 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
out_type):
del dimension_numbers # unused
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
# primitive_util.h's HigherPrecisionType, e.g.
@ -3327,7 +3343,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
swap_ans=False):
out_type, swap_ans=False):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim
x_kept = remaining(range(x_ndim), x_contract, x_batch)
@ -3347,12 +3363,14 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
return x_bar
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
out_type):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
y_bar = _dot_general_transpose_lhs(
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
preferred_element_type=preferred_element_type, swap_ans=True)
preferred_element_type=preferred_element_type, out_type=out_type,
swap_ans=True)
if y_bar.dtype != y.aval.dtype:
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
return y_bar
@ -3366,6 +3384,7 @@ def _dot_batch_rule(
batch_dims,
*,
dimension_numbers,
out_type,
precision,
preferred_element_type: DTypeLike | None,
**_,
@ -3395,12 +3414,16 @@ def _dot_batch_rule(
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = np.shape(rhs)
if out_type is not None:
raise NotImplementedError("vmap with out_type is not supported. "
"Please open an issue.")
batched_out = invoke_prim(
lhs,
rhs,
new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_type=out_type,
)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
@ -3570,7 +3593,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: np.dtype | None,
platform: str = "default"):
out_type, platform: str = "default"):
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
@ -3658,6 +3681,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
**algorithm_kwarg,
)
if config.sharding_in_types.value:
if out_type is not None:
assert aval_out.sharding == out_type
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp)
if accumulation_aval.dtype != aval_out.dtype:
@ -3711,12 +3736,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
return (m, n)
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
precision, preferred_element_type: DTypeLike | None,
**_) -> np.dtype:
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision, preferred_element_type=preferred_element_type)
return _dot_general_dtype_rule(
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision, preferred_element_type=preferred_element_type,
out_type=None)
def _ragged_dot_jvp_rule(
@ -3855,6 +3883,7 @@ def _ragged_dot_batch_rule(
*,
precision,
preferred_element_type: DTypeLike | None,
out_type,
**_,
):
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
@ -3868,6 +3897,7 @@ def _ragged_dot_batch_rule(
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision,
preferred_element_type=preferred_element_type,
out_type=out_type,
)

View File

@ -67,10 +67,10 @@ from jax._src.typing import (
DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar,
)
from jax._src.util import (
NumpyComplexWarning,
canonicalize_axis as _canonicalize_axis,
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
from jax.sharding import Sharding, SingleDeviceSharding
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
PartitionSpec as P)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np
import opt_einsum
@ -8955,6 +8955,7 @@ def einsum(
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
out_type=None,
) -> Array: ...
@overload
@ -8967,6 +8968,7 @@ def einsum(
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
out_type=None,
) -> Array: ...
def einsum(
@ -8977,6 +8979,7 @@ def einsum(
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
out_type=None,
) -> Array:
"""Einstein summation
@ -9208,11 +9211,11 @@ def einsum(
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
return einsum(operands, contractions, precision,
preferred_element_type, _dot_general)
preferred_element_type, _dot_general, out_type)
# Enable other modules to override einsum_contact_path.
@ -9311,7 +9314,12 @@ def _einsum(
precision,
preferred_element_type,
_dot_general=lax.dot_general,
out_type=None,
):
if out_type is not None and not isinstance(out_type, NamedSharding):
raise NotImplementedError(
"`out_type` argument of `einsum` only supports NamedSharding instances."
" Please file a bug if this is not enough for your use case.")
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
operands = list(map(asarray, operands))
if preferred_element_type is None:
@ -9434,12 +9442,21 @@ def _einsum(
if names == result_names:
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
preferred_element_type=preferred_element_type)
preferred_element_type=preferred_element_type,
out_type=out_type)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
if (config.sharding_in_types.value and out_type is not None and
names != result_names):
spec = out_type.spec
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec))
else:
dot_general_out_type = out_type # type: ignore
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
preferred_element_type=preferred_element_type)
preferred_element_type=preferred_element_type,
out_type=dot_general_out_type)
else:
raise NotImplementedError # if this is actually reachable, open an issue!
@ -9452,7 +9469,8 @@ def _einsum(
operand = lax.transpose(operand, perm)
operands.append(operand) # used in next iteration
return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type)
return lax_internal._convert_element_type(operands[0], preferred_element_type,
output_weak_type)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)

View File

@ -2089,10 +2089,11 @@ def _dot_general_lowering(
b,
*,
dimension_numbers,
out_type,
precision,
preferred_element_type,
):
del preferred_element_type # Unused.
del preferred_element_type, out_type # Unused.
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
assert batch_dims == ((), ())

View File

@ -2180,7 +2180,7 @@ def _conv_general_dilated(lhs, rhs, *,
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
def _dot_general(lhs, rhs, *, dimension_numbers, out_type,
precision: lax_internal.CanonicalPrecision,
preferred_element_type: DType | None,
_in_avals: Sequence[core.ShapedArray],

View File

@ -606,8 +606,11 @@ mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None,
out_type=None) -> BCOO | Array:
"""A general contraction operation.
Args:
@ -625,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
the result will be dense, of type ndarray.
"""
# TODO(jakevdp) make use of these?
del precision # unused
del precision, out_type # unused
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
dimension_numbers)
@ -1051,7 +1054,8 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
kwds = {'dimension_numbers': dimension_numbers,
'precision': None,
'preferred_element_type': None}
'preferred_element_type': None,
'out_type': None}
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
return A, B, indices

View File

@ -462,7 +462,8 @@ bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None) -> Array:
preferred_element_type: None = None,
out_type=None) -> Array:
"""A general contraction operation.
Args:
@ -479,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision # unused
del precision, out_type # unused
if isinstance(rhs, (np.ndarray, jax.Array)):
if isinstance(lhs, (np.ndarray, jax.Array)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,

View File

@ -111,4 +111,4 @@ def _dot_general_validated_shape(
rhs = core.ShapedArray(rhs_shape, np.float32)
return _dot_general_shape_rule(
lhs, rhs, dimension_numbers=dimension_numbers,
precision=None, preferred_element_type=None)
precision=None, preferred_element_type=None, out_type=None)

View File

@ -4945,6 +4945,60 @@ class ShardingInTypesTest(jtu.JaxTestCase):
_, out = g(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
def test_einsum_with_out_type(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
self.assertEqual(out.sharding.spec, P('x', None))
return out
out = f(arr1, arr2)
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
lowered_text = f.lower(arr1, arr2).as_text()
self.assertIn('@Sharding', lowered_text)
@jax.jit
def g(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
self.assertEqual(out.sharding.spec, P('x', None))
return out
arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr4 = jax.device_put(np_inp.T, NamedSharding(mesh, P('x', 'y')))
out2 = g(arr3, arr4)
self.assertArraysEqual(out2, np_inp @ np_inp.T)
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None)))
def test_einsum_inverse(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(64)
@jax.jit
def h(x, y):
s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None))
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s)
self.assertEqual(out.sharding.spec, s.spec)
return out
arr1 = jax.device_put(np_inp.reshape(8, 4, 2),
NamedSharding(mesh, P('x', 'y', None)))
arr2 = jax.device_put(np_inp.reshape(2, 4, 8),
NamedSharding(mesh, P(None, 'x', 'y')))
out = h(arr1, arr2)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None)))
lowered_text = h.lower(arr1, arr2).as_text()
self.assertIn('@Sharding', lowered_text)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):