Allow jnp.int32(tracer) to work. (#3235)

This commit is contained in:
Peter Hawkins 2020-05-28 20:46:48 -04:00 committed by GitHub
parent e48a4e012b
commit 7944879cdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -147,7 +147,7 @@ class _ScalarMeta(type):
return not (self == other)
def __call__(self, x):
return array(self.dtype.type(x), dtype=self.dtype)
return array(x, dtype=self.dtype)
def _make_scalar_type(np_scalar_type):
return _ScalarMeta(np_scalar_type.__name__, (object,),

View File

@ -173,5 +173,11 @@ class DtypesTest(jtu.JaxTestCase):
np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))
def testScalarCastInsideJitWorks(self):
# jnp.int32(tracer) should work.
self.assertEqual(jnp.int32(101),
jax.jit(lambda x: jnp.int32(x))(jnp.float32(101.4)))
if __name__ == "__main__":
absltest.main()