mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[jax2tf] Fix the scoping of the enable_xla conversion parameter
Previously, the global enable_xla flag was set upon entry to `jax.convert`. It should instead be set only for the duration of the just-in-time conversion, which may happen later when the converted function is invoked.
This commit is contained in:
parent
1e9c7e4995
commit
2ad9c0c34c
10
CHANGELOG.md
10
CHANGELOG.md
@ -10,6 +10,16 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.2.14 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
|
||||
* New features:
|
||||
|
||||
* Breaking changes:
|
||||
|
||||
* Bug fixes:
|
||||
* The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter
|
||||
properly to apply only during the just-in-time conversion
|
||||
({jax-issue}`#6720`).
|
||||
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
|
||||
`tf.Variable` ({jax-issue}`#6572`).
|
||||
|
||||
* Bug fixes:
|
||||
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
|
||||
|
@ -190,8 +190,6 @@ def convert(fun: Callable, *,
|
||||
A version of `fun` that expects TfVals as arguments (or
|
||||
tuple/lists/dicts) thereof, and returns TfVals as outputs.
|
||||
"""
|
||||
global _enable_xla
|
||||
_enable_xla = enable_xla
|
||||
api._check_callable(fun)
|
||||
|
||||
def converted_fun(*args: TfVal) -> TfVal:
|
||||
@ -276,6 +274,9 @@ def convert(fun: Callable, *,
|
||||
try:
|
||||
global _shape_env
|
||||
assert not _shape_env, f"Unexpected shape environment {_shape_env}"
|
||||
global _enable_xla
|
||||
prev_enable_xla = _enable_xla
|
||||
_enable_xla = enable_xla
|
||||
_shape_env = shapeenv
|
||||
|
||||
if with_gradient:
|
||||
@ -296,6 +297,7 @@ def convert(fun: Callable, *,
|
||||
for o, _ in out_flat_raw]
|
||||
finally:
|
||||
_shape_env = {}
|
||||
_enable_xla = prev_enable_xla
|
||||
|
||||
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
|
||||
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
|
||||
|
@ -22,6 +22,7 @@ from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
@ -516,6 +517,34 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5],
|
||||
jnp.bfloat16))
|
||||
|
||||
def test_enable_xla(self):
|
||||
# Tests that enable_xla flag is properly scoped to a conversion.
|
||||
def fun(x):
|
||||
# Can be converted only if enable_xla is on, due to negative padding.
|
||||
return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)])
|
||||
|
||||
tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True)
|
||||
tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False)
|
||||
x = np.ones((2, 3), dtype=np.float32)
|
||||
|
||||
self.assertAllClose(fun(x), tf_fun_with_xla(x))
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"Call to pad cannot be converted with enable_xla=False"):
|
||||
tf_fun_without_xla(x)
|
||||
|
||||
# Now in reverse order
|
||||
def fun2(x):
|
||||
# Can be converted only if enable_xla is on, due to negative padding.
|
||||
return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)])
|
||||
|
||||
tf_fun2_without_xla = jax2tf.convert(fun2, enable_xla=False)
|
||||
tf_fun2_with_xla = jax2tf.convert(fun2, enable_xla=True)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"Call to pad cannot be converted with enable_xla=False"):
|
||||
tf_fun2_without_xla(x)
|
||||
self.assertAllClose(fun(x), tf_fun2_with_xla(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user