diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 93e2df3b8..c4f1ddcc0 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1305,7 +1305,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]], def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, - out_type=None) -> Array: + out_sharding=None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -1351,12 +1351,12 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ - if out_type is not None and not config.sharding_in_types.value: - raise NotImplementedError("out_type only works when sharding_in_types " + if out_sharding is not None and not config.sharding_in_types.value: + raise NotImplementedError("out_sharding only works when sharding_in_types " "config is True.") - if out_type is not None and not isinstance(out_type, NamedSharding): + if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError( - '`out_type` argument of `dot_general` only supports NamedSharding ' + '`out_sharding` argument of `dot_general` only supports NamedSharding ' 'instances. Please file a bug if this is not enough for your use case.') (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), @@ -1370,7 +1370,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), preferred_element_type=preferred_element_type, - out_type=out_type) + out_sharding=out_sharding) def ragged_dot( @@ -3456,8 +3456,8 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_type): - if out_type is not None and not isinstance(out_type, NamedSharding): + out_sharding): + if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) @@ -3536,15 +3536,15 @@ def _check_specs_match(lhs_spec, rhs_spec, msg): def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_type): + out_sharding): if lhs.sharding.mesh != rhs.sharding.mesh: raise ValueError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') - if out_type is not None: - assert isinstance(out_type, NamedSharding) - return out_type + if out_sharding is not None: + assert isinstance(out_sharding, NamedSharding) + return out_sharding (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) @@ -3580,8 +3580,8 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_type): - if out_type is not None and not isinstance(out_type, NamedSharding): + out_sharding): + if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError del dimension_numbers # unused # We're mostly matching XLA's logic here, namely in shape_inference.cc and @@ -3629,7 +3629,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_type, swap_ans=False): + out_sharding, swap_ans=False): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim x_kept = remaining(range(x_ndim), x_contract, x_batch) @@ -3650,7 +3650,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, ds = None dot_general_out = dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type, - out_type=ds) + out_sharding=ds) x_bar = transpose(dot_general_out, tuple(out_axes)) if x_bar.dtype != x.aval.dtype: x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type) @@ -3658,12 +3658,12 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_type): + out_sharding): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) y_bar = _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type, out_type=out_type, + preferred_element_type=preferred_element_type, out_sharding=out_sharding, swap_ans=True) if y_bar.dtype != y.aval.dtype: y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) @@ -3678,7 +3678,7 @@ def _dot_batch_rule( batch_dims, *, dimension_numbers, - out_type, + out_sharding, precision, preferred_element_type: DTypeLike | None, **_, @@ -3708,8 +3708,8 @@ def _dot_batch_rule( rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) else: rhs_shape = np.shape(rhs) - if out_type is not None: - raise NotImplementedError("vmap with out_type is not supported. " + if out_sharding is not None: + raise NotImplementedError("vmap with out_sharding is not supported. " "Please open an issue.") batched_out = invoke_prim( lhs, @@ -3717,7 +3717,7 @@ def _dot_batch_rule( new_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, - out_type=out_type, + out_sharding=out_sharding, ) result_batch_dim = batching.shape_as_bdim( result_stack_dim, @@ -3939,7 +3939,7 @@ def get_algorithm_compute_types( def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, - out_type, platform: str = "default"): + out_sharding, platform: str = "default"): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -4028,8 +4028,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, **algorithm_kwarg, ) if config.sharding_in_types.value: - if out_type is not None: - assert aval_out.sharding == out_type + if out_sharding is not None: + assert aval_out.sharding == out_sharding result = mlir.lower_sharding_under_shit(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) @@ -4090,7 +4090,7 @@ def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, return _dot_general_dtype_rule( lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type, - out_type=None) + out_sharding=None) def _ragged_dot_jvp_rule( @@ -4213,9 +4213,9 @@ def _ragged_dot_invoke_prim( new_dimension_numbers, precision, preferred_element_type, - out_type, + out_sharding, ): - del out_type + del out_sharding return ragged_dot( lhs, rhs, @@ -4244,7 +4244,7 @@ def _ragged_dot_batch_rule( dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type, - out_type=None, + out_sharding=None, ) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8f078ceee..764c55365 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9508,7 +9508,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, - out_type=None, + out_sharding=None, ) -> Array: ... @overload @@ -9521,7 +9521,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, - out_type=None, + out_sharding=None, ) -> Array: ... @export @@ -9533,7 +9533,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, - out_type=None, + out_sharding=None, ) -> Array: """Einstein summation @@ -9769,7 +9769,7 @@ def einsum( if spec is not None: einsum = jax.named_call(einsum, name=spec) return einsum(operands, contractions, precision, - preferred_element_type, _dot_general, out_type) + preferred_element_type, _dot_general, out_sharding) # Enable other modules to override einsum_contact_path. @@ -9869,16 +9869,16 @@ def _einsum( precision, preferred_element_type, _dot_general=lax.dot_general, - out_type=None, + out_sharding=None, ): - if out_type is not None and not config.sharding_in_types.value: - raise NotImplementedError("out_type only works when sharding_in_types " + if out_sharding is not None and not config.sharding_in_types.value: + raise NotImplementedError("out_sharding only works when sharding_in_types " "config is True.") - out_type = canonicalize_sharding(out_type) - if out_type is not None and not isinstance(out_type, NamedSharding): + out_sharding = canonicalize_sharding(out_sharding) + if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError( - "`out_type` argument of `einsum` only supports NamedSharding instances." - " Please file a bug if this is not enough for your use case.") + "`out_sharding` argument of `einsum` only supports NamedSharding" + " instances. Please file a bug if this is not enough for your use case.") dtypes.check_user_dtype_supported(preferred_element_type, "einsum") operands = list(map(asarray, operands)) if preferred_element_type is None: @@ -10000,25 +10000,27 @@ def _einsum( names = batch_names_str + remaining_rhs_names + remaining_lhs_names if names == result_names: dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch)) - k_out_type = {} if out_type is None else {'out_type': out_type} + k_out_sharding = ({} if out_sharding is None else + {'out_sharding': out_sharding}) operand = _dot_general(rhs, lhs, dimension_numbers, precision, preferred_element_type=preferred_element_type, - **k_out_type) + **k_out_sharding) else: names = batch_names_str + remaining_lhs_names + remaining_rhs_names - if (config.sharding_in_types.value and out_type is not None and + if (config.sharding_in_types.value and out_sharding is not None and names != result_names): - spec = out_type.spec + spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) - dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec)) + dot_general_out_sharding = NamedSharding(out_sharding.mesh, + P(*inverse_spec)) else: - dot_general_out_type = out_type # type: ignore + dot_general_out_sharding = out_sharding # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) - dot_general_out_type = ({} if dot_general_out_type is None else # type: ignore - {'out_type': dot_general_out_type}) + dot_general_out_sharding = ({} if dot_general_out_sharding is None else # type: ignore + {'out_sharding': dot_general_out_sharding}) operand = _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type=preferred_element_type, - **dot_general_out_type) + **dot_general_out_sharding) else: raise NotImplementedError # if this is actually reachable, open an issue! diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index a87c8990e..72facd9a7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2019,11 +2019,11 @@ def _dot_general_lowering( b, *, dimension_numbers, - out_type, + out_sharding, precision, preferred_element_type, ): - del preferred_element_type, out_type # Unused. + del preferred_element_type, out_sharding # Unused. ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 0d8c95d42..644c3324b 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -364,7 +364,7 @@ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, - out_type=None, + out_sharding=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 317cd7313..242077bbf 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2168,7 +2168,7 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated def _dot_general(lhs, rhs, *, dimension_numbers, precision: lax_internal.CanonicalPrecision, preferred_element_type: DType | None, - out_type=None, + out_sharding=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 13ff61785..7f67fe1f8 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -610,7 +610,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, preferred_element_type: None = None, - out_type=None) -> BCOO | Array: + out_sharding=None) -> BCOO | Array: """A general contraction operation. Args: @@ -628,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, the result will be dense, of type ndarray. """ # TODO(jakevdp) make use of these? - del precision, out_type # unused + del precision, out_sharding # unused if isinstance(lhs, BCOO) and isinstance(rhs, BCOO): shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers) @@ -1055,7 +1055,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers) kwds = {'dimension_numbers': dimension_numbers, 'precision': None, 'preferred_element_type': None, - 'out_type': None} + 'out_sharding': None} A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds) return A, B, indices diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index ed7e53d4c..7fefd1572 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -463,7 +463,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, preferred_element_type: None = None, - out_type=None) -> Array: + out_sharding=None) -> Array: """A general contraction operation. Args: @@ -480,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray. """ - del precision, out_type # unused + del precision, out_sharding # unused if isinstance(rhs, (np.ndarray, jax.Array)): if isinstance(lhs, (np.ndarray, jax.Array)): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers, diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 2cb765676..36e9a9c51 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -111,4 +111,4 @@ def _dot_general_validated_shape( rhs = core.ShapedArray(rhs_shape, np.float32) return _dot_general_shape_rule( lhs, rhs, dimension_numbers=dimension_numbers, - precision=None, preferred_element_type=None, out_type=None) + precision=None, preferred_element_type=None, out_sharding=None) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 82ba29267..e491f8d7a 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -372,7 +372,7 @@ def einsum( preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., - out_type: NamedSharding | P | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @overload @@ -386,7 +386,7 @@ def einsum( preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., - out_type: NamedSharding | P | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @overload def einsum( @@ -398,7 +398,7 @@ def einsum( preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = ..., _dot_general: Callable[..., Array] = ..., - out_type: NamedSharding | P | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @overload diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5ad79ed38..180dea92f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4906,14 +4906,14 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out[1].sharding, arr2.sharding) @jtu.with_user_mesh((4,), ('x',)) - def test_dot_general_out_type(self, mesh): + def test_dot_general_out_sharding(self, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) @jax.jit def f(x, y): - out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None)) + out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None)) self.assertEqual(out.sharding.spec, P('x', None)) return jnp.sum(out) @@ -5217,7 +5217,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_einsum_with_out_type(self, mesh): + def test_einsum_with_out_sharding(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -5225,7 +5225,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x, y): out = jnp.einsum('xy,yz->xz', x, y, - out_type=NamedSharding(x.sharding.mesh, P('x', None))) + out_sharding=NamedSharding(x.sharding.mesh, P('x', None))) self.assertEqual(out.sharding.spec, P('x', None)) return out @@ -5238,7 +5238,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def g(x, y): - out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None)) + out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None)) self.assertEqual(out.sharding.spec, P('x', None)) return out @@ -5268,7 +5268,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def h(x, y): spec = P('x', None, 'y', None) - out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=spec) + out = jnp.einsum('btd,dhq->bhtq', x, y, out_sharding=spec) self.assertEqual(out.sharding.spec, spec) return out @@ -5963,7 +5963,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x, y): out = jnp.einsum('xy,yz->xz', x, y, - out_type=NamedSharding(auto_mesh, P(None, None))) + out_sharding=NamedSharding(auto_mesh, P(None, None))) return out with self.assertRaisesRegex( @@ -5973,7 +5973,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def g(x, y): with hidden_axes('x'): - out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None)) + out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None)) return out with self.assertRaisesRegex( @@ -6046,7 +6046,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(y.sharding.spec, P('x', 'y')) z = jnp.sin(y) self.assertEqual(z.sharding.spec, P('x', 'y')) - a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P('x', None)) + a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', None)) self.assertEqual(a.sharding.spec, P('x', None)) return a @@ -6076,7 +6076,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(y.sharding.spec, P('x', 'y')) z = jnp.sin(y) self.assertEqual(z.sharding.spec, P('x', 'y')) - a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P('x', 'y')) + a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', 'y')) self.assertEqual(a.sharding.spec, P('x', 'y')) return a @@ -6102,7 +6102,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y) self.assertEqual(z.sharding.spec, P(None, 'y')) - a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P(None, 'y')) + a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P(None, 'y')) self.assertEqual(a.sharding.spec, P(None, 'y')) return a