https://github.com/google/jax/pull/15391 had the unintentional side effect of causing these names not to be imported by default. Restore the status quo by importing them.
PiperOrigin-RevId: 521898088
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.
It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:
```
from jax._src.lib import xla_bridge
@mock.patch.object(xla_bridge, 'process_index')
...
```
A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:
```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```
However, this solution requires the `jax._src` be present in the JAX namespace.
Ideally users wouldn't mock our internals at all, but that requires significantly more work.
PiperOrigin-RevId: 512295203
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.
Fix some overly-strict type annotations.
PiperOrigin-RevId: 508461199
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
This function was added as a public API (#7987) but py.type static
checkers do not recognize it as a public API because of the alias name.
`jax.eval_context` exists only for backward compatibility, so the
correct import would be to import `ensure_compile_time_eval` directly
from `jax._src.core`.
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.
Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.
PiperOrigin-RevId: 485075693
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX.
PiperOrigin-RevId: 473243243
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
--
e1f1e93e0c8b53e62a064b06b56c84a2bfedb911 by Roy Frostig <frostig@google.com>:
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module
PiperOrigin-RevId: 461146464