mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax2tf: correctly handle opaque dtype in jax2tf pure()
In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure(). PiperOrigin-RevId: 535408671
This commit is contained in:
parent
8534f0bfc3
commit
bbae2edd12
@ -1244,8 +1244,8 @@ class TensorFlowTrace(core.Trace):
|
||||
return val
|
||||
tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True)
|
||||
return TensorFlowTracer(
|
||||
self, val, core.ShapedArray(tf_val.shape, jax_dtype,
|
||||
weak_type=dtypes.is_weakly_typed(val)))
|
||||
self, tf_val, core.ShapedArray(np.shape(val), jax_dtype,
|
||||
weak_type=dtypes.is_weakly_typed(val)))
|
||||
|
||||
def lift(self, val: core.Tracer) -> TensorFlowTracer:
|
||||
# This would be called when we need to raise a tracer from a lower-level
|
||||
|
@ -1780,7 +1780,10 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(tf_result, jax_result)
|
||||
|
||||
def test_key_closure(self):
|
||||
func = lambda: jax.random.uniform(global_key, ())
|
||||
def func():
|
||||
# Include nontrivial shape operations to catch tracing bugs.
|
||||
key = global_key.reshape(1).squeeze()
|
||||
return jax.random.uniform(key)
|
||||
global_key = jax.random.PRNGKey(0)
|
||||
tf_result = jax2tf.convert(func)()
|
||||
jax_result = func()
|
||||
|
Loading…
x
Reference in New Issue
Block a user