2022-09-22 12:26:48 -07:00
|
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-02-04 10:56:13 +02:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""Allows JAX to call TensorFlow functions with support for autodiff.
|
|
|
|
|
|
|
|
|
|
**Experimental: please give feedback, and expect changes.**
|
|
|
|
|
|
|
|
|
|
This module introduces the function :func:`call_tf` that allows JAX to call
|
|
|
|
|
TensorFlow functions.
|
|
|
|
|
|
|
|
|
|
For examples and details, see
|
2024-09-20 07:51:48 -07:00
|
|
|
|
https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
"""
|
2023-12-11 13:59:29 +00:00
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
|
from collections.abc import Callable, Sequence
|
2023-12-13 15:43:12 +01:00
|
|
|
|
import dataclasses
|
2021-06-25 07:39:09 +02:00
|
|
|
|
import functools
|
2024-06-26 14:44:52 -04:00
|
|
|
|
from typing import Any
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
2022-11-30 23:21:55 -08:00
|
|
|
|
from absl import logging
|
2021-02-04 10:56:13 +02:00
|
|
|
|
import jax
|
2021-02-05 12:04:25 +02:00
|
|
|
|
from jax import dlpack
|
2021-06-16 13:25:56 +03:00
|
|
|
|
from jax import dtypes
|
2023-02-15 23:40:12 -08:00
|
|
|
|
from jax import numpy as jnp
|
2021-02-04 10:56:13 +02:00
|
|
|
|
from jax import tree_util
|
2023-02-01 17:50:00 -08:00
|
|
|
|
from jax._src import ad_util
|
2023-04-26 10:20:12 -07:00
|
|
|
|
from jax._src import core
|
2023-02-01 17:50:00 -08:00
|
|
|
|
from jax._src import effects
|
|
|
|
|
from jax._src import util
|
2023-04-26 10:20:12 -07:00
|
|
|
|
from jax._src.lib import xla_client
|
2022-04-19 13:59:28 -07:00
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
2022-12-15 20:59:34 -08:00
|
|
|
|
from jax._src.lib.mlir.dialects import hlo
|
2021-11-24 07:47:48 -08:00
|
|
|
|
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
2024-07-17 09:56:18 -07:00
|
|
|
|
from jax._src.interpreters import mlir
|
2021-02-04 10:56:13 +02:00
|
|
|
|
import numpy as np
|
2023-04-26 10:20:12 -07:00
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
2021-06-16 13:25:56 +03:00
|
|
|
|
map = util.safe_map
|
|
|
|
|
zip = util.safe_zip
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
TfConcreteFunction = Any
|
2023-03-08 14:09:35 -08:00
|
|
|
|
TfVal = jax2tf_internal.TfVal
|
2021-06-25 07:39:09 +02:00
|
|
|
|
|
2021-02-05 12:04:25 +02:00
|
|
|
|
# The platforms for which to use DLPack to avoid copying (only works on GPU
|
2023-08-18 16:50:36 -04:00
|
|
|
|
# and CPU at the moment, and only for Array). For CPU we don't need
|
2021-02-05 12:04:25 +02:00
|
|
|
|
# DLPack, if we are careful.
|
|
|
|
|
_DLPACK_PLATFORMS = ("gpu",)
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
2023-07-21 10:43:03 -07:00
|
|
|
|
class UnspecifiedOutputShapeDtype:
|
|
|
|
|
pass
|
2023-03-24 11:26:44 -07:00
|
|
|
|
|
|
|
|
|
def call_tf(
|
|
|
|
|
callable_tf: Callable,
|
|
|
|
|
has_side_effects=True,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ordered=False,
|
2023-07-21 10:43:03 -07:00
|
|
|
|
output_shape_dtype=UnspecifiedOutputShapeDtype(),
|
2023-05-03 09:04:01 -07:00
|
|
|
|
call_tf_graph=False,
|
2023-03-24 11:26:44 -07:00
|
|
|
|
) -> Callable:
|
2021-02-04 10:56:13 +02:00
|
|
|
|
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
|
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
|
2021-02-04 10:56:13 +02:00
|
|
|
|
numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
|
|
|
|
|
function must return the same type of results.
|
|
|
|
|
|
|
|
|
|
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
|
2024-07-24 10:23:29 -07:00
|
|
|
|
or :func:`jax.pmap`, or a control-flow primitive) then
|
2023-03-24 11:26:44 -07:00
|
|
|
|
``callable_tf`` will be compiled with ``tf.function(callable_tf,
|
|
|
|
|
jit_compile=True)``
|
2021-02-04 10:56:13 +02:00
|
|
|
|
and the resulting XLA computation will be embedded in JAX's XLA computation.
|
|
|
|
|
|
|
|
|
|
If ``call_tf`` appears outside a JAX staging context, it will be called inline
|
|
|
|
|
using TensorFlow eager mode.
|
|
|
|
|
|
|
|
|
|
The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
|
2021-06-25 07:39:09 +02:00
|
|
|
|
``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
|
2021-02-04 10:56:13 +02:00
|
|
|
|
that the gradient will be TensorFlow-accurate, e.g., will respect the
|
2021-06-25 07:39:09 +02:00
|
|
|
|
custom gradients that may be defined for the code in ``callable_tf``.
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
For an example and more details see the
|
2023-03-24 11:26:44 -07:00
|
|
|
|
`README
|
2024-09-20 07:51:48 -07:00
|
|
|
|
<https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
Args:
|
2021-06-25 07:39:09 +02:00
|
|
|
|
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
|
2021-02-04 10:56:13 +02:00
|
|
|
|
arguments.
|
2023-01-12 08:44:53 +01:00
|
|
|
|
has_side_effects: if True then it ensures that instances of this primitive
|
|
|
|
|
are not removed or replicated by JAX optimizations such as dead-code
|
|
|
|
|
elimination.
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ordered: If true, calls are modeled as having ordered effects.
|
|
|
|
|
output_shape_dtype: An optional declaration of the expected shape and dtype
|
|
|
|
|
of the result of the called TensorFlow function. If given it will be used
|
|
|
|
|
during JAX tracing to form the abstract values of the results of the
|
2023-03-24 11:26:44 -07:00
|
|
|
|
`call_tf`. If not given then we form a `tf.Graph` for the called
|
|
|
|
|
TensorFlow function and we use the TensorFlow-inferred shapes and types.
|
|
|
|
|
Must be a pytree matching the structure of the nested structure returned
|
|
|
|
|
from the TensorFlow function, containing objects with `.shape` and
|
|
|
|
|
`.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
|
2023-05-03 09:04:01 -07:00
|
|
|
|
call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
|
|
|
|
|
future.
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
|
2021-02-04 10:56:13 +02:00
|
|
|
|
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
|
2023-03-24 11:26:44 -07:00
|
|
|
|
op-by-op mode or in a staged context. This callable can be used with JAX's
|
|
|
|
|
reverse-mode autodiff (:func:`jax.grad`).
|
2021-02-04 10:56:13 +02:00
|
|
|
|
"""
|
|
|
|
|
@jax.custom_vjp
|
|
|
|
|
def make_call(*args_jax):
|
|
|
|
|
"""We wrap it all in `make_call` so that we can attach custom VJP."""
|
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
|
2021-06-16 13:25:56 +03:00
|
|
|
|
# Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
|
|
|
|
|
def canonical_arg(v):
|
|
|
|
|
v = v if getattr(v, "dtype", None) else np.asarray(v)
|
|
|
|
|
dtype = dtypes.canonicalize_dtype(v.dtype)
|
|
|
|
|
if dtype != v.dtype:
|
|
|
|
|
v = v.astype(dtype)
|
|
|
|
|
return v
|
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
|
2021-07-10 18:49:25 +03:00
|
|
|
|
def make_tensorspec(a_jax):
|
|
|
|
|
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
|
2023-04-27 13:42:23 +02:00
|
|
|
|
a_tf_shape = [d if core.is_constant_dim(d) else None for d in a_jax.shape]
|
2023-01-12 13:04:09 -08:00
|
|
|
|
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
|
2021-07-10 18:49:25 +03:00
|
|
|
|
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
|
2021-06-25 07:39:09 +02:00
|
|
|
|
|
2023-07-21 10:43:03 -07:00
|
|
|
|
if not isinstance(output_shape_dtype, UnspecifiedOutputShapeDtype):
|
2023-03-08 14:09:35 -08:00
|
|
|
|
output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
|
|
|
|
|
output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
|
|
|
|
|
else:
|
|
|
|
|
output_avals, output_shape_dtype_tree = None, None
|
2022-08-04 09:50:41 +03:00
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
res_treedef = None # We'll store here the result treedef
|
2022-08-04 09:50:41 +03:00
|
|
|
|
res_tf_flat = None # For error reporting
|
2021-06-25 07:39:09 +02:00
|
|
|
|
# The function below will be called at least once, either in eager
|
2023-03-08 14:09:35 -08:00
|
|
|
|
# mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
|
2021-06-25 07:39:09 +02:00
|
|
|
|
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
|
|
|
|
|
args_tf = args_treedef.unflatten(args_tf_flat)
|
|
|
|
|
res_tf = callable_tf(*args_tf)
|
2023-05-06 08:29:45 -07:00
|
|
|
|
|
|
|
|
|
# b/279454591: When `callable_tf` is a tf function with zero outputs, it
|
|
|
|
|
# returns a `StatefulPartitionedCall` (if the function is stateful) or
|
|
|
|
|
# `PartitionedCall` (if the function is stateless) op instead of
|
|
|
|
|
# tf.Tensors. We work around this issue by replacing the output `res_tf`
|
|
|
|
|
# with an empty list.
|
|
|
|
|
|
|
|
|
|
if isinstance(res_tf, tf.Operation):
|
|
|
|
|
assert (
|
|
|
|
|
res_tf.type == "StatefulPartitionedCall"
|
|
|
|
|
or res_tf.type == "PartitionedCall"
|
|
|
|
|
)
|
|
|
|
|
t_out = res_tf.get_attr("Tout")
|
|
|
|
|
# t_out should be an empty list.
|
|
|
|
|
assert not t_out, (
|
|
|
|
|
"The TF function returned an unexpected result, please check its"
|
|
|
|
|
f" function body. res_tf = {res_tf}"
|
|
|
|
|
)
|
|
|
|
|
res_tf = t_out
|
|
|
|
|
|
2022-08-04 09:50:41 +03:00
|
|
|
|
nonlocal res_treedef, res_tf_flat
|
2021-06-25 07:39:09 +02:00
|
|
|
|
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
|
2023-03-08 14:09:35 -08:00
|
|
|
|
assert res_treedef is None or res_treedef == res_treedef_now, (
|
|
|
|
|
f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
|
2021-06-25 07:39:09 +02:00
|
|
|
|
res_treedef = res_treedef_now
|
2023-03-08 14:09:35 -08:00
|
|
|
|
if output_avals is not None:
|
|
|
|
|
if res_treedef != output_shape_dtype_tree:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The pytree of the TensorFlow function results does not match the "
|
|
|
|
|
"pytree of the declared output_shape_dtype:\n"
|
|
|
|
|
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
|
|
|
|
|
assert len(output_avals) == len(res_tf_flat)
|
|
|
|
|
|
2023-04-26 10:20:12 -07:00
|
|
|
|
checked_res_tf_flat = [
|
|
|
|
|
check_tf_result(i, r_tf, r_aval)
|
|
|
|
|
for i, (r_tf, r_aval) in enumerate(
|
2023-04-27 13:42:23 +02:00
|
|
|
|
zip(res_tf_flat,
|
|
|
|
|
(output_avals
|
|
|
|
|
if output_avals is not None
|
|
|
|
|
else (None,) * len(res_tf_flat))))]
|
2023-04-26 10:20:12 -07:00
|
|
|
|
return checked_res_tf_flat
|
2021-06-25 07:39:09 +02:00
|
|
|
|
|
|
|
|
|
# Prepare a tf.function ahead of time, to cache the concrete functions. This
|
|
|
|
|
# won't be used in op-by-op execution mode.
|
2023-04-26 10:20:12 -07:00
|
|
|
|
function_flat_tf = tf.function(
|
2023-05-03 09:04:01 -07:00
|
|
|
|
callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
|
2021-06-16 13:25:56 +03:00
|
|
|
|
|
2021-02-04 10:56:13 +02:00
|
|
|
|
res_jax_flat = call_tf_p.bind(
|
2021-06-25 07:39:09 +02:00
|
|
|
|
*args_flat_jax,
|
2021-06-10 17:01:22 +02:00
|
|
|
|
# Carry the actual function such that op-by-op call can call in TF eager mode.
|
2021-06-25 07:39:09 +02:00
|
|
|
|
callable_flat_tf=callable_flat_tf,
|
|
|
|
|
function_flat_tf=function_flat_tf,
|
2023-01-12 08:44:53 +01:00
|
|
|
|
args_flat_sig_tf=args_flat_sig_tf,
|
2023-03-08 14:09:35 -08:00
|
|
|
|
output_avals=output_avals,
|
2023-03-24 11:26:44 -07:00
|
|
|
|
has_side_effects=has_side_effects,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ordered=ordered,
|
|
|
|
|
call_tf_graph=call_tf_graph,
|
|
|
|
|
)
|
2022-08-04 09:50:41 +03:00
|
|
|
|
|
2023-03-08 14:09:35 -08:00
|
|
|
|
# We must have called callable_flat_tf by nοw
|
2022-08-04 09:50:41 +03:00
|
|
|
|
assert res_treedef is not None
|
2021-02-04 10:56:13 +02:00
|
|
|
|
return res_treedef.unflatten(res_jax_flat)
|
|
|
|
|
|
|
|
|
|
# Define the fwd and bwd custom_vjp functions
|
|
|
|
|
def make_call_vjp_fwd(*args_jax):
|
2021-06-16 10:20:24 +03:00
|
|
|
|
# Return the primal arguments as the residual
|
2021-02-04 10:56:13 +02:00
|
|
|
|
return make_call(*args_jax), args_jax
|
|
|
|
|
|
2021-06-16 10:20:24 +03:00
|
|
|
|
def make_call_vjp_bwd(residual_jax, ct_res_jax):
|
|
|
|
|
args_jax = residual_jax # residual is the primal argument
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
2021-06-16 10:20:24 +03:00
|
|
|
|
def tf_vjp_fun(args_tf, ct_res_tf):
|
2021-02-04 10:56:13 +02:00
|
|
|
|
"""Invoke TF gradient."""
|
2021-06-16 10:20:24 +03:00
|
|
|
|
|
2024-08-12 09:23:40 -07:00
|
|
|
|
# TF does not like us to watch non-float vars or Nones.
|
|
|
|
|
def replace_non_float_or_none(arg_tf):
|
|
|
|
|
if arg_tf is not None and (
|
|
|
|
|
arg_tf.dtype.is_floating or arg_tf.dtype.is_complex
|
|
|
|
|
):
|
2023-02-15 23:40:12 -08:00
|
|
|
|
return arg_tf
|
2021-06-16 10:20:24 +03:00
|
|
|
|
else:
|
2023-04-27 13:42:23 +02:00
|
|
|
|
# When watched, this will be ignored. When used in results it will
|
2021-06-16 10:20:24 +03:00
|
|
|
|
# result in a floating 0. gradient, which JAX will ignore (and
|
|
|
|
|
# replace it with a float0)
|
|
|
|
|
return tf.zeros((), dtype=tf.float32)
|
|
|
|
|
|
2024-08-12 09:23:40 -07:00
|
|
|
|
watched_args_tf = tf.nest.map_structure(
|
|
|
|
|
replace_non_float_or_none, args_tf
|
|
|
|
|
)
|
2021-02-04 10:56:13 +02:00
|
|
|
|
with tf.GradientTape(persistent=True) as tape:
|
2021-06-16 10:20:24 +03:00
|
|
|
|
tape.watch(watched_args_tf)
|
2021-06-25 07:39:09 +02:00
|
|
|
|
res = callable_tf(*args_tf)
|
2021-06-16 10:20:24 +03:00
|
|
|
|
|
|
|
|
|
tf.nest.assert_same_structure(res, ct_res_tf)
|
|
|
|
|
dres_darg = tape.gradient(
|
2024-08-12 09:23:40 -07:00
|
|
|
|
tf.nest.map_structure(replace_non_float_or_none, res),
|
2021-06-16 10:20:24 +03:00
|
|
|
|
sources=watched_args_tf,
|
|
|
|
|
output_gradients=ct_res_tf,
|
2024-08-12 09:23:40 -07:00
|
|
|
|
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
|
|
|
|
)
|
2021-06-16 10:20:24 +03:00
|
|
|
|
|
2023-02-17 09:38:41 -08:00
|
|
|
|
dres_darg = tree_util.tree_map(
|
|
|
|
|
lambda x: x if x is None else tf.convert_to_tensor(x),
|
|
|
|
|
dres_darg,
|
|
|
|
|
)
|
2024-08-21 09:48:59 -07:00
|
|
|
|
|
|
|
|
|
# callable_tf may mutate (the structure of) args_tf, thus we check against
|
|
|
|
|
# watched_args_tf which should be structurally the same as the original
|
|
|
|
|
# args_tf.
|
|
|
|
|
tf.nest.assert_same_structure(dres_darg, watched_args_tf)
|
2021-06-16 10:20:24 +03:00
|
|
|
|
return dres_darg
|
|
|
|
|
|
2021-02-04 10:56:13 +02:00
|
|
|
|
# Use call_tf to call the VJP function
|
2021-06-29 14:57:42 +03:00
|
|
|
|
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
|
|
|
|
|
# We must make the float0s that JAX expects
|
|
|
|
|
def fix_float0(arg_jax, ct_arg_jax):
|
Update users of jax.tree.map() to be more careful about how they handle Nones.
Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.
Fix code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.
PiperOrigin-RevId: 673258116
2024-09-10 23:53:24 -07:00
|
|
|
|
if arg_jax is None:
|
|
|
|
|
return None
|
2021-06-29 14:57:42 +03:00
|
|
|
|
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
|
|
|
|
|
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
|
|
|
|
|
if ct_arg_dtype != ct_arg_jax.dtype:
|
|
|
|
|
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
|
|
|
|
|
ct_arg_dtype))
|
|
|
|
|
return ct_arg_jax
|
|
|
|
|
|
Update users of jax.tree.map() to be more careful about how they handle Nones.
Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.
Fix code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.
PiperOrigin-RevId: 673258116
2024-09-10 23:53:24 -07:00
|
|
|
|
ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax,
|
|
|
|
|
is_leaf=lambda x: x is None)
|
2021-06-29 14:57:42 +03:00
|
|
|
|
return ct_args_jax_fixed
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
|
2021-06-25 07:39:09 +02:00
|
|
|
|
return util.wraps(callable_tf)(make_call)
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
|
def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal:
|
2023-03-08 14:09:35 -08:00
|
|
|
|
# Check that the TF function returns values of expected types. This
|
|
|
|
|
# improves error reporting, preventing hard-to-diagnose errors downstream
|
|
|
|
|
try:
|
|
|
|
|
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
msg = ("The called TF function returns a result that is not "
|
|
|
|
|
f"convertible to JAX: {r_tf}.")
|
|
|
|
|
raise ValueError(msg) from e
|
|
|
|
|
|
|
|
|
|
if r_aval is None:
|
|
|
|
|
return r_tf
|
|
|
|
|
# We convert to TF type, and canonicalize to 32-bit if necessary
|
|
|
|
|
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
|
|
|
|
|
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
|
|
|
|
|
# check at runtime that the returned shape matches the declared shape. I wish
|
|
|
|
|
# that tf.ensure_shape did this, but it can only take shapes that contain None
|
|
|
|
|
# not computed shapes. However, in eager mode we should be able to resolve
|
|
|
|
|
# the declared shapes to constants and we get better checking.
|
|
|
|
|
if tf.executing_eagerly():
|
|
|
|
|
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
|
|
|
|
|
else:
|
|
|
|
|
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
|
|
|
|
|
# We do as much checking as we can here, instead of relying on tf.ensure_shape
|
|
|
|
|
# because the latter gives different errors in eager vs. compiled mode.
|
2023-04-26 10:20:12 -07:00
|
|
|
|
# TODO(b/279454591): This strange error is from TF. Eager function suppose
|
|
|
|
|
# return tf Val with concrete shape but not. Here we change exception to warn
|
|
|
|
|
# and bypass it. This case need revisit on TF side.
|
|
|
|
|
try:
|
|
|
|
|
_ = len(r_tf.shape)
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
msg = (
|
|
|
|
|
"The shape check test cannot be performed because the shape of the"
|
|
|
|
|
"`r_tf` tensor cannot be obtained."
|
|
|
|
|
f"r_tf = {r_tf}, r_aval = {r_aval}"
|
|
|
|
|
)
|
|
|
|
|
msg += str(e)
|
|
|
|
|
logging.warning(msg)
|
|
|
|
|
return r_tf
|
2023-03-08 14:09:35 -08:00
|
|
|
|
if (r_tf.dtype != r_aval_dtype_tf or
|
|
|
|
|
len(r_tf.shape) != len(r_aval_shape_tf) or
|
|
|
|
|
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
|
|
|
|
|
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
|
|
|
|
|
msg = ("The shapes or dtypes returned by the TensorFlow function "
|
|
|
|
|
"do not match the declared output_shape_dtype:\n"
|
|
|
|
|
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
# At this point tf.ensure_shape does not do much, it should never throw an
|
|
|
|
|
# error, albeit it may refine the shape a bit.
|
|
|
|
|
return tf.ensure_shape(r_tf, r_aval_shape_tf)
|
|
|
|
|
|
|
|
|
|
|
2021-02-04 10:56:13 +02:00
|
|
|
|
call_tf_p = core.Primitive("call_tf")
|
|
|
|
|
call_tf_p.multiple_results = True
|
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
# The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
|
|
|
|
|
def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
|
2021-02-05 12:04:25 +02:00
|
|
|
|
# On GPU we use dlpack to avoid copies of data to the host.
|
|
|
|
|
def _arg_jax_to_tf(arg_jax):
|
2023-02-10 20:56:02 -08:00
|
|
|
|
if (isinstance(arg_jax, jax.Array) and
|
2023-06-08 05:51:00 -07:00
|
|
|
|
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
|
2024-02-27 10:35:03 -08:00
|
|
|
|
arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
|
2024-04-11 16:44:19 +00:00
|
|
|
|
arg_dlpack = jax.dlpack.to_dlpack(arg_jax)
|
2021-02-05 12:04:25 +02:00
|
|
|
|
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
|
2023-08-18 16:50:36 -04:00
|
|
|
|
# The following avoids copies to the host on CPU, always for Array
|
2021-02-05 12:04:25 +02:00
|
|
|
|
# and even for ndarray if they are sufficiently aligned.
|
|
|
|
|
# TODO(necula): on TPU this copies to the host!
|
2023-12-20 12:47:43 -08:00
|
|
|
|
if getattr(arg_jax, 'dtype', None) == dtypes.float0:
|
|
|
|
|
return tf.zeros(shape=arg_jax.shape,
|
|
|
|
|
dtype=jax2tf_internal._tf_np_dtype_for_float0)
|
2021-02-05 12:04:25 +02:00
|
|
|
|
return tf.constant(np.asarray(arg_jax))
|
|
|
|
|
|
|
|
|
|
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
|
2021-06-10 17:01:22 +02:00
|
|
|
|
with jax2tf_internal.inside_call_tf():
|
|
|
|
|
# Call in TF eager mode
|
2021-06-25 07:39:09 +02:00
|
|
|
|
res_tf_flat = callable_flat_tf(*args_tf_flat)
|
2021-02-05 12:04:25 +02:00
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
def _res_tf_to_jax(res_tf: TfVal):
|
2024-02-27 10:35:03 -08:00
|
|
|
|
res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
|
|
|
|
|
if isinstance(res_tf, tf.Tensor) and jax_dtype.type in dlpack.SUPPORTED_DTYPES:
|
2021-02-05 12:04:25 +02:00
|
|
|
|
res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
|
|
|
|
|
res_jax_platform = res_tf_platform.lower()
|
2024-05-10 11:29:49 -07:00
|
|
|
|
if res_jax_platform in _DLPACK_PLATFORMS:
|
2021-02-05 12:04:25 +02:00
|
|
|
|
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
|
2021-06-22 06:38:22 -07:00
|
|
|
|
return jax.dlpack.from_dlpack(res_dlpack)
|
2021-02-05 12:04:25 +02:00
|
|
|
|
|
2023-02-15 23:40:12 -08:00
|
|
|
|
# When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
|
|
|
|
|
# To handle this special case, we create a numpy copy.
|
|
|
|
|
if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
|
|
|
|
|
return jax.device_put(jnp.array(res_tf.numpy()))
|
|
|
|
|
else:
|
|
|
|
|
return jax.device_put(np.asarray(res_tf))
|
2021-02-05 12:04:25 +02:00
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
return list(map(_res_tf_to_jax, res_tf_flat))
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
call_tf_p.def_impl(_call_tf_impl)
|
|
|
|
|
|
2022-11-30 23:21:55 -08:00
|
|
|
|
@functools.lru_cache(maxsize=128)
|
|
|
|
|
def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.ConcreteFunction
|
|
|
|
|
with jax2tf_internal.inside_call_tf():
|
|
|
|
|
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
|
|
|
|
|
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
2023-02-01 17:50:00 -08:00
|
|
|
|
# Mark the effectful instances of call_tf
|
2023-12-13 15:43:12 +01:00
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
2023-02-01 17:50:00 -08:00
|
|
|
|
class CallTfEffect(effects.Effect):
|
|
|
|
|
__str__ = lambda _: "CallTfEffect"
|
|
|
|
|
|
|
|
|
|
call_tf_effect = CallTfEffect()
|
2022-12-16 20:42:19 -08:00
|
|
|
|
|
2023-02-01 17:50:00 -08:00
|
|
|
|
effects.lowerable_effects.add_type(CallTfEffect)
|
|
|
|
|
effects.control_flow_allowed_effects.add_type(CallTfEffect)
|
|
|
|
|
effects.remat_allowed_effects.add_type(CallTfEffect)
|
|
|
|
|
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
|
2022-12-16 20:42:19 -08:00
|
|
|
|
|
|
|
|
|
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
class CallTfOrderedEffect(effects.Effect):
|
|
|
|
|
__str__ = lambda _: "CallTfOrderedEffect"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
call_tf_ordered_effect = CallTfOrderedEffect()
|
|
|
|
|
|
|
|
|
|
effects.lowerable_effects.add_type(CallTfOrderedEffect)
|
|
|
|
|
effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
|
|
|
|
|
effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
|
|
|
|
|
effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
|
|
|
|
|
effects.ordered_effects.add_type(CallTfOrderedEffect)
|
2023-09-18 11:23:49 -07:00
|
|
|
|
effects.shardable_ordered_effects.add_type(CallTfOrderedEffect)
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
|
|
|
|
|
|
2023-04-26 10:20:12 -07:00
|
|
|
|
def _call_tf_abstract_eval(
|
|
|
|
|
*args_flat_avals,
|
|
|
|
|
function_flat_tf,
|
|
|
|
|
args_flat_sig_tf,
|
|
|
|
|
has_side_effects,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ordered,
|
2023-04-26 10:20:12 -07:00
|
|
|
|
output_avals,
|
2023-05-03 09:04:01 -07:00
|
|
|
|
call_tf_graph,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
**__,
|
|
|
|
|
):
|
2022-11-30 23:21:55 -08:00
|
|
|
|
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
effects = set()
|
|
|
|
|
if ordered:
|
|
|
|
|
effects.add(call_tf_ordered_effect)
|
|
|
|
|
elif has_side_effects:
|
|
|
|
|
effects.add(call_tf_effect)
|
2022-11-30 23:21:55 -08:00
|
|
|
|
|
2023-04-27 13:42:23 +02:00
|
|
|
|
# If no output_avals is given, then we ask TF to infer the output shapes.
|
2023-03-08 14:09:35 -08:00
|
|
|
|
# We call this even if output_avals is given because it will ensure that
|
|
|
|
|
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
|
|
|
|
|
# there is a small cost of calling it more often than needed.
|
2022-11-30 23:21:55 -08:00
|
|
|
|
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
|
|
|
|
|
args_flat_sig_tf)
|
2023-07-31 15:17:48 -07:00
|
|
|
|
|
|
|
|
|
# In the case that the tf.function has no return value
|
2023-04-26 10:20:12 -07:00
|
|
|
|
if len(concrete_function_flat_tf.outputs) == 0:
|
2023-09-22 09:05:42 +01:00
|
|
|
|
return (), effects
|
2023-04-26 10:20:12 -07:00
|
|
|
|
|
2023-03-08 14:09:35 -08:00
|
|
|
|
if output_avals is not None:
|
|
|
|
|
return output_avals, effects
|
|
|
|
|
|
2022-11-30 23:21:55 -08:00
|
|
|
|
def is_fully_known_shape(s):
|
2023-10-10 22:33:03 +01:00
|
|
|
|
return s.rank is not None and all(d is not None for d in s)
|
2023-03-08 14:09:35 -08:00
|
|
|
|
|
|
|
|
|
if all(is_fully_known_shape(s)
|
|
|
|
|
for s in concrete_function_flat_tf.output_shapes):
|
|
|
|
|
avals_from_tf = tuple(
|
|
|
|
|
# We convert to JAX type, and canonicalize to 32-bit if necessary
|
|
|
|
|
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
|
|
|
|
|
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
|
|
|
|
|
concrete_function_flat_tf.output_shapes))
|
|
|
|
|
return avals_from_tf, effects
|
|
|
|
|
|
|
|
|
|
msg = ("call_tf cannot call functions whose output has dynamic shape. "
|
|
|
|
|
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
|
|
|
|
|
"Consider using the `output_shape_dtype` argument to call_tf. "
|
2024-09-20 07:51:48 -07:00
|
|
|
|
"\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
|
2023-03-08 14:09:35 -08:00
|
|
|
|
" for a discussion.")
|
|
|
|
|
raise ValueError(msg)
|
2022-11-30 23:21:55 -08:00
|
|
|
|
|
2023-04-26 10:20:12 -07:00
|
|
|
|
|
2022-12-16 20:42:19 -08:00
|
|
|
|
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
|
2023-03-24 11:26:44 -07:00
|
|
|
|
def _call_tf_lowering(
|
2023-04-27 13:42:23 +02:00
|
|
|
|
ctx: mlir.LoweringRuleContext,
|
2023-03-24 11:26:44 -07:00
|
|
|
|
*args_op,
|
|
|
|
|
platform,
|
|
|
|
|
function_flat_tf,
|
|
|
|
|
args_flat_sig_tf,
|
|
|
|
|
has_side_effects,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ordered,
|
2023-05-03 09:04:01 -07:00
|
|
|
|
call_tf_graph,
|
2023-03-24 11:26:44 -07:00
|
|
|
|
output_avals,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
**_,
|
|
|
|
|
):
|
2022-11-30 23:21:55 -08:00
|
|
|
|
# We use the same TF lowering device as for the embedding JAX computation.
|
|
|
|
|
# One example when this is needed is when the code refers to variables on one
|
|
|
|
|
# device. Or, for sharding annotations (only supported on TPU).
|
2023-04-26 10:20:12 -07:00
|
|
|
|
|
2022-11-30 23:21:55 -08:00
|
|
|
|
if platform in ["cpu", "tpu"]:
|
|
|
|
|
tf_platform = platform.upper()
|
|
|
|
|
elif platform == "cuda":
|
|
|
|
|
tf_platform = "GPU"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("platform {platform} not supported")
|
2021-07-10 18:49:25 +03:00
|
|
|
|
|
2022-11-30 23:21:55 -08:00
|
|
|
|
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
|
2021-06-25 07:39:09 +02:00
|
|
|
|
|
2021-07-10 18:49:25 +03:00
|
|
|
|
captured_inputs = []
|
2021-06-25 07:39:09 +02:00
|
|
|
|
if concrete_function_flat_tf.captured_inputs:
|
2021-02-04 10:56:13 +02:00
|
|
|
|
# The function uses either captured variables or tensors.
|
|
|
|
|
msg = (
|
2021-07-10 18:49:25 +03:00
|
|
|
|
"call_tf works best with a TensorFlow function that does not capture "
|
|
|
|
|
"variables or tensors from the context. "
|
2024-09-20 07:51:48 -07:00
|
|
|
|
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
|
2021-07-10 18:49:25 +03:00
|
|
|
|
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
|
2021-02-04 10:56:13 +02:00
|
|
|
|
logging.warning(msg)
|
2021-06-25 07:39:09 +02:00
|
|
|
|
for inp in concrete_function_flat_tf.captured_inputs:
|
2021-09-14 14:15:40 +02:00
|
|
|
|
if inp.dtype == tf.resource: # A variable; lookup by handle
|
|
|
|
|
inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
|
|
|
|
|
assert len(inp_vars) == 1, f"Found {inp_vars}"
|
|
|
|
|
captured_inputs.append(inp_vars[0])
|
2021-02-04 10:56:13 +02:00
|
|
|
|
else:
|
2021-07-10 18:49:25 +03:00
|
|
|
|
captured_inputs.append(inp)
|
2021-06-16 13:25:56 +03:00
|
|
|
|
|
2024-10-24 10:12:54 +02:00
|
|
|
|
# The following use case happens when we call_tf a restored saved model that
|
|
|
|
|
# includes parameters (hence functions closing over tf.Variable), and then
|
|
|
|
|
# we jax2tf.convert it with native serialization, under tf.function (or
|
|
|
|
|
# for saving to saved model). The `np.asarray(inp)` fails because it thinks
|
|
|
|
|
# it is in TF graph mode. The `tf.init_scope()` lifts out of function-building
|
|
|
|
|
# graph scopes, and allows us to read the values of the variables
|
|
|
|
|
with tf.init_scope():
|
|
|
|
|
captured_ops = tuple(
|
|
|
|
|
mlir.ir_constant(np.asarray(inp))
|
|
|
|
|
for inp in captured_inputs
|
|
|
|
|
)
|
2023-04-27 13:42:23 +02:00
|
|
|
|
|
2023-05-03 09:04:01 -07:00
|
|
|
|
if call_tf_graph:
|
2023-03-24 11:26:44 -07:00
|
|
|
|
with jax2tf_internal.inside_call_tf():
|
|
|
|
|
return emit_tf_embedded_graph_custom_call(
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ctx,
|
2023-03-24 11:26:44 -07:00
|
|
|
|
concrete_function_flat_tf,
|
|
|
|
|
tuple(args_op) + captured_ops,
|
|
|
|
|
has_side_effects,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ordered,
|
|
|
|
|
output_avals,
|
|
|
|
|
)
|
2023-03-24 11:26:44 -07:00
|
|
|
|
|
2023-02-01 14:30:35 -08:00
|
|
|
|
def convert_to_spec(x):
|
|
|
|
|
if isinstance(x, tf.TensorSpec):
|
|
|
|
|
return x
|
|
|
|
|
else:
|
|
|
|
|
return tf.TensorSpec.from_tensor(x)
|
2023-01-16 22:42:42 -08:00
|
|
|
|
|
2023-02-01 14:30:35 -08:00
|
|
|
|
args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
|
2022-11-30 23:21:55 -08:00
|
|
|
|
|
2021-06-25 07:39:09 +02:00
|
|
|
|
with jax2tf_internal.inside_call_tf():
|
|
|
|
|
try:
|
2024-01-17 09:33:14 -08:00
|
|
|
|
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
|
|
|
|
|
*args_tf_flat
|
|
|
|
|
)(stage="hlo_serialized", platform_name=tf_platform)
|
2021-06-25 07:39:09 +02:00
|
|
|
|
except Exception as e:
|
2023-04-27 13:42:23 +02:00
|
|
|
|
msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
|
|
|
|
|
"\ncall_tf can used " +
|
2022-11-30 23:21:55 -08:00
|
|
|
|
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
2023-04-27 13:42:23 +02:00
|
|
|
|
"compilable functions with static output shapes.\n" +
|
2024-09-20 07:51:48 -07:00
|
|
|
|
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
|
2023-04-27 13:42:23 +02:00
|
|
|
|
"\n\nCaught TensorFlow exception: " + str(e))
|
2021-06-25 07:39:09 +02:00
|
|
|
|
raise ValueError(msg) from e
|
|
|
|
|
|
2021-07-10 18:49:25 +03:00
|
|
|
|
xla_comp = xla_client.XlaComputation(func_tf_hlo)
|
|
|
|
|
|
|
|
|
|
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
|
2023-02-16 11:54:25 -08:00
|
|
|
|
def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray:
|
2021-12-01 11:34:37 +02:00
|
|
|
|
if not res_shape.is_static():
|
|
|
|
|
msg = ("Compiled TensorFlow function has dynamic output shape " +
|
|
|
|
|
f"{res_shape}. call_tf can used " +
|
|
|
|
|
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
2023-03-17 21:50:26 -07:00
|
|
|
|
"compilable functions with static output shapes. " +
|
2024-09-20 07:51:48 -07:00
|
|
|
|
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.")
|
2021-12-01 11:34:37 +02:00
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
2021-07-10 18:49:25 +03:00
|
|
|
|
res_dtype = res_shape.numpy_dtype()
|
|
|
|
|
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
|
|
|
|
|
return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
|
|
|
|
|
|
|
|
|
|
result_shape = xla_comp.program_shape().result_shape()
|
|
|
|
|
if not result_shape.is_tuple():
|
|
|
|
|
# TF does not wrap singletons as tuples, but JAX expects tuples because
|
|
|
|
|
# call_tf is a multiple_results primitive.
|
|
|
|
|
result_shapes = (result_shape,)
|
|
|
|
|
else:
|
2021-07-27 15:50:47 +03:00
|
|
|
|
result_shapes = result_shape.tuple_shapes() # type: ignore
|
2021-07-10 18:49:25 +03:00
|
|
|
|
|
2024-05-17 09:46:36 +01:00
|
|
|
|
result_avals = tuple(map(canonical_res_aval, result_shapes))
|
2021-07-10 18:49:25 +03:00
|
|
|
|
|
2023-04-27 13:42:23 +02:00
|
|
|
|
submodule = mlir.xla_computation_to_mlir_module(xla_comp)
|
|
|
|
|
symtab = ir.SymbolTable(submodule.operation)
|
|
|
|
|
callee_result_types = symtab["main"].type.results
|
|
|
|
|
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
|
|
|
|
f"call_tf_{function_flat_tf.name}",
|
2024-04-29 17:20:20 +03:00
|
|
|
|
submodule,
|
|
|
|
|
dst_symtab=ctx.module_context.symbol_table)
|
2023-04-27 13:42:23 +02:00
|
|
|
|
call = func_dialect.CallOp(callee_result_types,
|
|
|
|
|
ir.FlatSymbolRefAttr.get(fn),
|
|
|
|
|
tuple(args_op) + captured_ops)
|
2025-02-24 17:45:19 -05:00
|
|
|
|
flat_results = call.results
|
2022-04-19 13:59:28 -07:00
|
|
|
|
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
if ordered:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"ordered=True is not supported in the jitted context without"
|
|
|
|
|
" `call_tf_graph=True`"
|
|
|
|
|
)
|
|
|
|
|
|
2023-04-27 13:42:23 +02:00
|
|
|
|
outputs = []
|
|
|
|
|
for op, res_aval, res_shape in zip(flat_results, result_avals,
|
|
|
|
|
result_shapes):
|
|
|
|
|
if res_aval.dtype != res_shape.numpy_dtype():
|
|
|
|
|
op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
|
|
|
|
|
outputs.append(op)
|
|
|
|
|
return outputs
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
|
|
|
|
|
2022-11-30 23:21:55 -08:00
|
|
|
|
def _register_call_lowering(platform):
|
|
|
|
|
mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
|
|
|
|
|
platform=platform),
|
|
|
|
|
platform=platform)
|
|
|
|
|
for platform in ("cpu", "cuda", "tpu"):
|
|
|
|
|
_register_call_lowering(platform)
|
2021-02-04 10:56:13 +02:00
|
|
|
|
|
2023-03-08 14:09:35 -08:00
|
|
|
|
# Support the call_tf under jax2tf.convert in eager mode
|
2021-06-10 17:01:22 +02:00
|
|
|
|
def _jax2tf_call_tf(*args: TfVal,
|
2021-06-25 07:39:09 +02:00
|
|
|
|
callable_flat_tf: Callable,
|
|
|
|
|
**_) -> TfVal:
|
2021-07-30 13:35:21 +03:00
|
|
|
|
with jax2tf_internal.inside_call_tf():
|
|
|
|
|
res_tf_flat = callable_flat_tf(*args)
|
2021-06-10 17:01:22 +02:00
|
|
|
|
return res_tf_flat
|
|
|
|
|
|
2021-07-10 18:49:25 +03:00
|
|
|
|
jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
|
2023-03-24 11:26:44 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def emit_tf_embedded_graph_custom_call(
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ctx: mlir.LoweringRuleContext,
|
2023-03-24 11:26:44 -07:00
|
|
|
|
concrete_function_flat_tf,
|
2023-04-27 13:42:23 +02:00
|
|
|
|
operands: Sequence[ir.Value],
|
2023-03-24 11:26:44 -07:00
|
|
|
|
has_side_effects,
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
ordered,
|
2023-03-24 11:26:44 -07:00
|
|
|
|
output_avals,
|
|
|
|
|
):
|
2023-04-27 13:42:23 +02:00
|
|
|
|
"""Emits a custom call referencing a tf.Graph embedding of the TF function.
|
2023-03-24 11:26:44 -07:00
|
|
|
|
|
2023-05-25 15:58:16 -07:00
|
|
|
|
All call_tf called function information is stored in tf.metadata.
|
2023-03-24 11:26:44 -07:00
|
|
|
|
This includes:
|
2023-05-25 15:58:16 -07:00
|
|
|
|
(1) The called function name: This name will be used by the runtime to execute
|
2023-03-24 11:26:44 -07:00
|
|
|
|
the callback.
|
2023-05-25 15:58:16 -07:00
|
|
|
|
(2) The called function index in the XLACallModule `function_list` attribute.
|
2023-03-24 11:26:44 -07:00
|
|
|
|
"""
|
2023-05-25 15:58:16 -07:00
|
|
|
|
call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
|
|
|
|
|
if call_tf_concrete_function_list is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"call_tf_graph=True only support exporting by jax2tf.convert currently."
|
|
|
|
|
)
|
2023-10-13 09:42:12 -07:00
|
|
|
|
# TODO(necula): It is dangerous to modify global state when lowering because
|
|
|
|
|
# there are a number of lowering caches that only cache the StableHLO.
|
|
|
|
|
# See call_tf_test.py:test_multi_platform_call_tf_graph.
|
2023-05-25 15:58:16 -07:00
|
|
|
|
called_index = add_to_call_tf_concrete_function_list(
|
|
|
|
|
concrete_function_flat_tf, call_tf_concrete_function_list)
|
2023-05-03 09:04:01 -07:00
|
|
|
|
tf_backend_config = {
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
"has_token_input_output": ir.BoolAttr.get(ordered),
|
2023-05-25 15:58:16 -07:00
|
|
|
|
"called_index": mlir.i64_attr(called_index),
|
2023-03-24 11:26:44 -07:00
|
|
|
|
}
|
2023-09-22 09:05:42 +01:00
|
|
|
|
result_avals = ctx.avals_out if ctx.avals_out is not None else ()
|
2023-03-24 11:26:44 -07:00
|
|
|
|
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
operands = list(operands)
|
|
|
|
|
result_types = list(
|
2024-07-03 16:38:18 -04:00
|
|
|
|
mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
|
2023-03-24 11:26:44 -07:00
|
|
|
|
)
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
if ordered:
|
2024-07-01 08:42:48 -04:00
|
|
|
|
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect))
|
2024-07-03 16:38:18 -04:00
|
|
|
|
result_types.insert(0, mlir.token_type())
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
|
|
|
|
|
custom_call = hlo.CustomCallOp(
|
2023-03-24 11:26:44 -07:00
|
|
|
|
result_types,
|
|
|
|
|
operands,
|
2023-10-13 09:42:12 -07:00
|
|
|
|
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
|
2023-03-24 11:26:44 -07:00
|
|
|
|
has_side_effect=ir.BoolAttr.get(has_side_effects),
|
|
|
|
|
api_version=mlir.i32_attr(2),
|
|
|
|
|
called_computations=ir.ArrayAttr.get([]),
|
|
|
|
|
backend_config=ir.StringAttr.get(""),
|
|
|
|
|
)
|
|
|
|
|
# Store TF metadata in unregistered attribute
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
|
|
|
|
|
tf_backend_config
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
results = list(custom_call.results)
|
|
|
|
|
if ordered:
|
|
|
|
|
token = results.pop(0)
|
2024-07-01 08:42:48 -04:00
|
|
|
|
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: token}))
|
Add (optional) ordered effects for `jax2tf.call_tf`
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
2023-05-24 11:47:58 -07:00
|
|
|
|
|
|
|
|
|
return results
|
2023-05-25 15:58:16 -07:00
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
|
2023-10-13 09:42:12 -07:00
|
|
|
|
try:
|
|
|
|
|
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
|
|
|
|
|
except ValueError:
|
|
|
|
|
called_index = len(call_tf_concrete_function_list)
|
|
|
|
|
call_tf_concrete_function_list.append(concrete_tf_fn)
|
2023-05-25 15:58:16 -07:00
|
|
|
|
return called_index
|