jax2tf: better handling for opaque dtypes

This commit is contained in:
Jake VanderPlas 2023-05-04 14:22:15 -07:00
parent 2845df03fc
commit b031cc2660
3 changed files with 32 additions and 0 deletions

View File

@ -363,6 +363,10 @@ class KeyTyRules:
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape), # type: ignore
jnp.dtype('uint32'))]
@staticmethod
def physical_const(val) -> Array:
return val.unsafe_raw_array()
@staticmethod
def physical_op_sharding(aval, op_sharding_proto):
key_shape = aval.dtype.impl.key_shape

View File

@ -1087,6 +1087,8 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
# The float0 type is not known to TF.
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
if hasattr(val, 'dtype') and core.is_opaque_dtype(val.dtype):
val = val.dtype._rules.physical_const(val)
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
if do_memoize:
_thread_local_state.constant_cache[const_key] = (val, tf_val)

View File

@ -1751,6 +1751,32 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(res_tf.numpy(), res_jax)
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
def test_key_argument(self):
func = lambda key: jax.random.uniform(key, ())
key = jax.random.PRNGKey(0)
key_raw = jax.random.key_data(key)
with self.assertWarnsRegex(FutureWarning, "Raw arrays as random keys.*"):
tf_result = jax2tf.convert(func)(key_raw)
jax_result = func(key)
self.assertEqual(tf_result, jax_result)
def test_key_from_seed(self):
func = lambda seed: jax.random.uniform(jax.random.PRNGKey(seed), ())
seed = 1701
tf_result = jax2tf.convert(func)(seed)
jax_result = func(seed)
self.assertEqual(tf_result, jax_result)
def test_key_closure(self):
func = lambda: jax.random.uniform(global_key, ())
global_key = jax.random.PRNGKey(0)
tf_result = jax2tf.convert(func)()
jax_result = func()
self.assertEqual(tf_result, jax_result)
if __name__ == "__main__":
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if not hasattr(tfxla, "optimization_barrier"):