diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 113c87b60..0561a5b73 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1040,7 +1040,8 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]], def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + out_type=None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -1086,6 +1087,10 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError( + '`out_type` argument of `dot_general` only supports NamedSharding ' + 'instances. Please file a bug if this is not enough for your use case.') (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), api_util._ensure_index_tuple(rhs_contract)) @@ -1097,7 +1102,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_type=out_type) def ragged_dot( @@ -3002,7 +3008,11 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, not dtypes.issubdtype(new_dtype, np.complexfloating)): operand = hlo.real(operand) aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) - return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)] + out = mlir.convert_hlo(ctx, operand, aval_in, aval_out) + if config.sharding_in_types.value: + proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [out] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -3164,7 +3174,8 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): @@ -3241,24 +3252,28 @@ def _check_specs_match(lhs_spec, rhs_spec, msg): raise TypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): if lhs.sharding.mesh != rhs.sharding.mesh: raise ValueError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') + if out_type is not None: + return out_type + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch) msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " - f"to have the consistent sharding, got {lhs_batch_spec} and " - f"{rhs_batch_spec}.") + f"to have the consistent sharding, got {lhs_batch_spec} and " + f"{rhs_batch_spec}.") _check_specs_match(lhs_batch_spec, rhs_batch_spec, msg) lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) msg = ("dot_general requires contracting dimensions to have consistent " - f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") + f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") _check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg) return _dot_general_sharding_computation( @@ -3280,7 +3295,8 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): del dimension_numbers # unused # We're mostly matching XLA's logic here, namely in shape_inference.cc and # primitive_util.h's HigherPrecisionType, e.g. @@ -3327,7 +3343,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - swap_ans=False): + out_type, swap_ans=False): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim x_kept = remaining(range(x_ndim), x_contract, x_batch) @@ -3347,12 +3363,14 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) y_bar = _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type, swap_ans=True) + preferred_element_type=preferred_element_type, out_type=out_type, + swap_ans=True) if y_bar.dtype != y.aval.dtype: y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) return y_bar @@ -3366,6 +3384,7 @@ def _dot_batch_rule( batch_dims, *, dimension_numbers, + out_type, precision, preferred_element_type: DTypeLike | None, **_, @@ -3395,12 +3414,16 @@ def _dot_batch_rule( rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) else: rhs_shape = np.shape(rhs) + if out_type is not None: + raise NotImplementedError("vmap with out_type is not supported. " + "Please open an issue.") batched_out = invoke_prim( lhs, rhs, new_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, + out_type=out_type, ) result_batch_dim = batching.shape_as_bdim( result_stack_dim, @@ -3570,7 +3593,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, - platform: str = "default"): + out_type, platform: str = "default"): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -3658,6 +3681,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, **algorithm_kwarg, ) if config.sharding_in_types.value: + if out_type is not None: + assert aval_out.sharding == out_type out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp) if accumulation_aval.dtype != aval_out.dtype: @@ -3711,12 +3736,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S return (m, n) def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: + precision, preferred_element_type: DTypeLike | None, + **_) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. - return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, - precision=precision, preferred_element_type=preferred_element_type) + return _dot_general_dtype_rule( + lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, preferred_element_type=preferred_element_type, + out_type=None) def _ragged_dot_jvp_rule( @@ -3855,6 +3883,7 @@ def _ragged_dot_batch_rule( *, precision, preferred_element_type: DTypeLike | None, + out_type, **_, ): invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2]) @@ -3868,6 +3897,7 @@ def _ragged_dot_batch_rule( dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type, + out_type=out_type, ) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 971f5116e..9c62c25d0 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -67,10 +67,10 @@ from jax._src.typing import ( DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar, ) from jax._src.util import ( - NumpyComplexWarning, - canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) -from jax.sharding import Sharding, SingleDeviceSharding + NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, + ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) +from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, + PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum @@ -8955,6 +8955,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: ... @overload @@ -8967,6 +8968,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: ... def einsum( @@ -8977,6 +8979,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: """Einstein summation @@ -9208,11 +9211,11 @@ def einsum( contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) - einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) if spec is not None: einsum = jax.named_call(einsum, name=spec) return einsum(operands, contractions, precision, - preferred_element_type, _dot_general) + preferred_element_type, _dot_general, out_type) # Enable other modules to override einsum_contact_path. @@ -9311,7 +9314,12 @@ def _einsum( precision, preferred_element_type, _dot_general=lax.dot_general, + out_type=None, ): + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError( + "`out_type` argument of `einsum` only supports NamedSharding instances." + " Please file a bug if this is not enough for your use case.") dtypes.check_user_dtype_supported(preferred_element_type, "einsum") operands = list(map(asarray, operands)) if preferred_element_type is None: @@ -9434,12 +9442,21 @@ def _einsum( if names == result_names: dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch)) operand = _dot_general(rhs, lhs, dimension_numbers, precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_type=out_type) else: names = batch_names_str + remaining_lhs_names + remaining_rhs_names + if (config.sharding_in_types.value and out_type is not None and + names != result_names): + spec = out_type.spec + inverse_spec = tuple(spec[result_names.index(name)] for name in names) + dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec)) + else: + dot_general_out_type = out_type # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) operand = _dot_general(lhs, rhs, dimension_numbers, precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_type=dot_general_out_type) else: raise NotImplementedError # if this is actually reachable, open an issue! @@ -9452,7 +9469,8 @@ def _einsum( operand = lax.transpose(operand, perm) operands.append(operand) # used in next iteration - return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type) + return lax_internal._convert_element_type(operands[0], preferred_element_type, + output_weak_type) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b0a2b4dbc..79919e638 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2089,10 +2089,11 @@ def _dot_general_lowering( b, *, dimension_numbers, + out_type, precision, preferred_element_type, ): - del preferred_element_type # Unused. + del preferred_element_type, out_type # Unused. ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index a5cfa5f9b..dcf9cafb5 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2180,7 +2180,7 @@ def _conv_general_dilated(lhs, rhs, *, tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated -def _dot_general(lhs, rhs, *, dimension_numbers, +def _dot_general(lhs, rhs, *, dimension_numbers, out_type, precision: lax_internal.CanonicalPrecision, preferred_element_type: DType | None, _in_avals: Sequence[core.ShapedArray], diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index f65f7b0a1..477f63474 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -606,8 +606,11 @@ mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun( bcoo_dot_general_p = core.Primitive('bcoo_dot_general') -def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers, - precision: None = None, preferred_element_type: None = None) -> BCOO | Array: +def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, + dimension_numbers: DotDimensionNumbers, + precision: None = None, + preferred_element_type: None = None, + out_type=None) -> BCOO | Array: """A general contraction operation. Args: @@ -625,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: the result will be dense, of type ndarray. """ # TODO(jakevdp) make use of these? - del precision # unused + del precision, out_type # unused if isinstance(lhs, BCOO) and isinstance(rhs, BCOO): shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers) @@ -1051,7 +1054,8 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers) indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True) kwds = {'dimension_numbers': dimension_numbers, 'precision': None, - 'preferred_element_type': None} + 'preferred_element_type': None, + 'out_type': None} A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds) return A, B, indices diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 8aa7d80c7..372bce034 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -462,7 +462,8 @@ bcsr_dot_general_p = core.Primitive('bcsr_dot_general') def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, - preferred_element_type: None = None) -> Array: + preferred_element_type: None = None, + out_type=None) -> Array: """A general contraction operation. Args: @@ -479,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray. """ - del precision # unused + del precision, out_type # unused if isinstance(rhs, (np.ndarray, jax.Array)): if isinstance(lhs, (np.ndarray, jax.Array)): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers, diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 7ef1ed781..2cb765676 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -111,4 +111,4 @@ def _dot_general_validated_shape( rhs = core.ShapedArray(rhs_shape, np.float32) return _dot_general_shape_rule( lhs, rhs, dimension_numbers=dimension_numbers, - precision=None, preferred_element_type=None) + precision=None, preferred_element_type=None, out_type=None) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d3b96676a..fd65c7953 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4945,6 +4945,60 @@ class ShardingInTypesTest(jtu.JaxTestCase): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + def test_einsum_with_out_type(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertArraysEqual(out, np_inp @ np_inp.T) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + lowered_text = f.lower(arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + + @jax.jit + def g(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr4 = jax.device_put(np_inp.T, NamedSharding(mesh, P('x', 'y'))) + out2 = g(arr3, arr4) + self.assertArraysEqual(out2, np_inp @ np_inp.T) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + + def test_einsum_inverse(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(64) + + @jax.jit + def h(x, y): + s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None)) + out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s) + self.assertEqual(out.sharding.spec, s.spec) + return out + + arr1 = jax.device_put(np_inp.reshape(8, 4, 2), + NamedSharding(mesh, P('x', 'y', None))) + arr2 = jax.device_put(np_inp.reshape(2, 4, 8), + NamedSharding(mesh, P(None, 'x', 'y'))) + out = h(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None))) + + lowered_text = h.lower(arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):