mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Respect jax_disable_jit in pjit
PiperOrigin-RevId: 503297194
This commit is contained in:
parent
622522e4a8
commit
6dd4ebc8da
@ -122,6 +122,8 @@ def _python_pjit(fun: Callable, infer_params_fn):
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def wrapped(*args, **kwargs):
|
||||
if config.jax_disable_jit:
|
||||
return fun(*args, **kwargs)
|
||||
return _python_pjit_helper(infer_params_fn, *args, **kwargs)[0]
|
||||
|
||||
return wrapped
|
||||
|
Loading…
x
Reference in New Issue
Block a user