mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update nn_test.py
Add test for fixed integer-type Gelu behaviour.
This commit is contained in:
parent
4ce88ec71f
commit
32b2a8ff00
@ -84,6 +84,11 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
val = nn.glu(jnp.array([1.0, 0.0]), axis=0)
|
||||
self.assertAllClose(val, jnp.array([0.5]))
|
||||
|
||||
def testGeluIntType(self):
|
||||
val_float = nn.gelu(jnp.array(-1.0))
|
||||
val_int = nn.gelu(jnp.array(-1))
|
||||
self.assertAllClose(val_float, val_int)
|
||||
|
||||
@parameterized.parameters(False, True)
|
||||
def testGelu(self, approximate):
|
||||
def gelu_reference(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user