mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make jax.ensure_compile_time_eval
correctly exposed as a public API
This function was added as a public API (#7987) but py.type static checkers do not recognize it as a public API because of the alias name. `jax.eval_context` exists only for backward compatibility, so the correct import would be to import `ensure_compile_time_eval` directly from `jax._src.core`.
This commit is contained in:
parent
704b9fc454
commit
479f33680d
@ -62,7 +62,7 @@ from jax._src.config import (
|
||||
transfer_guard_device_to_host as transfer_guard_device_to_host,
|
||||
spmd_mode as spmd_mode,
|
||||
)
|
||||
from jax._src.core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval
|
||||
from jax._src.environment_info import print_environment_info as print_environment_info
|
||||
from jax._src.api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
|
Loading…
x
Reference in New Issue
Block a user