Respect jax_disable_jit in pjit

PiperOrigin-RevId: 503297194
This commit is contained in:
Yash Katariya 2023-01-19 16:35:23 -08:00 committed by jax authors
parent 622522e4a8
commit 6dd4ebc8da

View File

@ -122,6 +122,8 @@ def _python_pjit(fun: Callable, infer_params_fn):
@wraps(fun)
@api_boundary
def wrapped(*args, **kwargs):
if config.jax_disable_jit:
return fun(*args, **kwargs)
return _python_pjit_helper(infer_params_fn, *args, **kwargs)[0]
return wrapped