mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow jnp.int32(tracer)
to work. (#3235)
This commit is contained in:
parent
e48a4e012b
commit
7944879cdd
@ -147,7 +147,7 @@ class _ScalarMeta(type):
|
||||
return not (self == other)
|
||||
|
||||
def __call__(self, x):
|
||||
return array(self.dtype.type(x), dtype=self.dtype)
|
||||
return array(x, dtype=self.dtype)
|
||||
|
||||
def _make_scalar_type(np_scalar_type):
|
||||
return _ScalarMeta(np_scalar_type.__name__, (object,),
|
||||
|
@ -173,5 +173,11 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
|
||||
np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))
|
||||
|
||||
def testScalarCastInsideJitWorks(self):
|
||||
# jnp.int32(tracer) should work.
|
||||
self.assertEqual(jnp.int32(101),
|
||||
jax.jit(lambda x: jnp.int32(x))(jnp.float32(101.4)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user