mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11727 from gnecula:call_tf_checks
PiperOrigin-RevId: 508685246
This commit is contained in:
commit
d09f3c2ee4
@ -116,15 +116,30 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
|
||||
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
|
||||
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
|
||||
|
||||
def check_tf_result(r_tf):
|
||||
# Check that the TF function returns values of expected types. This
|
||||
# improves error reporting, preventing hard-to-diagnose errors downstream
|
||||
try:
|
||||
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
|
||||
except Exception as e:
|
||||
msg = ("The called TF function returns a result that is not "
|
||||
f"convertible to JAX: {r_tf}.")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
res_treedef = None # We'll store here the result treedef
|
||||
res_tf_flat = None # For error reporting
|
||||
# The function below will be called at least once, either in eager
|
||||
# or in graph mode.
|
||||
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
|
||||
args_tf = args_treedef.unflatten(args_tf_flat)
|
||||
res_tf = callable_tf(*args_tf)
|
||||
nonlocal res_treedef
|
||||
nonlocal res_treedef, res_tf_flat
|
||||
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
|
||||
for r_tf in res_tf_flat:
|
||||
check_tf_result(r_tf)
|
||||
assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}"
|
||||
for r_tf in res_tf_flat:
|
||||
check_tf_result(r_tf)
|
||||
res_treedef = res_treedef_now
|
||||
return res_tf_flat
|
||||
|
||||
@ -158,6 +173,21 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
|
||||
function_flat_tf=function_flat_tf,
|
||||
args_flat_sig_tf=args_flat_sig_tf,
|
||||
has_side_effects=has_side_effects)
|
||||
|
||||
assert res_treedef is not None
|
||||
# Sometimes, in compiled mode, we get a different number of results than we
|
||||
# got when tracing the TF function (and building the res_treedef). This
|
||||
# can happen, e.g., when returning tf.TensorArray, which appears as one
|
||||
# leaf when tracing but after compilation we get a tuple. See
|
||||
# call_tf_test.test_error_bad_result_tensorarray.
|
||||
if res_treedef.num_leaves != len(res_jax_flat):
|
||||
# It is not clear if this error can happen once we have check_tf_result
|
||||
# in callable_flat_tf, but we keep it for safety.
|
||||
msg = (f"Incorrect number of results ({len(res_jax_flat)}) from the "
|
||||
"called TF function after compilation. "
|
||||
f"Expected {res_treedef.num_leaves} leaves based on observed "
|
||||
f"results during tracing: {res_tf_flat}.")
|
||||
raise ValueError(msg)
|
||||
return res_treedef.unflatten(res_jax_flat)
|
||||
|
||||
# Define the fwd and bwd custom_vjp functions
|
||||
|
@ -903,7 +903,11 @@ def _to_tf_dtype(jax_dtype):
|
||||
def _to_jax_dtype(tf_dtype):
|
||||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||||
# due to float0 and 64-bit behavior.
|
||||
return dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
|
||||
dt = dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
|
||||
if dt not in dtypes._jax_dtype_set:
|
||||
raise TypeError(f"dtype {dt} is not a valid JAX array "
|
||||
"type. Only arrays of numeric types are supported by JAX.")
|
||||
return dt
|
||||
|
||||
|
||||
def _maybe_decode_gda(gda_or_py_object: Any):
|
||||
|
@ -129,7 +129,17 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
res = fun_jax(x, y)
|
||||
self.assertAllClose((np.float32(12.), np.float64(11.)), res)
|
||||
|
||||
def test_eval_non_compileable_strings(self):
|
||||
def test_result_tuple(self):
|
||||
x1 = np.ones(3, dtype=np.int32)
|
||||
x2 = np.ones(5, dtype=np.float32)
|
||||
def fun_tf():
|
||||
return tf.tuple([x1, x2])
|
||||
|
||||
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
||||
res = fun_jax()
|
||||
self.assertAllClose(res, (x1, x2))
|
||||
|
||||
def test_error_non_compileable_strings(self):
|
||||
# Check that in op-by-op we call a function in eager mode.
|
||||
def f_tf_non_compileable(x):
|
||||
return tf.strings.length(tf.strings.format("Hello {}!", [x]))
|
||||
@ -145,19 +155,43 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
_call_tf_non_compileable_error):
|
||||
lax.cond(True, lambda x: f_jax(x), lambda x: f_jax(x), x)
|
||||
|
||||
def test_eval_non_compileable_dynamic_shape(self):
|
||||
def test_error_non_compileable_dynamic_shape(self):
|
||||
# Check that in op-by-op we call a function in eager mode.
|
||||
def f_tf_non_compileable(x):
|
||||
return tf.cond(x[0], lambda: x[1:], lambda: x)
|
||||
|
||||
f_jax = jax2tf.call_tf(f_tf_non_compileable)
|
||||
x = np.array([True, False], dtype=np.bool_)
|
||||
self.assertAllClose(f_tf_non_compileable(x), f_jax(x))
|
||||
self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) # Works in eager mode
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
_call_tf_dynamic_shape_error):
|
||||
jax.jit(f_jax)(x)
|
||||
|
||||
def test_error_bad_result_tensorarray(self):
|
||||
# Call a function that returns a tf.TensorArray. This should be detected
|
||||
# early on. If we don't the function is actually compileable but returns
|
||||
# a tuple instead of a single result.
|
||||
def fun_tf():
|
||||
ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
|
||||
ta = ta.unstack([0, 1, 2, 3, 4])
|
||||
return ta
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The called TF function returns a result that is not convertible to JAX"):
|
||||
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
||||
fun_jax()
|
||||
|
||||
def test_error_bad_result_string(self):
|
||||
def fun_tf():
|
||||
return tf.constant("foo")
|
||||
|
||||
# Now under jit, should fail because the function is not compileable
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The called TF function returns a result that is not convertible to JAX"):
|
||||
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
||||
fun_jax()
|
||||
|
||||
@_parameterized_jit
|
||||
def test_control_flow(self, with_jit=True):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user