mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Copybara import of the project:
-- b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d by George Necula <gcnecula@gmail.com>: [shape_poly] Fix lowering when we have both dimension variables and tokens COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16575 from gnecula:call_tf_poly b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d PiperOrigin-RevId: 544252624
This commit is contained in:
parent
64b0962f4e
commit
46aa9e0b31
@ -889,6 +889,7 @@ def lower_jaxpr_to_fun(
|
||||
output_token_types = []
|
||||
token_types = [token_type() for _ in effects]
|
||||
token_avals = [core.AbstractToken] * num_tokens
|
||||
# Order of arguments: dim vars, tokens, array inputs
|
||||
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
|
||||
input_types = [*dim_var_types, *token_types, *input_types]
|
||||
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
|
||||
@ -968,7 +969,7 @@ def lower_jaxpr_to_fun(
|
||||
attrs["tf.aliasing_output"] = i32_attr(alias)
|
||||
|
||||
if num_tokens > 0:
|
||||
token_arg_attrs = arg_attrs[num_dim_vars:num_tokens]
|
||||
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens]
|
||||
for attrs in token_arg_attrs:
|
||||
attrs["jax.token"] = ir.BoolAttr.get(True)
|
||||
|
||||
|
@ -1541,6 +1541,28 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
)
|
||||
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(testcase_name=f"{ordered=}", ordered=ordered)
|
||||
for ordered in [True, False]
|
||||
])
|
||||
def test_call_tf_graph_polymorphic(self, ordered: bool):
|
||||
@tf.function(jit_compile=True, autograph=False)
|
||||
@partial(jax2tf.convert,
|
||||
with_gradient=False,
|
||||
native_serialization=True,
|
||||
polymorphic_shapes=["(b)"])
|
||||
@jax.jit
|
||||
def tf_f_2(x):
|
||||
tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
|
||||
jax2tf.call_tf(tf_f,
|
||||
call_tf_graph=True,
|
||||
ordered=ordered,
|
||||
output_shape_dtype=None)(x)
|
||||
return x
|
||||
|
||||
x = np.arange(3, dtype=np.int32)
|
||||
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user