mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
jax2tf: better handling for opaque dtypes
This commit is contained in:
parent
2845df03fc
commit
b031cc2660
@ -363,6 +363,10 @@ class KeyTyRules:
|
||||
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape), # type: ignore
|
||||
jnp.dtype('uint32'))]
|
||||
|
||||
@staticmethod
|
||||
def physical_const(val) -> Array:
|
||||
return val.unsafe_raw_array()
|
||||
|
||||
@staticmethod
|
||||
def physical_op_sharding(aval, op_sharding_proto):
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
|
@ -1087,6 +1087,8 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
# The float0 type is not known to TF.
|
||||
if jax_dtype == dtypes.float0:
|
||||
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
|
||||
if hasattr(val, 'dtype') and core.is_opaque_dtype(val.dtype):
|
||||
val = val.dtype._rules.physical_const(val)
|
||||
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
|
||||
if do_memoize:
|
||||
_thread_local_state.constant_cache[const_key] = (val, tf_val)
|
||||
|
@ -1751,6 +1751,32 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(res_tf.numpy(), res_jax)
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_custom_prng=True)
|
||||
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_key_argument(self):
|
||||
func = lambda key: jax.random.uniform(key, ())
|
||||
key = jax.random.PRNGKey(0)
|
||||
key_raw = jax.random.key_data(key)
|
||||
with self.assertWarnsRegex(FutureWarning, "Raw arrays as random keys.*"):
|
||||
tf_result = jax2tf.convert(func)(key_raw)
|
||||
jax_result = func(key)
|
||||
self.assertEqual(tf_result, jax_result)
|
||||
|
||||
def test_key_from_seed(self):
|
||||
func = lambda seed: jax.random.uniform(jax.random.PRNGKey(seed), ())
|
||||
seed = 1701
|
||||
tf_result = jax2tf.convert(func)(seed)
|
||||
jax_result = func(seed)
|
||||
self.assertEqual(tf_result, jax_result)
|
||||
|
||||
def test_key_closure(self):
|
||||
func = lambda: jax.random.uniform(global_key, ())
|
||||
global_key = jax.random.PRNGKey(0)
|
||||
tf_result = jax2tf.convert(func)()
|
||||
jax_result = func()
|
||||
self.assertEqual(tf_result, jax_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: Remove once tensorflow is 2.10.0 everywhere.
|
||||
if not hasattr(tfxla, "optimization_barrier"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user