mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
This commit is contained in:
parent
49224d6cdb
commit
97cd748376
@ -1305,7 +1305,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
||||
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
out_type=None) -> Array:
|
||||
out_sharding=None) -> Array:
|
||||
"""General dot product/contraction operator.
|
||||
|
||||
Wraps XLA's `DotGeneral
|
||||
@ -1351,12 +1351,12 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
|
||||
non-contracting/non-batch dimensions.
|
||||
"""
|
||||
if out_type is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_type only works when sharding_in_types "
|
||||
if out_sharding is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_sharding only works when sharding_in_types "
|
||||
"config is True.")
|
||||
if out_type is not None and not isinstance(out_type, NamedSharding):
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
'`out_type` argument of `dot_general` only supports NamedSharding '
|
||||
'`out_sharding` argument of `dot_general` only supports NamedSharding '
|
||||
'instances. Please file a bug if this is not enough for your use case.')
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
||||
@ -1370,7 +1370,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
dimension_numbers=(cdims, bdims),
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=out_type)
|
||||
out_sharding=out_sharding)
|
||||
|
||||
|
||||
def ragged_dot(
|
||||
@ -3456,8 +3456,8 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
|
||||
|
||||
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
if out_type is not None and not isinstance(out_type, NamedSharding):
|
||||
out_sharding):
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
@ -3536,15 +3536,15 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
|
||||
|
||||
def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
out_sharding):
|
||||
if lhs.sharding.mesh != rhs.sharding.mesh:
|
||||
raise ValueError(
|
||||
'Mesh of both lhs and rhs should match. Got lhs:'
|
||||
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')
|
||||
|
||||
if out_type is not None:
|
||||
assert isinstance(out_type, NamedSharding)
|
||||
return out_type
|
||||
if out_sharding is not None:
|
||||
assert isinstance(out_sharding, NamedSharding)
|
||||
return out_sharding
|
||||
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
|
||||
@ -3580,8 +3580,8 @@ def tuple_delete(tup, idx):
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
if out_type is not None and not isinstance(out_type, NamedSharding):
|
||||
out_sharding):
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError
|
||||
del dimension_numbers # unused
|
||||
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
|
||||
@ -3629,7 +3629,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
|
||||
|
||||
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type, swap_ans=False):
|
||||
out_sharding, swap_ans=False):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = x.aval.ndim
|
||||
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
||||
@ -3650,7 +3650,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
ds = None
|
||||
dot_general_out = dot_general(g, y, dims, precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=ds)
|
||||
out_sharding=ds)
|
||||
x_bar = transpose(dot_general_out, tuple(out_axes))
|
||||
if x_bar.dtype != x.aval.dtype:
|
||||
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
|
||||
@ -3658,12 +3658,12 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
|
||||
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
out_sharding):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
y_bar = _dot_general_transpose_lhs(
|
||||
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
||||
preferred_element_type=preferred_element_type, out_type=out_type,
|
||||
preferred_element_type=preferred_element_type, out_sharding=out_sharding,
|
||||
swap_ans=True)
|
||||
if y_bar.dtype != y.aval.dtype:
|
||||
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
|
||||
@ -3678,7 +3678,7 @@ def _dot_batch_rule(
|
||||
batch_dims,
|
||||
*,
|
||||
dimension_numbers,
|
||||
out_type,
|
||||
out_sharding,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
@ -3708,8 +3708,8 @@ def _dot_batch_rule(
|
||||
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
|
||||
else:
|
||||
rhs_shape = np.shape(rhs)
|
||||
if out_type is not None:
|
||||
raise NotImplementedError("vmap with out_type is not supported. "
|
||||
if out_sharding is not None:
|
||||
raise NotImplementedError("vmap with out_sharding is not supported. "
|
||||
"Please open an issue.")
|
||||
batched_out = invoke_prim(
|
||||
lhs,
|
||||
@ -3717,7 +3717,7 @@ def _dot_batch_rule(
|
||||
new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=out_type,
|
||||
out_sharding=out_sharding,
|
||||
)
|
||||
result_batch_dim = batching.shape_as_bdim(
|
||||
result_stack_dim,
|
||||
@ -3939,7 +3939,7 @@ def get_algorithm_compute_types(
|
||||
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision, preferred_element_type: np.dtype | None,
|
||||
out_type, platform: str = "default"):
|
||||
out_sharding, platform: str = "default"):
|
||||
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
||||
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
||||
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
||||
@ -4028,8 +4028,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
**algorithm_kwarg,
|
||||
)
|
||||
if config.sharding_in_types.value:
|
||||
if out_type is not None:
|
||||
assert aval_out.sharding == out_type
|
||||
if out_sharding is not None:
|
||||
assert aval_out.sharding == out_sharding
|
||||
result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
|
||||
if accumulation_aval.dtype != aval_out.dtype:
|
||||
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
|
||||
@ -4090,7 +4090,7 @@ def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
||||
return _dot_general_dtype_rule(
|
||||
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision, preferred_element_type=preferred_element_type,
|
||||
out_type=None)
|
||||
out_sharding=None)
|
||||
|
||||
|
||||
def _ragged_dot_jvp_rule(
|
||||
@ -4213,9 +4213,9 @@ def _ragged_dot_invoke_prim(
|
||||
new_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
out_type,
|
||||
out_sharding,
|
||||
):
|
||||
del out_type
|
||||
del out_sharding
|
||||
return ragged_dot(
|
||||
lhs,
|
||||
rhs,
|
||||
@ -4244,7 +4244,7 @@ def _ragged_dot_batch_rule(
|
||||
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=None,
|
||||
out_sharding=None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -9508,7 +9508,7 @@ def einsum(
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_type=None,
|
||||
out_sharding=None,
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
@ -9521,7 +9521,7 @@ def einsum(
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_type=None,
|
||||
out_sharding=None,
|
||||
) -> Array: ...
|
||||
|
||||
@export
|
||||
@ -9533,7 +9533,7 @@ def einsum(
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_type=None,
|
||||
out_sharding=None,
|
||||
) -> Array:
|
||||
"""Einstein summation
|
||||
|
||||
@ -9769,7 +9769,7 @@ def einsum(
|
||||
if spec is not None:
|
||||
einsum = jax.named_call(einsum, name=spec)
|
||||
return einsum(operands, contractions, precision,
|
||||
preferred_element_type, _dot_general, out_type)
|
||||
preferred_element_type, _dot_general, out_sharding)
|
||||
|
||||
|
||||
# Enable other modules to override einsum_contact_path.
|
||||
@ -9869,16 +9869,16 @@ def _einsum(
|
||||
precision,
|
||||
preferred_element_type,
|
||||
_dot_general=lax.dot_general,
|
||||
out_type=None,
|
||||
out_sharding=None,
|
||||
):
|
||||
if out_type is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_type only works when sharding_in_types "
|
||||
if out_sharding is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_sharding only works when sharding_in_types "
|
||||
"config is True.")
|
||||
out_type = canonicalize_sharding(out_type)
|
||||
if out_type is not None and not isinstance(out_type, NamedSharding):
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
"`out_type` argument of `einsum` only supports NamedSharding instances."
|
||||
" Please file a bug if this is not enough for your use case.")
|
||||
"`out_sharding` argument of `einsum` only supports NamedSharding"
|
||||
" instances. Please file a bug if this is not enough for your use case.")
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
|
||||
operands = list(map(asarray, operands))
|
||||
if preferred_element_type is None:
|
||||
@ -10000,25 +10000,27 @@ def _einsum(
|
||||
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
|
||||
if names == result_names:
|
||||
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
|
||||
k_out_type = {} if out_type is None else {'out_type': out_type}
|
||||
k_out_sharding = ({} if out_sharding is None else
|
||||
{'out_sharding': out_sharding})
|
||||
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
**k_out_type)
|
||||
**k_out_sharding)
|
||||
else:
|
||||
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
|
||||
if (config.sharding_in_types.value and out_type is not None and
|
||||
if (config.sharding_in_types.value and out_sharding is not None and
|
||||
names != result_names):
|
||||
spec = out_type.spec
|
||||
spec = out_sharding.spec
|
||||
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
|
||||
dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec))
|
||||
dot_general_out_sharding = NamedSharding(out_sharding.mesh,
|
||||
P(*inverse_spec))
|
||||
else:
|
||||
dot_general_out_type = out_type # type: ignore
|
||||
dot_general_out_sharding = out_sharding # type: ignore
|
||||
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
|
||||
dot_general_out_type = ({} if dot_general_out_type is None else # type: ignore
|
||||
{'out_type': dot_general_out_type})
|
||||
dot_general_out_sharding = ({} if dot_general_out_sharding is None else # type: ignore
|
||||
{'out_sharding': dot_general_out_sharding})
|
||||
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
**dot_general_out_type)
|
||||
**dot_general_out_sharding)
|
||||
else:
|
||||
raise NotImplementedError # if this is actually reachable, open an issue!
|
||||
|
||||
|
@ -2019,11 +2019,11 @@ def _dot_general_lowering(
|
||||
b,
|
||||
*,
|
||||
dimension_numbers,
|
||||
out_type,
|
||||
out_sharding,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
):
|
||||
del preferred_element_type, out_type # Unused.
|
||||
del preferred_element_type, out_sharding # Unused.
|
||||
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
|
||||
assert batch_dims == ((), ())
|
||||
|
||||
|
@ -364,7 +364,7 @@ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
precision: tuple[PrecisionType, PrecisionType] | None,
|
||||
preferred_element_type: DType | None,
|
||||
out_type=None,
|
||||
out_sharding=None,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
|
@ -2168,7 +2168,7 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
precision: lax_internal.CanonicalPrecision,
|
||||
preferred_element_type: DType | None,
|
||||
out_type=None,
|
||||
out_sharding=None,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
|
@ -610,7 +610,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None,
|
||||
preferred_element_type: None = None,
|
||||
out_type=None) -> BCOO | Array:
|
||||
out_sharding=None) -> BCOO | Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -628,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
|
||||
the result will be dense, of type ndarray.
|
||||
"""
|
||||
# TODO(jakevdp) make use of these?
|
||||
del precision, out_type # unused
|
||||
del precision, out_sharding # unused
|
||||
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
|
||||
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
|
||||
dimension_numbers)
|
||||
@ -1055,7 +1055,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
|
||||
kwds = {'dimension_numbers': dimension_numbers,
|
||||
'precision': None,
|
||||
'preferred_element_type': None,
|
||||
'out_type': None}
|
||||
'out_sharding': None}
|
||||
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
|
||||
return A, B, indices
|
||||
|
||||
|
@ -463,7 +463,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None,
|
||||
preferred_element_type: None = None,
|
||||
out_type=None) -> Array:
|
||||
out_sharding=None) -> Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -480,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
are sparse, the result will be sparse, of type BCSR. If either input is
|
||||
dense, the result will be dense, of type ndarray.
|
||||
"""
|
||||
del precision, out_type # unused
|
||||
del precision, out_sharding # unused
|
||||
if isinstance(rhs, (np.ndarray, jax.Array)):
|
||||
if isinstance(lhs, (np.ndarray, jax.Array)):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
|
@ -111,4 +111,4 @@ def _dot_general_validated_shape(
|
||||
rhs = core.ShapedArray(rhs_shape, np.float32)
|
||||
return _dot_general_shape_rule(
|
||||
lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
precision=None, preferred_element_type=None, out_type=None)
|
||||
precision=None, preferred_element_type=None, out_sharding=None)
|
||||
|
@ -372,7 +372,7 @@ def einsum(
|
||||
preferred_element_type: DTypeLike | None = ...,
|
||||
_use_xeinsum: builtins.bool = False,
|
||||
_dot_general: Callable[..., Array] = ...,
|
||||
out_type: NamedSharding | P | None = ...,
|
||||
out_sharding: NamedSharding | P | None = ...,
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
@ -386,7 +386,7 @@ def einsum(
|
||||
preferred_element_type: DTypeLike | None = ...,
|
||||
_use_xeinsum: builtins.bool = False,
|
||||
_dot_general: Callable[..., Array] = ...,
|
||||
out_type: NamedSharding | P | None = ...,
|
||||
out_sharding: NamedSharding | P | None = ...,
|
||||
) -> Array: ...
|
||||
@overload
|
||||
def einsum(
|
||||
@ -398,7 +398,7 @@ def einsum(
|
||||
preferred_element_type: DTypeLike | None = ...,
|
||||
_use_xeinsum: builtins.bool = ...,
|
||||
_dot_general: Callable[..., Array] = ...,
|
||||
out_type: NamedSharding | P | None = ...,
|
||||
out_sharding: NamedSharding | P | None = ...,
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
|
@ -4906,14 +4906,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out[1].sharding, arr2.sharding)
|
||||
|
||||
@jtu.with_user_mesh((4,), ('x',))
|
||||
def test_dot_general_out_type(self, mesh):
|
||||
def test_dot_general_out_sharding(self, mesh):
|
||||
np_inp1 = np.arange(16.).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
return jnp.sum(out)
|
||||
|
||||
@ -5217,7 +5217,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_einsum_with_out_type(self, mesh):
|
||||
def test_einsum_with_out_sharding(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
|
||||
@ -5225,7 +5225,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
out_sharding=NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
@ -5238,7 +5238,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
@ -5268,7 +5268,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def h(x, y):
|
||||
spec = P('x', None, 'y', None)
|
||||
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=spec)
|
||||
out = jnp.einsum('btd,dhq->bhtq', x, y, out_sharding=spec)
|
||||
self.assertEqual(out.sharding.spec, spec)
|
||||
return out
|
||||
|
||||
@ -5963,7 +5963,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_type=NamedSharding(auto_mesh, P(None, None)))
|
||||
out_sharding=NamedSharding(auto_mesh, P(None, None)))
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -5973,7 +5973,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
with hidden_axes('x'):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -6046,7 +6046,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P('x', 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P('x', None))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', None))
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
@ -6076,7 +6076,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P('x', 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P('x', 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', 'y'))
|
||||
self.assertEqual(a.sharding.spec, P('x', 'y'))
|
||||
return a
|
||||
|
||||
@ -6102,7 +6102,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_type=P(None, 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P(None, 'y'))
|
||||
self.assertEqual(a.sharding.spec, P(None, 'y'))
|
||||
return a
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user