diff --git a/CHANGELOG.md b/CHANGELOG.md index e289615ff..ea91ecce9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,16 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.14 (unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master). +* New features: + +* Breaking changes: + +* Bug fixes: + * The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter + properly to apply only during the just-in-time conversion + ({jax-issue}`#6720`). + * Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured + `tf.Variable` ({jax-issue}`#6572`). * Bug fixes: * The {func}`jax2tf.convert` now converts `lax.dot_general` using the diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 26af1fe2c..fafea2703 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -190,8 +190,6 @@ def convert(fun: Callable, *, A version of `fun` that expects TfVals as arguments (or tuple/lists/dicts) thereof, and returns TfVals as outputs. """ - global _enable_xla - _enable_xla = enable_xla api._check_callable(fun) def converted_fun(*args: TfVal) -> TfVal: @@ -276,6 +274,9 @@ def convert(fun: Callable, *, try: global _shape_env assert not _shape_env, f"Unexpected shape environment {_shape_env}" + global _enable_xla + prev_enable_xla = _enable_xla + _enable_xla = enable_xla _shape_env = shapeenv if with_gradient: @@ -296,6 +297,7 @@ def convert(fun: Callable, *, for o, _ in out_flat_raw] finally: _shape_env = {} + _enable_xla = prev_enable_xla out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat] out = tree_util.tree_unflatten(out_tree_thunk(), out_flat) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 7b03afe55..e3d8cee59 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized import jax from jax import dtypes +from jax import lax from jax import numpy as jnp from jax import test_util as jtu from jax.config import config @@ -516,6 +517,34 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5], jnp.bfloat16)) + def test_enable_xla(self): + # Tests that enable_xla flag is properly scoped to a conversion. + def fun(x): + # Can be converted only if enable_xla is on, due to negative padding. + return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)]) + + tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True) + tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False) + x = np.ones((2, 3), dtype=np.float32) + + self.assertAllClose(fun(x), tf_fun_with_xla(x)) + with self.assertRaisesRegex(NotImplementedError, + "Call to pad cannot be converted with enable_xla=False"): + tf_fun_without_xla(x) + + # Now in reverse order + def fun2(x): + # Can be converted only if enable_xla is on, due to negative padding. + return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)]) + + tf_fun2_without_xla = jax2tf.convert(fun2, enable_xla=False) + tf_fun2_with_xla = jax2tf.convert(fun2, enable_xla=True) + + with self.assertRaisesRegex(NotImplementedError, + "Call to pad cannot be converted with enable_xla=False"): + tf_fun2_without_xla(x) + self.assertAllClose(fun(x), tf_fun2_with_xla(x)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())