From 55c6bdfe9c5d0631200cb76e4f56481b3656b03f Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 21 Sep 2020 14:18:31 -0700 Subject: [PATCH] Clean-up todos related to the upgrade of jaxlib. PiperOrigin-RevId: 332932271 --- jax/api.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/jax/api.py b/jax/api.py index b1bbb2bfe..4cbb79d9d 100644 --- a/jax/api.py +++ b/jax/api.py @@ -49,6 +49,7 @@ from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure, treedef_is_leaf, Partial) from .util import (unzip2, curry, partial, safe_map, safe_zip, prod, split_list, extend_name_stack, wrap_name, cache) +from .lib import jax_jit from .lib import xla_bridge as xb from .lib import xla_client as xc # Unused imports to be exported @@ -308,9 +309,7 @@ def _cpp_jit( xla_result.shaped_arrays, xla_result.lazy_expressions) - # TODO(jblespiau): Remove when C++ jit has landed (jaxlib.version >= 0.1.54) # Delay the import, because it requires a new version of jaxlib. - from .lib import jax_jit # pylint: disable=g-import-not-at-top cpp_jitted_f = jax_jit.jit(fun, cache_miss, python_jitted_f, FLAGS.jax_enable_x64, config.read("jax_disable_jit"), static_argnums) @@ -414,16 +413,12 @@ def disable_jit(): prev_val = _thread_local_state.jit_is_disabled _thread_local_state.jit_is_disabled = True - # TODO(jblespiau): Remove when C++ jit has landed (jaxlib.version >= 0.1.54) - if hasattr(lib, "jax_jit") and hasattr(lib.jax_jit, "set_disable_jit"): - prev_cpp_val = lib.jax_jit.get_disable_jit() - lib.jax_jit.set_disable_jit(True) - + prev_cpp_val = lib.jax_jit.get_disable_jit() + lib.jax_jit.set_disable_jit(True) yield finally: _thread_local_state.jit_is_disabled = prev_val - if hasattr(lib, "jax_jit") and hasattr(lib.jax_jit, "set_disable_jit"): - lib.jax_jit.set_disable_jit(prev_cpp_val) + lib.jax_jit.set_disable_jit(prev_cpp_val) def _jit_is_disabled():