Guard has_explicit_device with xla_client version

This commit is contained in:
Sharad Vikram 2022-06-02 11:54:58 -07:00
parent ea54754c49
commit 426c7356fb

View File

@ -63,6 +63,7 @@ from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import xla_extension_version
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
@ -567,6 +568,10 @@ def _cpp_jit(
return _BackendAndDeviceInfo(default_device, committed_to_device)
jitted_f_kwargs = {}
if xla_extension_version >= 71:
jitted_f_kwargs["has_explicit_device"] = (
device is not None or backend is not None)
cpp_jitted_f = jax_jit.jit(
fun,
cache_miss,
@ -575,7 +580,7 @@ def _cpp_jit(
static_argnames=static_argnames,
donate_argnums=donate_argnums,
cache=_cpp_jit_cache,
has_explicit_device=device is not None or backend is not None)
**jitted_f_kwargs) # type: ignore
f_jitted = wraps(fun)(cpp_jitted_f)
f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,