Rename out_type -> out_sharding parameter on einsum

PiperOrigin-RevId: 716454800
This commit is contained in:
Yash Katariya 2025-01-16 18:16:12 -08:00 committed by jax authors
parent 49224d6cdb
commit 97cd748376
10 changed files with 75 additions and 73 deletions

View File

@ -1305,7 +1305,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
out_type=None) -> Array:
out_sharding=None) -> Array:
"""General dot product/contraction operator.
Wraps XLA's `DotGeneral
@ -1351,12 +1351,12 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
non-contracting/non-batch dimensions.
"""
if out_type is not None and not config.sharding_in_types.value:
raise NotImplementedError("out_type only works when sharding_in_types "
if out_sharding is not None and not config.sharding_in_types.value:
raise NotImplementedError("out_sharding only works when sharding_in_types "
"config is True.")
if out_type is not None and not isinstance(out_type, NamedSharding):
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError(
'`out_type` argument of `dot_general` only supports NamedSharding '
'`out_sharding` argument of `dot_general` only supports NamedSharding '
'instances. Please file a bug if this is not enough for your use case.')
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
@ -1370,7 +1370,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
dimension_numbers=(cdims, bdims),
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
out_type=out_type)
out_sharding=out_sharding)
def ragged_dot(
@ -3456,8 +3456,8 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
out_type):
if out_type is not None and not isinstance(out_type, NamedSharding):
out_sharding):
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
@ -3536,15 +3536,15 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
out_type):
out_sharding):
if lhs.sharding.mesh != rhs.sharding.mesh:
raise ValueError(
'Mesh of both lhs and rhs should match. Got lhs:'
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')
if out_type is not None:
assert isinstance(out_type, NamedSharding)
return out_type
if out_sharding is not None:
assert isinstance(out_sharding, NamedSharding)
return out_sharding
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
@ -3580,8 +3580,8 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
out_type):
if out_type is not None and not isinstance(out_type, NamedSharding):
out_sharding):
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError
del dimension_numbers # unused
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
@ -3629,7 +3629,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
out_type, swap_ans=False):
out_sharding, swap_ans=False):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim
x_kept = remaining(range(x_ndim), x_contract, x_batch)
@ -3650,7 +3650,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
ds = None
dot_general_out = dot_general(g, y, dims, precision=precision,
preferred_element_type=preferred_element_type,
out_type=ds)
out_sharding=ds)
x_bar = transpose(dot_general_out, tuple(out_axes))
if x_bar.dtype != x.aval.dtype:
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
@ -3658,12 +3658,12 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
out_type):
out_sharding):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
y_bar = _dot_general_transpose_lhs(
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
preferred_element_type=preferred_element_type, out_type=out_type,
preferred_element_type=preferred_element_type, out_sharding=out_sharding,
swap_ans=True)
if y_bar.dtype != y.aval.dtype:
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
@ -3678,7 +3678,7 @@ def _dot_batch_rule(
batch_dims,
*,
dimension_numbers,
out_type,
out_sharding,
precision,
preferred_element_type: DTypeLike | None,
**_,
@ -3708,8 +3708,8 @@ def _dot_batch_rule(
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = np.shape(rhs)
if out_type is not None:
raise NotImplementedError("vmap with out_type is not supported. "
if out_sharding is not None:
raise NotImplementedError("vmap with out_sharding is not supported. "
"Please open an issue.")
batched_out = invoke_prim(
lhs,
@ -3717,7 +3717,7 @@ def _dot_batch_rule(
new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_type=out_type,
out_sharding=out_sharding,
)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
@ -3939,7 +3939,7 @@ def get_algorithm_compute_types(
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: np.dtype | None,
out_type, platform: str = "default"):
out_sharding, platform: str = "default"):
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
@ -4028,8 +4028,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
**algorithm_kwarg,
)
if config.sharding_in_types.value:
if out_type is not None:
assert aval_out.sharding == out_type
if out_sharding is not None:
assert aval_out.sharding == out_sharding
result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
if accumulation_aval.dtype != aval_out.dtype:
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
@ -4090,7 +4090,7 @@ def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
return _dot_general_dtype_rule(
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision, preferred_element_type=preferred_element_type,
out_type=None)
out_sharding=None)
def _ragged_dot_jvp_rule(
@ -4213,9 +4213,9 @@ def _ragged_dot_invoke_prim(
new_dimension_numbers,
precision,
preferred_element_type,
out_type,
out_sharding,
):
del out_type
del out_sharding
return ragged_dot(
lhs,
rhs,
@ -4244,7 +4244,7 @@ def _ragged_dot_batch_rule(
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision,
preferred_element_type=preferred_element_type,
out_type=None,
out_sharding=None,
)

View File

@ -9508,7 +9508,7 @@ def einsum(
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
out_type=None,
out_sharding=None,
) -> Array: ...
@overload
@ -9521,7 +9521,7 @@ def einsum(
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
out_type=None,
out_sharding=None,
) -> Array: ...
@export
@ -9533,7 +9533,7 @@ def einsum(
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
out_type=None,
out_sharding=None,
) -> Array:
"""Einstein summation
@ -9769,7 +9769,7 @@ def einsum(
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
return einsum(operands, contractions, precision,
preferred_element_type, _dot_general, out_type)
preferred_element_type, _dot_general, out_sharding)
# Enable other modules to override einsum_contact_path.
@ -9869,16 +9869,16 @@ def _einsum(
precision,
preferred_element_type,
_dot_general=lax.dot_general,
out_type=None,
out_sharding=None,
):
if out_type is not None and not config.sharding_in_types.value:
raise NotImplementedError("out_type only works when sharding_in_types "
if out_sharding is not None and not config.sharding_in_types.value:
raise NotImplementedError("out_sharding only works when sharding_in_types "
"config is True.")
out_type = canonicalize_sharding(out_type)
if out_type is not None and not isinstance(out_type, NamedSharding):
out_sharding = canonicalize_sharding(out_sharding)
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError(
"`out_type` argument of `einsum` only supports NamedSharding instances."
" Please file a bug if this is not enough for your use case.")
"`out_sharding` argument of `einsum` only supports NamedSharding"
" instances. Please file a bug if this is not enough for your use case.")
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
operands = list(map(asarray, operands))
if preferred_element_type is None:
@ -10000,25 +10000,27 @@ def _einsum(
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
if names == result_names:
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
k_out_type = {} if out_type is None else {'out_type': out_type}
k_out_sharding = ({} if out_sharding is None else
{'out_sharding': out_sharding})
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
preferred_element_type=preferred_element_type,
**k_out_type)
**k_out_sharding)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
if (config.sharding_in_types.value and out_type is not None and
if (config.sharding_in_types.value and out_sharding is not None and
names != result_names):
spec = out_type.spec
spec = out_sharding.spec
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec))
dot_general_out_sharding = NamedSharding(out_sharding.mesh,
P(*inverse_spec))
else:
dot_general_out_type = out_type # type: ignore
dot_general_out_sharding = out_sharding # type: ignore
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
dot_general_out_type = ({} if dot_general_out_type is None else # type: ignore
{'out_type': dot_general_out_type})
dot_general_out_sharding = ({} if dot_general_out_sharding is None else # type: ignore
{'out_sharding': dot_general_out_sharding})
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
preferred_element_type=preferred_element_type,
**dot_general_out_type)
**dot_general_out_sharding)
else:
raise NotImplementedError # if this is actually reachable, open an issue!

View File

@ -2019,11 +2019,11 @@ def _dot_general_lowering(
b,
*,
dimension_numbers,
out_type,
out_sharding,
precision,
preferred_element_type,
):
del preferred_element_type, out_type # Unused.
del preferred_element_type, out_sharding # Unused.
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
assert batch_dims == ((), ())

View File

@ -364,7 +364,7 @@ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: tuple[PrecisionType, PrecisionType] | None,
preferred_element_type: DType | None,
out_type=None,
out_sharding=None,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""

View File

@ -2168,7 +2168,7 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: lax_internal.CanonicalPrecision,
preferred_element_type: DType | None,
out_type=None,
out_sharding=None,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""

View File

@ -610,7 +610,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None,
out_type=None) -> BCOO | Array:
out_sharding=None) -> BCOO | Array:
"""A general contraction operation.
Args:
@ -628,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
the result will be dense, of type ndarray.
"""
# TODO(jakevdp) make use of these?
del precision, out_type # unused
del precision, out_sharding # unused
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
dimension_numbers)
@ -1055,7 +1055,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
kwds = {'dimension_numbers': dimension_numbers,
'precision': None,
'preferred_element_type': None,
'out_type': None}
'out_sharding': None}
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
return A, B, indices

View File

@ -463,7 +463,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None,
out_type=None) -> Array:
out_sharding=None) -> Array:
"""A general contraction operation.
Args:
@ -480,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision, out_type # unused
del precision, out_sharding # unused
if isinstance(rhs, (np.ndarray, jax.Array)):
if isinstance(lhs, (np.ndarray, jax.Array)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,

View File

@ -111,4 +111,4 @@ def _dot_general_validated_shape(
rhs = core.ShapedArray(rhs_shape, np.float32)
return _dot_general_shape_rule(
lhs, rhs, dimension_numbers=dimension_numbers,
precision=None, preferred_element_type=None, out_type=None)
precision=None, preferred_element_type=None, out_sharding=None)

View File

@ -372,7 +372,7 @@ def einsum(
preferred_element_type: DTypeLike | None = ...,
_use_xeinsum: builtins.bool = False,
_dot_general: Callable[..., Array] = ...,
out_type: NamedSharding | P | None = ...,
out_sharding: NamedSharding | P | None = ...,
) -> Array: ...
@overload
@ -386,7 +386,7 @@ def einsum(
preferred_element_type: DTypeLike | None = ...,
_use_xeinsum: builtins.bool = False,
_dot_general: Callable[..., Array] = ...,
out_type: NamedSharding | P | None = ...,
out_sharding: NamedSharding | P | None = ...,
) -> Array: ...
@overload
def einsum(
@ -398,7 +398,7 @@ def einsum(
preferred_element_type: DTypeLike | None = ...,
_use_xeinsum: builtins.bool = ...,
_dot_general: Callable[..., Array] = ...,
out_type: NamedSharding | P | None = ...,
out_sharding: NamedSharding | P | None = ...,
) -> Array: ...
@overload

View File

@ -4906,14 +4906,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out[1].sharding, arr2.sharding)
@jtu.with_user_mesh((4,), ('x',))
def test_dot_general_out_type(self, mesh):
def test_dot_general_out_sharding(self, mesh):
np_inp1 = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
self.assertEqual(out.sharding.spec, P('x', None))
return jnp.sum(out)
@ -5217,7 +5217,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_einsum_with_out_type(self, mesh):
def test_einsum_with_out_sharding(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
@ -5225,7 +5225,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
out_sharding=NamedSharding(x.sharding.mesh, P('x', None)))
self.assertEqual(out.sharding.spec, P('x', None))
return out
@ -5238,7 +5238,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x, y):
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
self.assertEqual(out.sharding.spec, P('x', None))
return out
@ -5268,7 +5268,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def h(x, y):
spec = P('x', None, 'y', None)
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=spec)
out = jnp.einsum('btd,dhq->bhtq', x, y, out_sharding=spec)
self.assertEqual(out.sharding.spec, spec)
return out
@ -5963,7 +5963,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_type=NamedSharding(auto_mesh, P(None, None)))
out_sharding=NamedSharding(auto_mesh, P(None, None)))
return out
with self.assertRaisesRegex(
@ -5973,7 +5973,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x, y):
with hidden_axes('x'):
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
return out
with self.assertRaisesRegex(
@ -6046,7 +6046,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P('x', None))
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', None))
self.assertEqual(a.sharding.spec, P('x', None))
return a
@ -6076,7 +6076,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P('x', 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', 'y'))
self.assertEqual(a.sharding.spec, P('x', 'y'))
return a
@ -6102,7 +6102,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P(None, 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P(None, 'y'))
self.assertEqual(a.sharding.spec, P(None, 'y'))
return a