mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Clean-up todos related to the upgrade of jaxlib.
PiperOrigin-RevId: 332932271
This commit is contained in:
parent
9f53d2a8d8
commit
55c6bdfe9c
13
jax/api.py
13
jax/api.py
@ -49,6 +49,7 @@ from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
treedef_is_leaf, Partial)
|
||||
from .util import (unzip2, curry, partial, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, wrap_name, cache)
|
||||
from .lib import jax_jit
|
||||
from .lib import xla_bridge as xb
|
||||
from .lib import xla_client as xc
|
||||
# Unused imports to be exported
|
||||
@ -308,9 +309,7 @@ def _cpp_jit(
|
||||
xla_result.shaped_arrays,
|
||||
xla_result.lazy_expressions)
|
||||
|
||||
# TODO(jblespiau): Remove when C++ jit has landed (jaxlib.version >= 0.1.54)
|
||||
# Delay the import, because it requires a new version of jaxlib.
|
||||
from .lib import jax_jit # pylint: disable=g-import-not-at-top
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss,
|
||||
python_jitted_f, FLAGS.jax_enable_x64,
|
||||
config.read("jax_disable_jit"), static_argnums)
|
||||
@ -414,16 +413,12 @@ def disable_jit():
|
||||
prev_val = _thread_local_state.jit_is_disabled
|
||||
_thread_local_state.jit_is_disabled = True
|
||||
|
||||
# TODO(jblespiau): Remove when C++ jit has landed (jaxlib.version >= 0.1.54)
|
||||
if hasattr(lib, "jax_jit") and hasattr(lib.jax_jit, "set_disable_jit"):
|
||||
prev_cpp_val = lib.jax_jit.get_disable_jit()
|
||||
lib.jax_jit.set_disable_jit(True)
|
||||
|
||||
prev_cpp_val = lib.jax_jit.get_disable_jit()
|
||||
lib.jax_jit.set_disable_jit(True)
|
||||
yield
|
||||
finally:
|
||||
_thread_local_state.jit_is_disabled = prev_val
|
||||
if hasattr(lib, "jax_jit") and hasattr(lib.jax_jit, "set_disable_jit"):
|
||||
lib.jax_jit.set_disable_jit(prev_cpp_val)
|
||||
lib.jax_jit.set_disable_jit(prev_cpp_val)
|
||||
|
||||
|
||||
def _jit_is_disabled():
|
||||
|
Loading…
x
Reference in New Issue
Block a user