[call_tf] Fix call_tf lowering for multi-platform lowering

call_tf has per-platform lowering because the lowering
of the called TF function may depend on the platform. When
doing multi-platform lowering this means that we lower
call_tf several times and wrap the lowerings with a
conditional. This results in an assertion failure
in add_to_call_tf_concrete_function_list, because we
are attempting to add the same function multiple times.

Here we remove the assertion (afaik, it is Ok to add
multiple functions with the same name, because all
we care about is the index of the called function in
the list). We also reuse the existing function if
we are adding an identical one.

We add tests for call_tf with multi-platform lowering.
This commit is contained in:
George Necula 2023-10-13 09:42:12 -07:00
parent 8d49f9a159
commit db44249afc
4 changed files with 96 additions and 19 deletions

View File

@ -1029,7 +1029,10 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
### Importing
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
raise ValueError(
"The exported argument must be an export.Exported. "
f"Found {exported}.")
@jax.custom_vjp
def f_flat(*args_flat):
return call_exported_p.bind(*args_flat, exported=exported)

View File

@ -629,9 +629,11 @@ def emit_tf_embedded_graph_custom_call(
raise ValueError(
"call_tf_graph=True only support exporting by jax2tf.convert currently."
)
# TODO(necula): It is dangerous to modify global state when lowering because
# there are a number of lowering caches that only cache the StableHLO.
# See call_tf_test.py:test_multi_platform_call_tf_graph.
called_index = add_to_call_tf_concrete_function_list(
concrete_function_flat_tf, call_tf_concrete_function_list)
call_target_name = "tf.call_tf_function"
tf_backend_config = {
"has_token_input_output": ir.BoolAttr.get(ordered),
"called_index": mlir.i64_attr(called_index),
@ -649,7 +651,7 @@ def emit_tf_embedded_graph_custom_call(
custom_call = hlo.CustomCallOp(
result_types,
operands,
call_target_name=ir.StringAttr.get(call_target_name),
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
has_side_effect=ir.BoolAttr.get(has_side_effects),
api_version=mlir.i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
@ -669,8 +671,9 @@ def emit_tf_embedded_graph_custom_call(
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
func_name = concrete_tf_fn.function_def.signature.name
assert func_name not in [f.function_def.signature.name for f in call_tf_concrete_function_list]
called_index = len(call_tf_concrete_function_list)
call_tf_concrete_function_list.append(concrete_tf_fn)
try:
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
except ValueError:
called_index = len(call_tf_concrete_function_list)
call_tf_concrete_function_list.append(concrete_tf_fn)
return called_index

View File

@ -29,6 +29,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np
@ -65,6 +66,18 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d
class CallTfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
# One TF device of each device_type
cls.tf_devices = []
for tf_device in tf.config.list_logical_devices():
if tf_device.device_type == "TPU_SYSTEM":
continue # A virtual device
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super(CallTfTest, cls).setUpClass()
def setUp(self):
if tf is None:
raise unittest.SkipTest("Test requires tensorflow")
@ -737,6 +750,73 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
):
_ = fun_jax_3(x)
def test_multi_platform(self):
def tf_fun(x):
return tf.math.sin(x)
def f_jax(x):
return jnp.cos(jax2tf.call_tf(tf_fun)(jnp.cos(x)))
x = np.arange(12, dtype=np.float32).reshape((3, 4))
# Find platforms that are available for both JAX and TF
# Pick one device from each available platform
jax_platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
devices = jax.devices(backend)
except RuntimeError:
devices = []
if devices:
jax_platforms.append(devices[0].platform)
jax_and_tf_platforms = (
set(jax_platforms) & {d.device_type.lower()
for d in self.__class__.tf_devices})
# TODO(b/306753579): call_tf can only be lowered when we have a device
lowering_platforms = tuple(
p if p != "gpu" else "cuda"
for p in jax_and_tf_platforms)
exp = export.export(f_jax,
lowering_platforms=lowering_platforms)(x)
for jax_platform in jax_and_tf_platforms:
with self.subTest(jax_platform):
jax_device = jax.devices(jax_platform)[0]
x_device = jax.device_put(x, jax_device)
logging.info("Running harness natively on %s", jax_device)
native_res = f_jax(x_device)
logging.info("Running exported harness on %s", jax_device)
exported_res = export.call_exported(exp)(x_device)
self.assertAllClose(native_res, exported_res)
def test_multi_platform_call_tf_graph(self):
def tf_fun(x):
return tf.math.sin(x)
def f_jax(x):
return jnp.cos(jax2tf.call_tf(tf_fun,
call_tf_graph=True,
ordered=True)(jnp.cos(x)))
x = np.arange(12, dtype=np.float32).reshape((3, 4))
# When we use call_tf_graph we can serialize for multiple platforms
lowering_platforms = ("tpu", "cpu", "cuda")
# We must use jax2tf.convert to run a call_tf(call_tf_graph)
# TODO(necula): if we remove the tf.function and we have multiple platforms
# then we attempt to lower call_tf multiple times and only the first
# lowering will have the proper side effects for the function_list.
f_tf = tf.function(jax2tf.convert(
f_jax,
native_serialization=True,
native_serialization_platforms=lowering_platforms))
for tf_device in self.__class__.tf_devices:
with self.subTest(tf_device.device_type):
logging.info(
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
with tf.device(tf_device):
res = f_tf(x)
self.assertAllClose(res, f_jax(x))
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
"Reloading output of jax2tf into JAX with call_tf"
@ -1693,5 +1773,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
res = loaded_model.call(*data_inputs)
self.assertAllClose(jax_func(*data_inputs), res)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -21,7 +21,6 @@ import functools
import math
import os
import re
from typing import Optional
import unittest
from absl import logging
@ -60,16 +59,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
# Pick one device from each available platform
cls.jax_platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
devices = jax.devices(backend)
except RuntimeError:
devices = []
if devices:
cls.jax_platforms.append(devices[0].platform)
# One TF device of each device_type
cls.tf_devices = []
for tf_device in (tf.config.list_logical_devices("TPU") +
@ -1741,9 +1730,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
native_serialization=True,
native_serialization_platforms=("cpu", "cuda", "tpu"))
for tf_device in self.__class__.tf_devices:
logging.info(
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
with tf.device(tf_device):
res = f_tf(x)
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
tf_device_jax_platform = dict(
CPU="cpu", GPU="cuda", TPU="tpu"
)[tf_device.device_type]