mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
683289c4ad
commit
bcd5deb269
@ -66,7 +66,8 @@ def compute_weight_mat(input_size: int, output_size: int, scale,
|
||||
with np.errstate(invalid='ignore', divide='ignore'):
|
||||
weights = jnp.where(
|
||||
jnp.abs(total_weight_sum) > 1000. * np.finfo(np.float32).eps,
|
||||
weights / total_weight_sum, 0)
|
||||
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
|
||||
0)
|
||||
# Zero out weights where the sample location is completely outside the input
|
||||
# range.
|
||||
# Note sample_f has already had the 0.5 removed, hence the weird range below.
|
||||
|
Loading…
x
Reference in New Issue
Block a user