Remove more exported names from jax.interpreters.xla.

None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
This commit is contained in:
Peter Hawkins 2023-02-16 11:54:25 -08:00 committed by jax authors
parent 6b545a2ddc
commit 54269c1145
13 changed files with 80 additions and 146 deletions

View File

@ -78,7 +78,7 @@ from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp
from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.interpreters import xla
from jax._src.config import (
config,
@ -560,7 +560,7 @@ class _BackendAndDeviceInfo(NamedTuple):
committed_to_device: bool
class _FastpathData(NamedTuple):
xla_executable: xla.XlaLoadedExecutable
xla_executable: xc.LoadedExecutable
out_pytree_def: Any
sticky_device: Optional[xc.Device]
avals: Iterable[Any]
@ -1101,7 +1101,7 @@ def xla_computation(fun: Callable,
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, xla.ShapedArray):
if not isinstance(out_aval, ShapedArray):
raise RuntimeError("As we want to propagate the weak_type, we need "
"to get a ShapedArray, otherwise this "
"information is lost")
@ -2327,7 +2327,7 @@ def _python_pmap(
class _PmapFastpathData(NamedTuple):
version: int # For forward and backward compatibility
xla_executable: xla.XlaLoadedExecutable
xla_executable: xc.LoadedExecutable
in_handler: Any
out_handler: Any
out_pytree_def: Any

View File

@ -87,7 +87,6 @@ Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer
XlaLoadedExecutable = xla.XlaLoadedExecutable
CompileOptions = xc.CompileOptions
map, unsafe_map = util.safe_map, map
@ -810,7 +809,7 @@ if MYPY:
ResultHandler = Any
else:
class ResultHandler(Protocol):
def __call__(self, env: Optional[Sequence[Any]], *args: xla.Buffer) -> Any:
def __call__(self, env: Optional[Sequence[Any]], *args: xc.Buffer) -> Any:
"""Boxes raw buffers into their user-facing representation."""
def aval_to_result_handler(sticky_device: Optional[Device],
@ -894,7 +893,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
return input_bufs, _remove_tokens
def _execute_compiled(name: str, compiled: XlaLoadedExecutable,
def _execute_compiled(name: str, compiled: xc.LoadedExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Sequence[int],
result_handler: Callable, has_unordered_effects: bool,
@ -919,7 +918,7 @@ def _execute_compiled(name: str, compiled: XlaLoadedExecutable,
def _execute_replicated(name: str,
compiled: XlaLoadedExecutable,
compiled: xc.LoadedExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Sequence[int],
result_handler: Callable,
@ -939,7 +938,7 @@ def _execute_replicated(name: str,
for device in compiled.local_devices()]
input_bufs_flip = list(unsafe_zip(*input_bufs))
out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(input_bufs_flip)
out_flat = [bufs[0] for bufs in out_bufs_flat_rep]
out_flat = [bufs[0] for bufs in out_bufs_flat_rep] # type: ignore
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
if from_lower_sharding_computation:
@ -1098,7 +1097,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
compile_options: CompileOptions,
backend: Backend) -> Optional[XlaLoadedExecutable]:
backend: Backend) -> Optional[xc.LoadedExecutable]:
"""Looks up `computation` in the persisent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
@ -1117,7 +1116,7 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
compile_time_secs: float,
module_name: str, compile_options: CompileOptions,
backend: Backend, compiled: XlaLoadedExecutable,
backend: Backend, compiled: xc.LoadedExecutable,
host_callbacks: List[Any]):
"""Writes `serialized_computation` to the persistent compilation cache."""
# Avoid import cycle between jax and jax.experimental

View File

@ -1016,8 +1016,8 @@ def lower_jaxpr_to_fun(
args.append(hlo.CreateTokenOp().results)
else:
args.append(arg)
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
util.wrap_name(name, api_name))
callee_name_stack = util.extend_name_stack(
ctx.name_stack, util.wrap_name(name, api_name))
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
*args, dim_var_values=dim_var_values)

View File

@ -1261,11 +1261,11 @@ def parallel_callable(fun: lu.WrappedFun,
@dataclasses.dataclass(frozen=True)
class ParallelCallableInfo:
name: str
backend: xla.Backend
backend: xc.Client
axis_name: core.AxisName
axis_size: int
global_axis_size: int
devices: Optional[Sequence[xla.Device]]
devices: Optional[Sequence[xc.Device]]
in_axes: Iterable[Optional[int]]
out_axes_thunk: Callable[[], Sequence[Optional[int]]]
avals: Sequence[core.AbstractValue]
@ -1370,7 +1370,7 @@ def lower_parallel_callable(
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Optional[Sequence[xla.Device]],
devices: Optional[Sequence[xc.Device]],
name: str,
in_axes: Iterable[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
@ -1733,6 +1733,7 @@ def _get_pmap_sharding(devices, specs):
multi_host_supported_collectives: Set[core.Primitive] = set()
def check_multihost_collective_allowlist(jaxpr):
used_collectives = set(xla.jaxpr_collectives(jaxpr))
if not used_collectives.issubset(multi_host_supported_collectives):
@ -2295,8 +2296,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
sub_ctx = ctx.module_context.replace(
axis_context=mlir.ReplicaAxisContext(new_env),
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
util.wrap_name(name, 'pmap')))
name_stack=util.extend_name_stack(ctx.module_context.name_stack,
util.wrap_name(name, 'pmap')))
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
*in_nodes_sharded,
dim_var_values=ctx.dim_var_values)
@ -2778,7 +2779,7 @@ ShardingInfo = Tuple[
def _get_and_check_device_assignment(
shardings: Iterable[ShardingInfo],
devices: Optional[Sequence[xc.Device]],
) -> Tuple[xla.Backend, Sequence[xc.Device]]:
) -> Tuple[xc.Client, Sequence[xc.Device]]:
from jax._src.api import local_devices
first_sharding_info = None
@ -3572,7 +3573,7 @@ class UnloadedMeshExecutable:
class MeshExecutableFastpathData(NamedTuple):
xla_executable: xla.XlaLoadedExecutable
xla_executable: xc.LoadedExecutable
out_pytree_def: Any
in_shardings: Sequence[sharding_internal.XLACompatibleSharding]
out_shardings: Sequence[sharding_internal.XLACompatibleSharding]

View File

@ -35,14 +35,11 @@ from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
partition_list)
# TODO: update callers to refer to new location.
from jax._src.util import extend_name_stack as extend_name_stack # noqa: F401
from jax._src.util import wrap_name as wrap_name # noqa: F401
from jax._src.typing import Shape
from jax._src.lib import xla_bridge as xb
@ -55,28 +52,12 @@ xe = xc._xla
xops = xc._xla.ops
# Types
Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer
XlaOp = xc.XlaOp
XlaShape = xc.Shape
XlaBuilder = xc.XlaBuilder
XlaLoadedExecutable = Any
XlaLoadedExecutable = xc.LoadedExecutable # type:ignore
# TODO(phawkins): update code to point to new locations.
DeviceArray = device_array.DeviceArray
_DeviceArray = device_array._DeviceArray
_CppDeviceArray = xe.Buffer
make_device_array = device_array.make_device_array
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
def _make_array_shape(a: ShapedArray) -> Sequence[XlaShape]:
def _make_array_shape(a: ShapedArray) -> Sequence[xc.Shape]:
if a.dtype == dtypes.float0:
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
else:
@ -89,22 +70,6 @@ def get_canonical_source_file(frame: source_info_util.Frame):
'', source_file)
return source_file
tracebacks = {}
def make_op_metadata(primitive: core.Primitive,
params: Dict, *,
source_info: source_info_util.SourceInfo,
name_stack: Union[str, source_info_util.NameStack] = "",
) -> xc.OpMetadata:
eqn_str = (str(source_info.name_stack) + '/'
+ str_eqn_compact(primitive.name, params))
tracebacks[eqn_str] = source_info.traceback
frame = source_info_util.user_frame(source_info)
return xc.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=get_canonical_source_file(frame) if frame else None,
source_line=frame.start_line if frame else None)
# Utilities
def parameter(builder, num, shape, name=None, replicated=None):
@ -176,47 +141,16 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
### handlers
# Numpy dtypes -> XLA primitive types
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
np.dtype('bool'): xc.PrimitiveType.PRED,
np.dtype('int8'): xc.PrimitiveType.S8,
np.dtype('int16'): xc.PrimitiveType.S16,
np.dtype('int32'): xc.PrimitiveType.S32,
np.dtype('int64'): xc.PrimitiveType.S64,
np.dtype('uint8'): xc.PrimitiveType.U8,
np.dtype('uint16'): xc.PrimitiveType.U16,
np.dtype('uint32'): xc.PrimitiveType.U32,
np.dtype('uint64'): xc.PrimitiveType.U64,
np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16,
np.dtype('float16'): xc.PrimitiveType.F16,
np.dtype('float32'): xc.PrimitiveType.F32,
np.dtype('float64'): xc.PrimitiveType.F64,
np.dtype('complex64'): xc.PrimitiveType.C64,
np.dtype('complex128'): xc.PrimitiveType.C128,
}
def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
"""Converts a NumPy dtype into an XLA PrimitiveType."""
# Many things (e.g., strings, scalar types) can be compared with NumPy dtypes,
# but may not hash correctly. Make sure we have a true np.dtype.
assert isinstance(dtype, np.dtype), type(dtype)
try:
return _dtype_to_primitive_type[dtype]
except KeyError as err:
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
# JAX abstract values -> XLA shapes
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
try:
return xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
xla_shape_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[XlaShape]]] = {
Callable[[Any], Sequence[xc.Shape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
@ -493,8 +427,8 @@ if not MYPY:
def __call__(self, ctx: TranslationContext,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw
) -> Sequence[XlaOp]:
*args: xc.XlaOp, **kw
) -> Sequence[xc.XlaOp]:
"""A translation rule lowers a primitive invocation into an XLA HLO."""
else:
TranslationRule = Any
@ -545,7 +479,7 @@ def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw) -> Sequence[XlaOp]:
*args: xc.XlaOp, **kw) -> Sequence[xc.XlaOp]:
ans = f(ctx.builder, *args, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):

View File

@ -837,7 +837,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
name_stack=extend_name_stack(name_stack, f'branch_{i}_fun'))
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
map(mlir.ir_constants, jaxpr.consts),

View File

@ -1472,7 +1472,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, 'cond'))
name_stack=extend_name_stack(name_stack, 'cond'))
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
@ -1504,15 +1504,14 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, 'body'))
name_stack=extend_name_stack(name_stack, 'body'))
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts),
*(y + z), dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack,
'body_pred'))
name_stack=extend_name_stack(name_stack, 'body_pred'))
((body_pred,),), _ = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),

View File

@ -40,13 +40,13 @@ from jax._src.sharding import NamedSharding
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax.interpreters import xla
from jax._src.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import ad
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
as_hashable_function, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name,
merge_lists, partition_list)
merge_lists, partition_list, extend_name_stack)
from jax import lax
source_info_util.register_exclusion(__file__)
@ -1373,8 +1373,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(),
@ -1440,8 +1440,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
@ -1491,8 +1491,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# translation rule just because the extra tuple stuff is a pain.
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')),
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')),
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')

View File

@ -37,7 +37,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tu
import jax
from jax import tree_util
from jax.lib import xla_client as xc
from jax.interpreters import mlir
from jax._src import core
from jax._src import source_info_util
@ -45,8 +45,7 @@ from jax._src import traceback_util
from jax._src import util
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
source_info_util.register_exclusion(__file__)
@ -200,7 +199,7 @@ class Lowering(Protocol):
class XlaExecutable(Executable):
def xla_extension_executable(self) -> xla.XlaLoadedExecutable:
def xla_extension_executable(self) -> xc.LoadedExecutable:
raise NotImplementedError("must override")
def call(self, *args_flat) -> Sequence[Any]:

View File

@ -21,6 +21,7 @@ from typing import List, Optional
from jax.experimental.compilation_cache.gfile_cache import GFileCache
from jax._src import path as pathlib
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib import version_str as jaxlib_version_str
from jax.interpreters import xla
@ -50,7 +51,7 @@ def initialize_cache(path):
def get_executable(xla_computation, compile_options,
backend) -> Optional[xla.XlaLoadedExecutable]:
backend) -> Optional[xla_client.LoadedExecutable]:
"""Returns the cached executable if present, or None otherwise."""
assert _cache is not None, "initialize_cache must be called before you can call get_executable()"
cache_key = get_cache_key(xla_computation, compile_options, backend)
@ -63,7 +64,7 @@ def get_executable(xla_computation, compile_options,
return xla_executable_deserialized
def put_executable(module_name, xla_computation, compile_options,
executable: xla.XlaLoadedExecutable, backend):
executable: xla_client.LoadedExecutable, backend):
"""Adds 'executable' to the cache, possibly evicting older entries."""
assert _cache is not None, "initialize_cache must be called before you can call put_executable()"
cache_key = get_cache_key(xla_computation, compile_options, backend)

View File

@ -411,7 +411,7 @@ def _code_generator_and_avals(
xla_comp = xla_client.XlaComputation(func_tf_hlo)
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray:
def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray:
if not res_shape.is_static():
msg = ("Compiled TensorFlow function has dynamic output shape " +
f"{res_shape}. call_tf can used " +

View File

@ -1083,6 +1083,19 @@ class TensorFlowTracer(core.Tracer):
def full_lower(self):
return self
def _make_op_metadata(primitive: core.Primitive,
params: Dict, *,
source_info: source_info_util.SourceInfo,
) -> xla_client.OpMetadata:
eqn_str = (str(source_info.name_stack) + '/'
+ core.str_eqn_compact(primitive.name, params))
frame = source_info_util.user_frame(source_info)
return xla_client.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=xla.get_canonical_source_file(frame) if frame else None,
source_line=frame.start_line if frame else None)
class TensorFlowTrace(core.Trace):
"""Trace class that underlies the jax2tf transformation.
@ -1168,9 +1181,8 @@ class TensorFlowTrace(core.Trace):
with tf.name_scope(_sanitize_scope_name(scope)):
if _thread_local_state.include_xla_op_metadata:
op_metadata = xla.make_op_metadata(primitive, params,
name_stack=current_name_stack,
source_info=source_info_util.current())
op_metadata = _make_op_metadata(primitive, params,
source_info=source_info_util.current())
op_metadata_proto = xla_data_pb2.OpMetadata(
op_type=op_metadata.op_type,
op_name=op_metadata.op_name,

View File

@ -14,53 +14,32 @@
from jax._src.interpreters.xla import (
AxisEnv as AxisEnv,
Backend as Backend,
Buffer as Buffer,
ConcreteArray as ConcreteArray,
Shape as Shape,
ShapedArray as ShapedArray,
SpatialSharding as SpatialSharding,
TranslationContext as TranslationContext,
TranslationRule as TranslationRule,
XlaBuilder as XlaBuilder,
XlaLoadedExecutable as XlaLoadedExecutable,
XlaOp as XlaOp,
XlaShape as XlaShape,
_CppDeviceArray as _CppDeviceArray,
_DeviceArray as _DeviceArray,
abstractify as abstractify,
aval_to_xla_shapes as aval_to_xla_shapes,
axis_groups as axis_groups,
axis_read as axis_read,
backend_specific_translations as backend_specific_translations,
canonicalize_dtype as canonicalize_dtype,
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
check_backend_matches as check_backend_matches,
dtype_to_primitive_type as dtype_to_primitive_type,
extend_axis_env as extend_axis_env,
extend_name_stack as extend_name_stack,
jaxpr_collectives as jaxpr_collectives,
make_device_array as make_device_array,
make_op_metadata as make_op_metadata,
new_name_stack as new_name_stack,
parameter as parameter,
partition_list as partition_list,
primitive_subcomputation as primitive_subcomputation,
pytype_aval_mappings as pytype_aval_mappings,
register_collective_primitive as register_collective_primitive,
register_initial_style_primitive as register_initial_style_primitive,
register_translation as register_translation,
sharding_to_proto as sharding_to_proto,
translations as translations,
xb as xb,
xc as xc,
xe as xe,
xla_call as xla_call,
xla_call_p as xla_call_p,
xla_destructure as xla_destructure,
xla_shape_handlers as xla_shape_handlers,
)
from jax._src.core import (
ShapedArray as ShapedArray,
ConcreteArray as ConcreteArray,
)
# TODO(phawkins): update users.
from jax._src.dispatch import (
apply_primitive as apply_primitive,
@ -68,9 +47,19 @@ from jax._src.dispatch import (
device_put as device_put,
)
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc # type: ignore
from jax._src.interpreters.xla import (
Device as _deprecated_Device,
_deprecated_Device = xc.Device
XlaOp = xc.XlaOp
xe = xc._xla
Backend = xe.Client
Buffer = xc.Buffer
_CppDeviceArray = xe.Buffer
from jax._src.device_array import (
make_device_array as make_device_array,
_DeviceArray as _DeviceArray,
DeviceArray as _deprecated_DeviceArray,
)
@ -92,8 +81,8 @@ del _deprecation_getattr
import typing
if typing.TYPE_CHECKING:
from jax._src.interpreters.xla import (
Device as Device,
Device = xc.Device
from jax._src.device_array import (
DeviceArray as DeviceArray,
)
del typing