Make jax.ensure_compile_time_eval correctly exposed as a public API

This function was added as a public API (#7987) but py.type static
checkers do not recognize it as a public API because of the alias name.

`jax.eval_context` exists only for backward compatibility, so the
correct import would be to import `ensure_compile_time_eval` directly
from `jax._src.core`.
This commit is contained in:
Jongwook Choi 2022-12-21 20:12:08 -05:00
parent 704b9fc454
commit 479f33680d

View File

@ -62,7 +62,7 @@ from jax._src.config import (
transfer_guard_device_to_host as transfer_guard_device_to_host,
spmd_mode as spmd_mode,
)
from jax._src.core import eval_context as ensure_compile_time_eval
from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval
from jax._src.environment_info import print_environment_info as print_environment_info
from jax._src.api import (
ad, # TODO(phawkins): update users to avoid this.