mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[call_tf] Fix call_tf lowering for multi-platform lowering
call_tf has per-platform lowering because the lowering of the called TF function may depend on the platform. When doing multi-platform lowering this means that we lower call_tf several times and wrap the lowerings with a conditional. This results in an assertion failure in add_to_call_tf_concrete_function_list, because we are attempting to add the same function multiple times. Here we remove the assertion (afaik, it is Ok to add multiple functions with the same name, because all we care about is the index of the called function in the list). We also reuse the existing function if we are adding an identical one. We add tests for call_tf with multi-platform lowering.
This commit is contained in:
parent
8d49f9a159
commit
db44249afc
@ -1029,7 +1029,10 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
|
||||
### Importing
|
||||
|
||||
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
|
||||
|
||||
if not isinstance(exported, Exported):
|
||||
raise ValueError(
|
||||
"The exported argument must be an export.Exported. "
|
||||
f"Found {exported}.")
|
||||
@jax.custom_vjp
|
||||
def f_flat(*args_flat):
|
||||
return call_exported_p.bind(*args_flat, exported=exported)
|
||||
|
@ -629,9 +629,11 @@ def emit_tf_embedded_graph_custom_call(
|
||||
raise ValueError(
|
||||
"call_tf_graph=True only support exporting by jax2tf.convert currently."
|
||||
)
|
||||
# TODO(necula): It is dangerous to modify global state when lowering because
|
||||
# there are a number of lowering caches that only cache the StableHLO.
|
||||
# See call_tf_test.py:test_multi_platform_call_tf_graph.
|
||||
called_index = add_to_call_tf_concrete_function_list(
|
||||
concrete_function_flat_tf, call_tf_concrete_function_list)
|
||||
call_target_name = "tf.call_tf_function"
|
||||
tf_backend_config = {
|
||||
"has_token_input_output": ir.BoolAttr.get(ordered),
|
||||
"called_index": mlir.i64_attr(called_index),
|
||||
@ -649,7 +651,7 @@ def emit_tf_embedded_graph_custom_call(
|
||||
custom_call = hlo.CustomCallOp(
|
||||
result_types,
|
||||
operands,
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effects),
|
||||
api_version=mlir.i32_attr(2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
@ -669,8 +671,9 @@ def emit_tf_embedded_graph_custom_call(
|
||||
|
||||
|
||||
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
|
||||
func_name = concrete_tf_fn.function_def.signature.name
|
||||
assert func_name not in [f.function_def.signature.name for f in call_tf_concrete_function_list]
|
||||
called_index = len(call_tf_concrete_function_list)
|
||||
call_tf_concrete_function_list.append(concrete_tf_fn)
|
||||
try:
|
||||
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
|
||||
except ValueError:
|
||||
called_index = len(call_tf_concrete_function_list)
|
||||
call_tf_concrete_function_list.append(concrete_tf_fn)
|
||||
return called_index
|
||||
|
@ -29,6 +29,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
import numpy as np
|
||||
|
||||
@ -65,6 +66,18 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d
|
||||
|
||||
class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
for tf_device in tf.config.list_logical_devices():
|
||||
if tf_device.device_type == "TPU_SYSTEM":
|
||||
continue # A virtual device
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super(CallTfTest, cls).setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
if tf is None:
|
||||
raise unittest.SkipTest("Test requires tensorflow")
|
||||
@ -737,6 +750,73 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
):
|
||||
_ = fun_jax_3(x)
|
||||
|
||||
def test_multi_platform(self):
|
||||
def tf_fun(x):
|
||||
return tf.math.sin(x)
|
||||
|
||||
def f_jax(x):
|
||||
return jnp.cos(jax2tf.call_tf(tf_fun)(jnp.cos(x)))
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
|
||||
# Find platforms that are available for both JAX and TF
|
||||
# Pick one device from each available platform
|
||||
jax_platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
devices = jax.devices(backend)
|
||||
except RuntimeError:
|
||||
devices = []
|
||||
if devices:
|
||||
jax_platforms.append(devices[0].platform)
|
||||
|
||||
jax_and_tf_platforms = (
|
||||
set(jax_platforms) & {d.device_type.lower()
|
||||
for d in self.__class__.tf_devices})
|
||||
|
||||
# TODO(b/306753579): call_tf can only be lowered when we have a device
|
||||
lowering_platforms = tuple(
|
||||
p if p != "gpu" else "cuda"
|
||||
for p in jax_and_tf_platforms)
|
||||
|
||||
exp = export.export(f_jax,
|
||||
lowering_platforms=lowering_platforms)(x)
|
||||
for jax_platform in jax_and_tf_platforms:
|
||||
with self.subTest(jax_platform):
|
||||
jax_device = jax.devices(jax_platform)[0]
|
||||
x_device = jax.device_put(x, jax_device)
|
||||
logging.info("Running harness natively on %s", jax_device)
|
||||
native_res = f_jax(x_device)
|
||||
logging.info("Running exported harness on %s", jax_device)
|
||||
exported_res = export.call_exported(exp)(x_device)
|
||||
self.assertAllClose(native_res, exported_res)
|
||||
|
||||
def test_multi_platform_call_tf_graph(self):
|
||||
def tf_fun(x):
|
||||
return tf.math.sin(x)
|
||||
|
||||
def f_jax(x):
|
||||
return jnp.cos(jax2tf.call_tf(tf_fun,
|
||||
call_tf_graph=True,
|
||||
ordered=True)(jnp.cos(x)))
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
# When we use call_tf_graph we can serialize for multiple platforms
|
||||
lowering_platforms = ("tpu", "cpu", "cuda")
|
||||
# We must use jax2tf.convert to run a call_tf(call_tf_graph)
|
||||
# TODO(necula): if we remove the tf.function and we have multiple platforms
|
||||
# then we attempt to lower call_tf multiple times and only the first
|
||||
# lowering will have the proper side effects for the function_list.
|
||||
f_tf = tf.function(jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=lowering_platforms))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
with self.subTest(tf_device.device_type):
|
||||
logging.info(
|
||||
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
|
||||
with tf.device(tf_device):
|
||||
res = f_tf(x)
|
||||
self.assertAllClose(res, f_jax(x))
|
||||
|
||||
|
||||
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
||||
"Reloading output of jax2tf into JAX with call_tf"
|
||||
@ -1693,5 +1773,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
res = loaded_model.call(*data_inputs)
|
||||
self.assertAllClose(jax_func(*data_inputs), res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -21,7 +21,6 @@ import functools
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
@ -60,16 +59,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Pick one device from each available platform
|
||||
cls.jax_platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
devices = jax.devices(backend)
|
||||
except RuntimeError:
|
||||
devices = []
|
||||
if devices:
|
||||
cls.jax_platforms.append(devices[0].platform)
|
||||
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
for tf_device in (tf.config.list_logical_devices("TPU") +
|
||||
@ -1741,9 +1730,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=("cpu", "cuda", "tpu"))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
logging.info(
|
||||
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
|
||||
with tf.device(tf_device):
|
||||
res = f_tf(x)
|
||||
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
|
||||
tf_device_jax_platform = dict(
|
||||
CPU="cpu", GPU="cuda", TPU="tpu"
|
||||
)[tf_device.device_type]
|
||||
|
Loading…
x
Reference in New Issue
Block a user