mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #18217 from gnecula:multi_call_tf
PiperOrigin-RevId: 576218473
This commit is contained in:
commit
4897a5fb5a
@ -1029,7 +1029,10 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
|
||||
### Importing
|
||||
|
||||
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
|
||||
|
||||
if not isinstance(exported, Exported):
|
||||
raise ValueError(
|
||||
"The exported argument must be an export.Exported. "
|
||||
f"Found {exported}.")
|
||||
@jax.custom_vjp
|
||||
def f_flat(*args_flat):
|
||||
return call_exported_p.bind(*args_flat, exported=exported)
|
||||
|
@ -629,9 +629,11 @@ def emit_tf_embedded_graph_custom_call(
|
||||
raise ValueError(
|
||||
"call_tf_graph=True only support exporting by jax2tf.convert currently."
|
||||
)
|
||||
# TODO(necula): It is dangerous to modify global state when lowering because
|
||||
# there are a number of lowering caches that only cache the StableHLO.
|
||||
# See call_tf_test.py:test_multi_platform_call_tf_graph.
|
||||
called_index = add_to_call_tf_concrete_function_list(
|
||||
concrete_function_flat_tf, call_tf_concrete_function_list)
|
||||
call_target_name = "tf.call_tf_function"
|
||||
tf_backend_config = {
|
||||
"has_token_input_output": ir.BoolAttr.get(ordered),
|
||||
"called_index": mlir.i64_attr(called_index),
|
||||
@ -649,7 +651,7 @@ def emit_tf_embedded_graph_custom_call(
|
||||
custom_call = hlo.CustomCallOp(
|
||||
result_types,
|
||||
operands,
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effects),
|
||||
api_version=mlir.i32_attr(2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
@ -669,8 +671,9 @@ def emit_tf_embedded_graph_custom_call(
|
||||
|
||||
|
||||
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
|
||||
func_name = concrete_tf_fn.function_def.signature.name
|
||||
assert func_name not in [f.function_def.signature.name for f in call_tf_concrete_function_list]
|
||||
called_index = len(call_tf_concrete_function_list)
|
||||
call_tf_concrete_function_list.append(concrete_tf_fn)
|
||||
try:
|
||||
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
|
||||
except ValueError:
|
||||
called_index = len(call_tf_concrete_function_list)
|
||||
call_tf_concrete_function_list.append(concrete_tf_fn)
|
||||
return called_index
|
||||
|
@ -29,6 +29,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
import numpy as np
|
||||
|
||||
@ -65,6 +66,18 @@ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has d
|
||||
|
||||
class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
for tf_device in tf.config.list_logical_devices():
|
||||
if tf_device.device_type == "TPU_SYSTEM":
|
||||
continue # A virtual device
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super(CallTfTest, cls).setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
if tf is None:
|
||||
raise unittest.SkipTest("Test requires tensorflow")
|
||||
@ -737,6 +750,73 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
):
|
||||
_ = fun_jax_3(x)
|
||||
|
||||
def test_multi_platform(self):
|
||||
def tf_fun(x):
|
||||
return tf.math.sin(x)
|
||||
|
||||
def f_jax(x):
|
||||
return jnp.cos(jax2tf.call_tf(tf_fun)(jnp.cos(x)))
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
|
||||
# Find platforms that are available for both JAX and TF
|
||||
# Pick one device from each available platform
|
||||
jax_platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
devices = jax.devices(backend)
|
||||
except RuntimeError:
|
||||
devices = []
|
||||
if devices:
|
||||
jax_platforms.append(devices[0].platform)
|
||||
|
||||
jax_and_tf_platforms = (
|
||||
set(jax_platforms) & {d.device_type.lower()
|
||||
for d in self.__class__.tf_devices})
|
||||
|
||||
# TODO(b/306753579): call_tf can only be lowered when we have a device
|
||||
lowering_platforms = tuple(
|
||||
p if p != "gpu" else "cuda"
|
||||
for p in jax_and_tf_platforms)
|
||||
|
||||
exp = export.export(f_jax,
|
||||
lowering_platforms=lowering_platforms)(x)
|
||||
for jax_platform in jax_and_tf_platforms:
|
||||
with self.subTest(jax_platform):
|
||||
jax_device = jax.devices(jax_platform)[0]
|
||||
x_device = jax.device_put(x, jax_device)
|
||||
logging.info("Running harness natively on %s", jax_device)
|
||||
native_res = f_jax(x_device)
|
||||
logging.info("Running exported harness on %s", jax_device)
|
||||
exported_res = export.call_exported(exp)(x_device)
|
||||
self.assertAllClose(native_res, exported_res)
|
||||
|
||||
def test_multi_platform_call_tf_graph(self):
|
||||
def tf_fun(x):
|
||||
return tf.math.sin(x)
|
||||
|
||||
def f_jax(x):
|
||||
return jnp.cos(jax2tf.call_tf(tf_fun,
|
||||
call_tf_graph=True,
|
||||
ordered=True)(jnp.cos(x)))
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
# When we use call_tf_graph we can serialize for multiple platforms
|
||||
lowering_platforms = ("tpu", "cpu", "cuda")
|
||||
# We must use jax2tf.convert to run a call_tf(call_tf_graph)
|
||||
# TODO(necula): if we remove the tf.function and we have multiple platforms
|
||||
# then we attempt to lower call_tf multiple times and only the first
|
||||
# lowering will have the proper side effects for the function_list.
|
||||
f_tf = tf.function(jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=lowering_platforms))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
with self.subTest(tf_device.device_type):
|
||||
logging.info(
|
||||
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
|
||||
with tf.device(tf_device):
|
||||
res = f_tf(x)
|
||||
self.assertAllClose(res, f_jax(x))
|
||||
|
||||
|
||||
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
||||
"Reloading output of jax2tf into JAX with call_tf"
|
||||
@ -1693,5 +1773,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
res = loaded_model.call(*data_inputs)
|
||||
self.assertAllClose(jax_func(*data_inputs), res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -21,7 +21,6 @@ import functools
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
@ -60,16 +59,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Pick one device from each available platform
|
||||
cls.jax_platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
devices = jax.devices(backend)
|
||||
except RuntimeError:
|
||||
devices = []
|
||||
if devices:
|
||||
cls.jax_platforms.append(devices[0].platform)
|
||||
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
for tf_device in (tf.config.list_logical_devices("TPU") +
|
||||
@ -1741,9 +1730,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=("cpu", "cuda", "tpu"))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
logging.info(
|
||||
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
|
||||
with tf.device(tf_device):
|
||||
res = f_tf(x)
|
||||
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
|
||||
tf_device_jax_platform = dict(
|
||||
CPU="cpu", GPU="cuda", TPU="tpu"
|
||||
)[tf_device.device_type]
|
||||
|
Loading…
x
Reference in New Issue
Block a user