From 74f1ab05038c5808c4a6254bdaa11fe5630a8e03 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 2 Feb 2023 12:58:15 -0500 Subject: [PATCH] Export Device as jax.Device. Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type. --- docs/jax.lib.rst | 10 ---------- docs/jax.rst | 1 + jax/__init__.py | 1 + jax/_src/api.py | 2 ++ 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/docs/jax.lib.rst b/docs/jax.lib.rst index a36631286..2513ccf1d 100644 --- a/docs/jax.lib.rst +++ b/docs/jax.lib.rst @@ -22,13 +22,3 @@ jax.lib.xla_client .. autosummary:: :toctree: _autosummary - -jax.lib.xla_extension ---------------------- - -.. currentmodule:: jaxlib.xla_extension - -.. autosummary:: - :toctree: _autosummary - - Device diff --git a/docs/jax.rst b/docs/jax.rst index 1d3eaaf89..3d354ceee 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -145,4 +145,5 @@ Miscellaneous .. autosummary:: :toctree: _autosummary + Device print_environment_info diff --git a/jax/__init__.py b/jax/__init__.py index 8d10fc2e2..a1a2923e4 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -65,6 +65,7 @@ from jax._src.config import ( from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval from jax._src.environment_info import print_environment_info as print_environment_info from jax._src.api import ( + Device as Device, ad, # TODO(phawkins): update users to avoid this. effects_barrier, block_until_ready as block_until_ready, diff --git a/jax/_src/api.py b/jax/_src/api.py index be220b934..844eff31c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -101,6 +101,8 @@ _dtype = partial(dtypes.dtype, canonicalize=True) AxisName = Any +Device = xc.Device + # These TypeVars are used below to express the fact that function types # (i.e. call signatures) are invariant under the vmap transformation. F = TypeVar("F", bound=Callable)