[jax2tf] Fix the scoping of the enable_xla conversion parameter

Previously, the global enable_xla flag was set upon entry to
`jax.convert`. It should instead be set only for the duration
of the just-in-time conversion, which may happen later when
the converted function is invoked.
This commit is contained in:
George Necula 2021-05-11 11:11:37 +03:00
parent 1e9c7e4995
commit 2ad9c0c34c
3 changed files with 43 additions and 2 deletions

View File

@ -10,6 +10,16 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.14 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
* New features:
* Breaking changes:
* Bug fixes:
* The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter
properly to apply only during the just-in-time conversion
({jax-issue}`#6720`).
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
`tf.Variable` ({jax-issue}`#6572`).
* Bug fixes:
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the

View File

@ -190,8 +190,6 @@ def convert(fun: Callable, *,
A version of `fun` that expects TfVals as arguments (or
tuple/lists/dicts) thereof, and returns TfVals as outputs.
"""
global _enable_xla
_enable_xla = enable_xla
api._check_callable(fun)
def converted_fun(*args: TfVal) -> TfVal:
@ -276,6 +274,9 @@ def convert(fun: Callable, *,
try:
global _shape_env
assert not _shape_env, f"Unexpected shape environment {_shape_env}"
global _enable_xla
prev_enable_xla = _enable_xla
_enable_xla = enable_xla
_shape_env = shapeenv
if with_gradient:
@ -296,6 +297,7 @@ def convert(fun: Callable, *,
for o, _ in out_flat_raw]
finally:
_shape_env = {}
_enable_xla = prev_enable_xla
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)

View File

@ -22,6 +22,7 @@ from absl.testing import parameterized
import jax
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config
@ -516,6 +517,34 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5],
jnp.bfloat16))
def test_enable_xla(self):
# Tests that enable_xla flag is properly scoped to a conversion.
def fun(x):
# Can be converted only if enable_xla is on, due to negative padding.
return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)])
tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True)
tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False)
x = np.ones((2, 3), dtype=np.float32)
self.assertAllClose(fun(x), tf_fun_with_xla(x))
with self.assertRaisesRegex(NotImplementedError,
"Call to pad cannot be converted with enable_xla=False"):
tf_fun_without_xla(x)
# Now in reverse order
def fun2(x):
# Can be converted only if enable_xla is on, due to negative padding.
return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)])
tf_fun2_without_xla = jax2tf.convert(fun2, enable_xla=False)
tf_fun2_with_xla = jax2tf.convert(fun2, enable_xla=True)
with self.assertRaisesRegex(NotImplementedError,
"Call to pad cannot be converted with enable_xla=False"):
tf_fun2_without_xla(x)
self.assertAllClose(fun(x), tf_fun2_with_xla(x))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())