mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Guard has_explicit_device
with xla_client version
This commit is contained in:
parent
ea54754c49
commit
426c7356fb
@ -63,6 +63,7 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
@ -567,6 +568,10 @@ def _cpp_jit(
|
||||
|
||||
return _BackendAndDeviceInfo(default_device, committed_to_device)
|
||||
|
||||
jitted_f_kwargs = {}
|
||||
if xla_extension_version >= 71:
|
||||
jitted_f_kwargs["has_explicit_device"] = (
|
||||
device is not None or backend is not None)
|
||||
cpp_jitted_f = jax_jit.jit(
|
||||
fun,
|
||||
cache_miss,
|
||||
@ -575,7 +580,7 @@ def _cpp_jit(
|
||||
static_argnames=static_argnames,
|
||||
donate_argnums=donate_argnums,
|
||||
cache=_cpp_jit_cache,
|
||||
has_explicit_device=device is not None or backend is not None)
|
||||
**jitted_f_kwargs) # type: ignore
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
|
||||
f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
|
||||
|
Loading…
x
Reference in New Issue
Block a user