Export Device as jax.Device.

Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
This commit is contained in:
Peter Hawkins 2023-02-02 12:58:15 -05:00
parent 365262b77a
commit 74f1ab0503
4 changed files with 4 additions and 10 deletions

View File

@ -22,13 +22,3 @@ jax.lib.xla_client
.. autosummary::
:toctree: _autosummary
jax.lib.xla_extension
---------------------
.. currentmodule:: jaxlib.xla_extension
.. autosummary::
:toctree: _autosummary
Device

View File

@ -145,4 +145,5 @@ Miscellaneous
.. autosummary::
:toctree: _autosummary
Device
print_environment_info

View File

@ -65,6 +65,7 @@ from jax._src.config import (
from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval
from jax._src.environment_info import print_environment_info as print_environment_info
from jax._src.api import (
Device as Device,
ad, # TODO(phawkins): update users to avoid this.
effects_barrier,
block_until_ready as block_until_ready,

View File

@ -101,6 +101,8 @@ _dtype = partial(dtypes.dtype, canonicalize=True)
AxisName = Any
Device = xc.Device
# These TypeVars are used below to express the fact that function types
# (i.e. call signatures) are invariant under the vmap transformation.
F = TypeVar("F", bound=Callable)