Merge pull request #18217 from gnecula:multi_call_tf

PiperOrigin-RevId: 576218473
This commit is contained in:
jax authors 2023-10-24 12:13:29 -07:00
commit 4897a5fb5a
4 changed files with 96 additions and 19 deletions

View File

@ -1029,7 +1029,10 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
### Importing
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
raise ValueError(
"The exported argument must be an export.Exported. "
f"Found {exported}.")
@jax.custom_vjp
def f_flat(*args_flat):
return call_exported_p.bind(*args_flat, exported=exported)

View File

@ -629,9 +629,11 @@ def emit_tf_embedded_graph_custom_call(
raise ValueError(
"call_tf_graph=True only support exporting by jax2tf.convert currently."
)
# TODO(necula): It is dangerous to modify global state when lowering because
# there are a number of lowering caches that only cache the StableHLO.
# See call_tf_test.py:test_multi_platform_call_tf_graph.
called_index = add_to_call_tf_concrete_function_list(
concrete_function_flat_tf, call_tf_concrete_function_list)
call_target_name = "tf.call_tf_function"
tf_backend_config = {
"has_token_input_output": ir.BoolAttr.get(ordered),
"called_index": mlir.i64_attr(called_index),
@ -649,7 +651,7 @@ def emit_tf_embedded_graph_custom_call(
custom_call = hlo.CustomCallOp(
result_types,
operands,
call_target_name=ir.StringAttr.get(call_target_name),
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
has_side_effect=ir.BoolAttr.get(has_side_effects),
api_version=mlir.i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
@ -669,8 +671,9 @@ def emit_tf_embedded_graph_custom_call(
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
func_name = concrete_tf_fn.function_def.signature.name
assert func_name not in [f.function_def.signature.name for f in call_tf_concrete_function_list]
called_index = len(call_tf_concrete_function_list)
call_tf_concrete_function_list.append(concrete_tf_fn)
try:
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
except ValueError:
called_index = len(call_tf_concrete_function_list)
call_tf_concrete_function_list.append(concrete_tf_fn)
return called_index

View File

@ -29,6 +29,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np
@ -65,6 +66,18 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d
class CallTfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
# One TF device of each device_type
cls.tf_devices = []
for tf_device in tf.config.list_logical_devices():
if tf_device.device_type == "TPU_SYSTEM":
continue # A virtual device
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super(CallTfTest, cls).setUpClass()
def setUp(self):
if tf is None:
raise unittest.SkipTest("Test requires tensorflow")
@ -737,6 +750,73 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
):
_ = fun_jax_3(x)
def test_multi_platform(self):
def tf_fun(x):
return tf.math.sin(x)
def f_jax(x):
return jnp.cos(jax2tf.call_tf(tf_fun)(jnp.cos(x)))
x = np.arange(12, dtype=np.float32).reshape((3, 4))
# Find platforms that are available for both JAX and TF
# Pick one device from each available platform
jax_platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
devices = jax.devices(backend)
except RuntimeError:
devices = []
if devices:
jax_platforms.append(devices[0].platform)
jax_and_tf_platforms = (
set(jax_platforms) & {d.device_type.lower()
for d in self.__class__.tf_devices})
# TODO(b/306753579): call_tf can only be lowered when we have a device
lowering_platforms = tuple(
p if p != "gpu" else "cuda"
for p in jax_and_tf_platforms)
exp = export.export(f_jax,
lowering_platforms=lowering_platforms)(x)
for jax_platform in jax_and_tf_platforms:
with self.subTest(jax_platform):
jax_device = jax.devices(jax_platform)[0]
x_device = jax.device_put(x, jax_device)
logging.info("Running harness natively on %s", jax_device)
native_res = f_jax(x_device)
logging.info("Running exported harness on %s", jax_device)
exported_res = export.call_exported(exp)(x_device)
self.assertAllClose(native_res, exported_res)
def test_multi_platform_call_tf_graph(self):
def tf_fun(x):
return tf.math.sin(x)
def f_jax(x):
return jnp.cos(jax2tf.call_tf(tf_fun,
call_tf_graph=True,
ordered=True)(jnp.cos(x)))
x = np.arange(12, dtype=np.float32).reshape((3, 4))
# When we use call_tf_graph we can serialize for multiple platforms
lowering_platforms = ("tpu", "cpu", "cuda")
# We must use jax2tf.convert to run a call_tf(call_tf_graph)
# TODO(necula): if we remove the tf.function and we have multiple platforms
# then we attempt to lower call_tf multiple times and only the first
# lowering will have the proper side effects for the function_list.
f_tf = tf.function(jax2tf.convert(
f_jax,
native_serialization=True,
native_serialization_platforms=lowering_platforms))
for tf_device in self.__class__.tf_devices:
with self.subTest(tf_device.device_type):
logging.info(
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
with tf.device(tf_device):
res = f_tf(x)
self.assertAllClose(res, f_jax(x))
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
"Reloading output of jax2tf into JAX with call_tf"
@ -1693,5 +1773,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
res = loaded_model.call(*data_inputs)
self.assertAllClose(jax_func(*data_inputs), res)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -21,7 +21,6 @@ import functools
import math
import os
import re
from typing import Optional
import unittest
from absl import logging
@ -60,16 +59,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
# Pick one device from each available platform
cls.jax_platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
devices = jax.devices(backend)
except RuntimeError:
devices = []
if devices:
cls.jax_platforms.append(devices[0].platform)
# One TF device of each device_type
cls.tf_devices = []
for tf_device in (tf.config.list_logical_devices("TPU") +
@ -1741,9 +1730,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
native_serialization=True,
native_serialization_platforms=("cpu", "cuda", "tpu"))
for tf_device in self.__class__.tf_devices:
logging.info(
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
with tf.device(tf_device):
res = f_tf(x)
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
tf_device_jax_platform = dict(
CPU="cpu", GPU="cuda", TPU="tpu"
)[tf_device.device_type]