mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use finfo(dtype).tiny as uniform minval
Otherwise, using the default clopen uniform for truncated_normal introduces a slight shift of the empirical mean.
This commit is contained in:
parent
675bfd54b5
commit
a44c1caf21
@ -441,7 +441,9 @@ def _truncated_normal(key, lower, upper, shape, dtype):
|
||||
sqrt2 = onp.array(onp.sqrt(2), dtype)
|
||||
a = lax.erf(lax.convert_element_type(lower, dtype) / sqrt2)
|
||||
b = lax.erf(lax.convert_element_type(upper, dtype) / sqrt2)
|
||||
u = uniform(key, shape, dtype)
|
||||
if not onp.issubdtype(dtype, onp.floating):
|
||||
raise TypeError("truncated_normal only accepts floating point dtypes.")
|
||||
u = uniform(key, shape, dtype, minval=onp.finfo(dtype).tiny)
|
||||
return sqrt2 * lax.erf_inv(a + u * (b - a))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user