remove checks since they are redundant and we can change out_aval because of various reasons

PiperOrigin-RevId: 721535417
This commit is contained in:
Yash Katariya 2025-01-30 15:13:51 -08:00 committed by jax authors
parent 9107ee4a22
commit 1f33cad321

View File

@ -4063,8 +4063,6 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
**algorithm_kwarg,
)
if config.sharding_in_types.value:
if out_sharding is not None:
assert aval_out.sharding == out_sharding
result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
if accumulation_aval.dtype != aval_out.dtype:
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
@ -4532,8 +4530,6 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
out = mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=broadcast_dimensions)
if config.sharding_in_types.value:
if sharding is not None:
assert sharding == aval_out.sharding, (sharding, aval_out.sharding)
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
@ -5154,8 +5150,6 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
out = mlir.reshape(ctx, x, aval_out)
if config.sharding_in_types.value:
if sharding is not None:
assert sharding == aval_out.sharding
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]