mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove checks since they are redundant and we can change out_aval because of various reasons
PiperOrigin-RevId: 721535417
This commit is contained in:
parent
9107ee4a22
commit
1f33cad321
@ -4063,8 +4063,6 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
**algorithm_kwarg,
|
||||
)
|
||||
if config.sharding_in_types.value:
|
||||
if out_sharding is not None:
|
||||
assert aval_out.sharding == out_sharding
|
||||
result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
|
||||
if accumulation_aval.dtype != aval_out.dtype:
|
||||
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
|
||||
@ -4532,8 +4530,6 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
|
||||
out = mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
if config.sharding_in_types.value:
|
||||
if sharding is not None:
|
||||
assert sharding == aval_out.sharding, (sharding, aval_out.sharding)
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
@ -5154,8 +5150,6 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
|
||||
out = mlir.reshape(ctx, x, aval_out)
|
||||
if config.sharding_in_types.value:
|
||||
if sharding is not None:
|
||||
assert sharding == aval_out.sharding
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user