mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Export Device as jax.Device.
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
This commit is contained in:
parent
365262b77a
commit
74f1ab0503
@ -22,13 +22,3 @@ jax.lib.xla_client
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
jax.lib.xla_extension
|
||||
---------------------
|
||||
|
||||
.. currentmodule:: jaxlib.xla_extension
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Device
|
||||
|
@ -145,4 +145,5 @@ Miscellaneous
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Device
|
||||
print_environment_info
|
||||
|
@ -65,6 +65,7 @@ from jax._src.config import (
|
||||
from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval
|
||||
from jax._src.environment_info import print_environment_info as print_environment_info
|
||||
from jax._src.api import (
|
||||
Device as Device,
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
effects_barrier,
|
||||
block_until_ready as block_until_ready,
|
||||
|
@ -101,6 +101,8 @@ _dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
AxisName = Any
|
||||
|
||||
Device = xc.Device
|
||||
|
||||
# These TypeVars are used below to express the fact that function types
|
||||
# (i.e. call signatures) are invariant under the vmap transformation.
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
Loading…
x
Reference in New Issue
Block a user