Use finfo(dtype).tiny as uniform minval

Otherwise, using the default clopen uniform for truncated_normal
introduces a slight shift of the empirical mean.
This commit is contained in:
Trevor Cai 2019-10-01 22:28:31 +01:00
parent 675bfd54b5
commit a44c1caf21

View File

@ -441,7 +441,9 @@ def _truncated_normal(key, lower, upper, shape, dtype):
sqrt2 = onp.array(onp.sqrt(2), dtype)
a = lax.erf(lax.convert_element_type(lower, dtype) / sqrt2)
b = lax.erf(lax.convert_element_type(upper, dtype) / sqrt2)
u = uniform(key, shape, dtype)
if not onp.issubdtype(dtype, onp.floating):
raise TypeError("truncated_normal only accepts floating point dtypes.")
u = uniform(key, shape, dtype, minval=onp.finfo(dtype).tiny)
return sqrt2 * lax.erf_inv(a + u * (b - a))