Fix validation code in lax.conv (#3279)

This commit is contained in:
Jake Vanderplas 2020-06-03 10:33:19 -07:00 committed by GitHub
parent b998044ffe
commit 0db57cb541
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2313,13 +2313,14 @@ def _conv_general_dilated_shape_rule(
msg = ("conv_general_dilated batch_group_count must divide lhs batch "
"dimension size, but {} does not divide {}.")
raise ValueError(msg.format(batch_group_count, lhs_batch_count))
if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
msg = ("conv_general_dilated rhs output feature dimension size must be a "
"multiple of batch_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
batch_group_count))
if not batch_group_count > 0 and feature_group_count > 0:
if batch_group_count > 1 and feature_group_count > 1:
msg = ("At most one of batch_group_count and feature_group_count may be > "
"1, got batch_group_count={} and feature_group_count={}")
raise ValueError(msg.format(batch_group_count, feature_group_count))