Prevent nans in scale_and_translate

fixes #6780
This commit is contained in:
Luke Pfister 2021-05-18 15:22:06 -06:00
parent 683289c4ad
commit bcd5deb269

View File

@ -66,7 +66,8 @@ def compute_weight_mat(input_size: int, output_size: int, scale,
with np.errstate(invalid='ignore', divide='ignore'):
weights = jnp.where(
jnp.abs(total_weight_sum) > 1000. * np.finfo(np.float32).eps,
weights / total_weight_sum, 0)
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
0)
# Zero out weights where the sample location is completely outside the input
# range.
# Note sample_f has already had the 0.5 removed, hence the weird range below.