Update nn_test.py

Add test for fixed integer-type Gelu behaviour.
This commit is contained in:
russbates 2022-08-11 15:41:07 +01:00 committed by GitHub
parent 4ce88ec71f
commit 32b2a8ff00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,6 +84,11 @@ class NNFunctionsTest(jtu.JaxTestCase):
val = nn.glu(jnp.array([1.0, 0.0]), axis=0)
self.assertAllClose(val, jnp.array([0.5]))
def testGeluIntType(self):
val_float = nn.gelu(jnp.array(-1.0))
val_int = nn.gelu(jnp.array(-1))
self.assertAllClose(val_float, val_int)
@parameterized.parameters(False, True)
def testGelu(self, approximate):
def gelu_reference(x):