Clean-up todos related to the upgrade of jaxlib.

PiperOrigin-RevId: 332932271
This commit is contained in:
jax authors 2020-09-21 14:18:31 -07:00
parent 9f53d2a8d8
commit 55c6bdfe9c

View File

@ -49,6 +49,7 @@ from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
treedef_is_leaf, Partial)
from .util import (unzip2, curry, partial, safe_map, safe_zip, prod, split_list,
extend_name_stack, wrap_name, cache)
from .lib import jax_jit
from .lib import xla_bridge as xb
from .lib import xla_client as xc
# Unused imports to be exported
@ -308,9 +309,7 @@ def _cpp_jit(
xla_result.shaped_arrays,
xla_result.lazy_expressions)
# TODO(jblespiau): Remove when C++ jit has landed (jaxlib.version >= 0.1.54)
# Delay the import, because it requires a new version of jaxlib.
from .lib import jax_jit # pylint: disable=g-import-not-at-top
cpp_jitted_f = jax_jit.jit(fun, cache_miss,
python_jitted_f, FLAGS.jax_enable_x64,
config.read("jax_disable_jit"), static_argnums)
@ -414,16 +413,12 @@ def disable_jit():
prev_val = _thread_local_state.jit_is_disabled
_thread_local_state.jit_is_disabled = True
# TODO(jblespiau): Remove when C++ jit has landed (jaxlib.version >= 0.1.54)
if hasattr(lib, "jax_jit") and hasattr(lib.jax_jit, "set_disable_jit"):
prev_cpp_val = lib.jax_jit.get_disable_jit()
lib.jax_jit.set_disable_jit(True)
prev_cpp_val = lib.jax_jit.get_disable_jit()
lib.jax_jit.set_disable_jit(True)
yield
finally:
_thread_local_state.jit_is_disabled = prev_val
if hasattr(lib, "jax_jit") and hasattr(lib.jax_jit, "set_disable_jit"):
lib.jax_jit.set_disable_jit(prev_cpp_val)
lib.jax_jit.set_disable_jit(prev_cpp_val)
def _jit_is_disabled():