Copybara import of the project:

--
b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d by George Necula <gcnecula@gmail.com>:

[shape_poly] Fix lowering when we have both dimension variables and tokens

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16575 from gnecula:call_tf_poly b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d
PiperOrigin-RevId: 544252624
This commit is contained in:
George Necula 2023-06-28 22:14:22 -07:00 committed by jax authors
parent 64b0962f4e
commit 46aa9e0b31
2 changed files with 24 additions and 1 deletions

View File

@ -889,6 +889,7 @@ def lower_jaxpr_to_fun(
output_token_types = []
token_types = [token_type() for _ in effects]
token_avals = [core.AbstractToken] * num_tokens
# Order of arguments: dim vars, tokens, array inputs
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
@ -968,7 +969,7 @@ def lower_jaxpr_to_fun(
attrs["tf.aliasing_output"] = i32_attr(alias)
if num_tokens > 0:
token_arg_attrs = arg_attrs[num_dim_vars:num_tokens]
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens]
for attrs in token_arg_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)

View File

@ -1541,6 +1541,28 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
)
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
@parameterized.named_parameters([
dict(testcase_name=f"{ordered=}", ordered=ordered)
for ordered in [True, False]
])
def test_call_tf_graph_polymorphic(self, ordered: bool):
@tf.function(jit_compile=True, autograph=False)
@partial(jax2tf.convert,
with_gradient=False,
native_serialization=True,
polymorphic_shapes=["(b)"])
@jax.jit
def tf_f_2(x):
tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
jax2tf.call_tf(tf_f,
call_tf_graph=True,
ordered=ordered,
output_shape_dtype=None)(x)
return x
x = np.arange(3, dtype=np.int32)
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())