mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix validation code in lax.conv (#3279)
This commit is contained in:
parent
b998044ffe
commit
0db57cb541
@ -2313,13 +2313,14 @@ def _conv_general_dilated_shape_rule(
|
||||
msg = ("conv_general_dilated batch_group_count must divide lhs batch "
|
||||
"dimension size, but {} does not divide {}.")
|
||||
raise ValueError(msg.format(batch_group_count, lhs_batch_count))
|
||||
if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
|
||||
|
||||
if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
|
||||
msg = ("conv_general_dilated rhs output feature dimension size must be a "
|
||||
"multiple of batch_group_count, but {} is not a multiple of {}.")
|
||||
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
||||
batch_group_count))
|
||||
|
||||
if not batch_group_count > 0 and feature_group_count > 0:
|
||||
if batch_group_count > 1 and feature_group_count > 1:
|
||||
msg = ("At most one of batch_group_count and feature_group_count may be > "
|
||||
"1, got batch_group_count={} and feature_group_count={}")
|
||||
raise ValueError(msg.format(batch_group_count, feature_group_count))
|
||||
|
Loading…
x
Reference in New Issue
Block a user