mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use input dtype for constants in jax.nn.gelu (#2259)
This commit is contained in:
parent
6cceb2c778
commit
eda91a048b
@ -180,7 +180,8 @@ def gelu(x):
|
||||
speed. For more information, see `Gaussian Error Linear Units (GELUs)
|
||||
<https://arxiv.org/abs/1606.08415>`_, section 2.
|
||||
"""
|
||||
cdf = 0.5 * (1.0 + np.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))))
|
||||
sqrt_2_over_pi = onp.sqrt(2 / onp.pi).astype(x.dtype)
|
||||
cdf = 0.5 * (1.0 + np.tanh(sqrt_2_over_pi * (x + 0.044715 * x**3)))
|
||||
return x * cdf
|
||||
|
||||
def glu(x, axis=-1):
|
||||
|
@ -51,6 +51,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
val = nn.elu(1e4)
|
||||
self.assertAllClose(val, 1e4, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters(*itertools.product(
|
||||
(np.float32, np.bfloat16, np.float16),
|
||||
(nn.gelu, nn.relu, nn.softplus, nn.sigmoid)))
|
||||
def testDtypeMatchesInput(self, dtype, fn):
|
||||
x = np.zeros((), dtype=dtype)
|
||||
out = fn(x)
|
||||
self.assertEqual(out.dtype, dtype)
|
||||
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testEluMemory(self):
|
||||
# see https://github.com/google/jax/pull/1640
|
||||
|
Loading…
x
Reference in New Issue
Block a user