mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy. Also: * shorten a number of alias chains. * move make_op_metadata() into its only caller in jax2tf * delete the unused function dtype_to_primitive_type. PiperOrigin-RevId: 510205315
This commit is contained in:
parent
6b545a2ddc
commit
54269c1145
@ -78,7 +78,7 @@ from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
from jax._src.config import (
|
||||
config,
|
||||
@ -560,7 +560,7 @@ class _BackendAndDeviceInfo(NamedTuple):
|
||||
committed_to_device: bool
|
||||
|
||||
class _FastpathData(NamedTuple):
|
||||
xla_executable: xla.XlaLoadedExecutable
|
||||
xla_executable: xc.LoadedExecutable
|
||||
out_pytree_def: Any
|
||||
sticky_device: Optional[xc.Device]
|
||||
avals: Iterable[Any]
|
||||
@ -1101,7 +1101,7 @@ def xla_computation(fun: Callable,
|
||||
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
||||
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
|
||||
for out_aval in out_avals:
|
||||
if not isinstance(out_aval, xla.ShapedArray):
|
||||
if not isinstance(out_aval, ShapedArray):
|
||||
raise RuntimeError("As we want to propagate the weak_type, we need "
|
||||
"to get a ShapedArray, otherwise this "
|
||||
"information is lost")
|
||||
@ -2327,7 +2327,7 @@ def _python_pmap(
|
||||
|
||||
class _PmapFastpathData(NamedTuple):
|
||||
version: int # For forward and backward compatibility
|
||||
xla_executable: xla.XlaLoadedExecutable
|
||||
xla_executable: xc.LoadedExecutable
|
||||
in_handler: Any
|
||||
out_handler: Any
|
||||
out_pytree_def: Any
|
||||
|
@ -87,7 +87,6 @@ Backend = xe.Client
|
||||
Device = xc.Device
|
||||
Buffer = xe.Buffer
|
||||
|
||||
XlaLoadedExecutable = xla.XlaLoadedExecutable
|
||||
CompileOptions = xc.CompileOptions
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
@ -810,7 +809,7 @@ if MYPY:
|
||||
ResultHandler = Any
|
||||
else:
|
||||
class ResultHandler(Protocol):
|
||||
def __call__(self, env: Optional[Sequence[Any]], *args: xla.Buffer) -> Any:
|
||||
def __call__(self, env: Optional[Sequence[Any]], *args: xc.Buffer) -> Any:
|
||||
"""Boxes raw buffers into their user-facing representation."""
|
||||
|
||||
def aval_to_result_handler(sticky_device: Optional[Device],
|
||||
@ -894,7 +893,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
||||
return input_bufs, _remove_tokens
|
||||
|
||||
|
||||
def _execute_compiled(name: str, compiled: XlaLoadedExecutable,
|
||||
def _execute_compiled(name: str, compiled: xc.LoadedExecutable,
|
||||
input_handler: Optional[Callable],
|
||||
output_buffer_counts: Sequence[int],
|
||||
result_handler: Callable, has_unordered_effects: bool,
|
||||
@ -919,7 +918,7 @@ def _execute_compiled(name: str, compiled: XlaLoadedExecutable,
|
||||
|
||||
|
||||
def _execute_replicated(name: str,
|
||||
compiled: XlaLoadedExecutable,
|
||||
compiled: xc.LoadedExecutable,
|
||||
input_handler: Optional[Callable],
|
||||
output_buffer_counts: Sequence[int],
|
||||
result_handler: Callable,
|
||||
@ -939,7 +938,7 @@ def _execute_replicated(name: str,
|
||||
for device in compiled.local_devices()]
|
||||
input_bufs_flip = list(unsafe_zip(*input_bufs))
|
||||
out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(input_bufs_flip)
|
||||
out_flat = [bufs[0] for bufs in out_bufs_flat_rep]
|
||||
out_flat = [bufs[0] for bufs in out_bufs_flat_rep] # type: ignore
|
||||
check_special(name, out_flat)
|
||||
out_bufs = unflatten(out_flat, output_buffer_counts)
|
||||
if from_lower_sharding_computation:
|
||||
@ -1098,7 +1097,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
|
||||
def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
|
||||
compile_options: CompileOptions,
|
||||
backend: Backend) -> Optional[XlaLoadedExecutable]:
|
||||
backend: Backend) -> Optional[xc.LoadedExecutable]:
|
||||
"""Looks up `computation` in the persisent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
@ -1117,7 +1116,7 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
|
||||
def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
|
||||
compile_time_secs: float,
|
||||
module_name: str, compile_options: CompileOptions,
|
||||
backend: Backend, compiled: XlaLoadedExecutable,
|
||||
backend: Backend, compiled: xc.LoadedExecutable,
|
||||
host_callbacks: List[Any]):
|
||||
"""Writes `serialized_computation` to the persistent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
|
@ -1016,8 +1016,8 @@ def lower_jaxpr_to_fun(
|
||||
args.append(hlo.CreateTokenOp().results)
|
||||
else:
|
||||
args.append(arg)
|
||||
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
|
||||
util.wrap_name(name, api_name))
|
||||
callee_name_stack = util.extend_name_stack(
|
||||
ctx.name_stack, util.wrap_name(name, api_name))
|
||||
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
|
||||
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
|
||||
*args, dim_var_values=dim_var_values)
|
||||
|
@ -1261,11 +1261,11 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ParallelCallableInfo:
|
||||
name: str
|
||||
backend: xla.Backend
|
||||
backend: xc.Client
|
||||
axis_name: core.AxisName
|
||||
axis_size: int
|
||||
global_axis_size: int
|
||||
devices: Optional[Sequence[xla.Device]]
|
||||
devices: Optional[Sequence[xc.Device]]
|
||||
in_axes: Iterable[Optional[int]]
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]]
|
||||
avals: Sequence[core.AbstractValue]
|
||||
@ -1370,7 +1370,7 @@ def lower_parallel_callable(
|
||||
axis_name: core.AxisName,
|
||||
axis_size: int,
|
||||
global_axis_size: int,
|
||||
devices: Optional[Sequence[xla.Device]],
|
||||
devices: Optional[Sequence[xc.Device]],
|
||||
name: str,
|
||||
in_axes: Iterable[Optional[int]],
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
@ -1733,6 +1733,7 @@ def _get_pmap_sharding(devices, specs):
|
||||
multi_host_supported_collectives: Set[core.Primitive] = set()
|
||||
|
||||
|
||||
|
||||
def check_multihost_collective_allowlist(jaxpr):
|
||||
used_collectives = set(xla.jaxpr_collectives(jaxpr))
|
||||
if not used_collectives.issubset(multi_host_supported_collectives):
|
||||
@ -2295,8 +2296,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
axis_context=mlir.ReplicaAxisContext(new_env),
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
util.wrap_name(name, 'pmap')))
|
||||
name_stack=util.extend_name_stack(ctx.module_context.name_stack,
|
||||
util.wrap_name(name, 'pmap')))
|
||||
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
|
||||
*in_nodes_sharded,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
@ -2778,7 +2779,7 @@ ShardingInfo = Tuple[
|
||||
def _get_and_check_device_assignment(
|
||||
shardings: Iterable[ShardingInfo],
|
||||
devices: Optional[Sequence[xc.Device]],
|
||||
) -> Tuple[xla.Backend, Sequence[xc.Device]]:
|
||||
) -> Tuple[xc.Client, Sequence[xc.Device]]:
|
||||
from jax._src.api import local_devices
|
||||
|
||||
first_sharding_info = None
|
||||
@ -3572,7 +3573,7 @@ class UnloadedMeshExecutable:
|
||||
|
||||
|
||||
class MeshExecutableFastpathData(NamedTuple):
|
||||
xla_executable: xla.XlaLoadedExecutable
|
||||
xla_executable: xc.LoadedExecutable
|
||||
out_pytree_def: Any
|
||||
in_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
out_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
|
@ -35,14 +35,11 @@ from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
|
||||
partition_list)
|
||||
|
||||
# TODO: update callers to refer to new location.
|
||||
from jax._src.util import extend_name_stack as extend_name_stack # noqa: F401
|
||||
from jax._src.util import wrap_name as wrap_name # noqa: F401
|
||||
from jax._src.typing import Shape
|
||||
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
@ -55,28 +52,12 @@ xe = xc._xla
|
||||
xops = xc._xla.ops
|
||||
|
||||
# Types
|
||||
Backend = xe.Client
|
||||
Device = xc.Device
|
||||
Buffer = xe.Buffer
|
||||
|
||||
XlaOp = xc.XlaOp
|
||||
XlaShape = xc.Shape
|
||||
XlaBuilder = xc.XlaBuilder
|
||||
XlaLoadedExecutable = Any
|
||||
XlaLoadedExecutable = xc.LoadedExecutable # type:ignore
|
||||
|
||||
# TODO(phawkins): update code to point to new locations.
|
||||
DeviceArray = device_array.DeviceArray
|
||||
_DeviceArray = device_array._DeviceArray
|
||||
_CppDeviceArray = xe.Buffer
|
||||
make_device_array = device_array.make_device_array
|
||||
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
def _make_array_shape(a: ShapedArray) -> Sequence[XlaShape]:
|
||||
def _make_array_shape(a: ShapedArray) -> Sequence[xc.Shape]:
|
||||
if a.dtype == dtypes.float0:
|
||||
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
|
||||
else:
|
||||
@ -89,22 +70,6 @@ def get_canonical_source_file(frame: source_info_util.Frame):
|
||||
'', source_file)
|
||||
return source_file
|
||||
|
||||
tracebacks = {}
|
||||
def make_op_metadata(primitive: core.Primitive,
|
||||
params: Dict, *,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
name_stack: Union[str, source_info_util.NameStack] = "",
|
||||
) -> xc.OpMetadata:
|
||||
eqn_str = (str(source_info.name_stack) + '/'
|
||||
+ str_eqn_compact(primitive.name, params))
|
||||
tracebacks[eqn_str] = source_info.traceback
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
return xc.OpMetadata(
|
||||
op_type=primitive.name,
|
||||
op_name=eqn_str,
|
||||
source_file=get_canonical_source_file(frame) if frame else None,
|
||||
source_line=frame.start_line if frame else None)
|
||||
|
||||
# Utilities
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
@ -176,47 +141,16 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
|
||||
### handlers
|
||||
|
||||
# Numpy dtypes -> XLA primitive types
|
||||
|
||||
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
|
||||
np.dtype('bool'): xc.PrimitiveType.PRED,
|
||||
np.dtype('int8'): xc.PrimitiveType.S8,
|
||||
np.dtype('int16'): xc.PrimitiveType.S16,
|
||||
np.dtype('int32'): xc.PrimitiveType.S32,
|
||||
np.dtype('int64'): xc.PrimitiveType.S64,
|
||||
np.dtype('uint8'): xc.PrimitiveType.U8,
|
||||
np.dtype('uint16'): xc.PrimitiveType.U16,
|
||||
np.dtype('uint32'): xc.PrimitiveType.U32,
|
||||
np.dtype('uint64'): xc.PrimitiveType.U64,
|
||||
np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16,
|
||||
np.dtype('float16'): xc.PrimitiveType.F16,
|
||||
np.dtype('float32'): xc.PrimitiveType.F32,
|
||||
np.dtype('float64'): xc.PrimitiveType.F64,
|
||||
np.dtype('complex64'): xc.PrimitiveType.C64,
|
||||
np.dtype('complex128'): xc.PrimitiveType.C128,
|
||||
}
|
||||
|
||||
def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
|
||||
"""Converts a NumPy dtype into an XLA PrimitiveType."""
|
||||
# Many things (e.g., strings, scalar types) can be compared with NumPy dtypes,
|
||||
# but may not hash correctly. Make sure we have a true np.dtype.
|
||||
assert isinstance(dtype, np.dtype), type(dtype)
|
||||
try:
|
||||
return _dtype_to_primitive_type[dtype]
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
|
||||
|
||||
|
||||
# JAX abstract values -> XLA shapes
|
||||
|
||||
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
|
||||
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
|
||||
try:
|
||||
return xla_shape_handlers[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
|
||||
|
||||
xla_shape_handlers: Dict[Type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[XlaShape]]] = {
|
||||
Callable[[Any], Sequence[xc.Shape]]] = {
|
||||
ShapedArray: _make_array_shape,
|
||||
ConcreteArray: _make_array_shape,
|
||||
}
|
||||
@ -493,8 +427,8 @@ if not MYPY:
|
||||
def __call__(self, ctx: TranslationContext,
|
||||
avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
*args: XlaOp, **kw
|
||||
) -> Sequence[XlaOp]:
|
||||
*args: xc.XlaOp, **kw
|
||||
) -> Sequence[xc.XlaOp]:
|
||||
"""A translation rule lowers a primitive invocation into an XLA HLO."""
|
||||
else:
|
||||
TranslationRule = Any
|
||||
@ -545,7 +479,7 @@ def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
|
||||
@functools.wraps(f)
|
||||
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
*args: XlaOp, **kw) -> Sequence[XlaOp]:
|
||||
*args: xc.XlaOp, **kw) -> Sequence[xc.XlaOp]:
|
||||
ans = f(ctx.builder, *args, **kw)
|
||||
if (prim.multiple_results or
|
||||
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
|
||||
|
@ -837,7 +837,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
|
||||
name_stack=extend_name_stack(name_stack, f'branch_{i}_fun'))
|
||||
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
||||
sub_ctx, jaxpr.jaxpr, tokens_in,
|
||||
map(mlir.ir_constants, jaxpr.consts),
|
||||
|
@ -1472,7 +1472,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
cond_args = cond_args[num_tokens:]
|
||||
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
|
||||
cond_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack, 'cond'))
|
||||
name_stack=extend_name_stack(name_stack, 'cond'))
|
||||
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
|
||||
_map(mlir.ir_constants, cond_jaxpr.consts),
|
||||
*(x + z), dim_var_values=ctx.dim_var_values)
|
||||
@ -1504,15 +1504,14 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
|
||||
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
|
||||
body_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack, 'body'))
|
||||
name_stack=extend_name_stack(name_stack, 'body'))
|
||||
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
|
||||
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts),
|
||||
*(y + z), dim_var_values=ctx.dim_var_values)
|
||||
out_tokens = [tokens_out.get(eff) for eff in body_effects]
|
||||
if batched:
|
||||
body_pred_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack,
|
||||
'body_pred'))
|
||||
name_stack=extend_name_stack(name_stack, 'body_pred'))
|
||||
((body_pred,),), _ = mlir.jaxpr_subcomp(
|
||||
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
|
||||
_map(mlir.ir_constants, cond_jaxpr.consts),
|
||||
|
@ -40,13 +40,13 @@ from jax._src.sharding import NamedSharding
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import ad
|
||||
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
|
||||
as_hashable_function, distributed_debug_log,
|
||||
tuple_insert, moveaxis, split_list, wrap_name,
|
||||
merge_lists, partition_list)
|
||||
merge_lists, partition_list, extend_name_stack)
|
||||
from jax import lax
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
@ -1373,8 +1373,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(),
|
||||
@ -1440,8 +1440,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
|
||||
@ -1491,8 +1491,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')),
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')),
|
||||
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
|
||||
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
|
@ -37,7 +37,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tu
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.lib import xla_client as xc
|
||||
from jax.interpreters import mlir
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
@ -45,8 +45,7 @@ from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import use_stablehlo
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
@ -200,7 +199,7 @@ class Lowering(Protocol):
|
||||
|
||||
class XlaExecutable(Executable):
|
||||
|
||||
def xla_extension_executable(self) -> xla.XlaLoadedExecutable:
|
||||
def xla_extension_executable(self) -> xc.LoadedExecutable:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def call(self, *args_flat) -> Sequence[Any]:
|
||||
|
@ -21,6 +21,7 @@ from typing import List, Optional
|
||||
|
||||
from jax.experimental.compilation_cache.gfile_cache import GFileCache
|
||||
from jax._src import path as pathlib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax.interpreters import xla
|
||||
@ -50,7 +51,7 @@ def initialize_cache(path):
|
||||
|
||||
|
||||
def get_executable(xla_computation, compile_options,
|
||||
backend) -> Optional[xla.XlaLoadedExecutable]:
|
||||
backend) -> Optional[xla_client.LoadedExecutable]:
|
||||
"""Returns the cached executable if present, or None otherwise."""
|
||||
assert _cache is not None, "initialize_cache must be called before you can call get_executable()"
|
||||
cache_key = get_cache_key(xla_computation, compile_options, backend)
|
||||
@ -63,7 +64,7 @@ def get_executable(xla_computation, compile_options,
|
||||
return xla_executable_deserialized
|
||||
|
||||
def put_executable(module_name, xla_computation, compile_options,
|
||||
executable: xla.XlaLoadedExecutable, backend):
|
||||
executable: xla_client.LoadedExecutable, backend):
|
||||
"""Adds 'executable' to the cache, possibly evicting older entries."""
|
||||
assert _cache is not None, "initialize_cache must be called before you can call put_executable()"
|
||||
cache_key = get_cache_key(xla_computation, compile_options, backend)
|
||||
|
@ -411,7 +411,7 @@ def _code_generator_and_avals(
|
||||
xla_comp = xla_client.XlaComputation(func_tf_hlo)
|
||||
|
||||
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
|
||||
def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray:
|
||||
def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray:
|
||||
if not res_shape.is_static():
|
||||
msg = ("Compiled TensorFlow function has dynamic output shape " +
|
||||
f"{res_shape}. call_tf can used " +
|
||||
|
@ -1083,6 +1083,19 @@ class TensorFlowTracer(core.Tracer):
|
||||
def full_lower(self):
|
||||
return self
|
||||
|
||||
def _make_op_metadata(primitive: core.Primitive,
|
||||
params: Dict, *,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
) -> xla_client.OpMetadata:
|
||||
eqn_str = (str(source_info.name_stack) + '/'
|
||||
+ core.str_eqn_compact(primitive.name, params))
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
return xla_client.OpMetadata(
|
||||
op_type=primitive.name,
|
||||
op_name=eqn_str,
|
||||
source_file=xla.get_canonical_source_file(frame) if frame else None,
|
||||
source_line=frame.start_line if frame else None)
|
||||
|
||||
|
||||
class TensorFlowTrace(core.Trace):
|
||||
"""Trace class that underlies the jax2tf transformation.
|
||||
@ -1168,9 +1181,8 @@ class TensorFlowTrace(core.Trace):
|
||||
|
||||
with tf.name_scope(_sanitize_scope_name(scope)):
|
||||
if _thread_local_state.include_xla_op_metadata:
|
||||
op_metadata = xla.make_op_metadata(primitive, params,
|
||||
name_stack=current_name_stack,
|
||||
source_info=source_info_util.current())
|
||||
op_metadata = _make_op_metadata(primitive, params,
|
||||
source_info=source_info_util.current())
|
||||
op_metadata_proto = xla_data_pb2.OpMetadata(
|
||||
op_type=op_metadata.op_type,
|
||||
op_name=op_metadata.op_name,
|
||||
|
@ -14,53 +14,32 @@
|
||||
|
||||
from jax._src.interpreters.xla import (
|
||||
AxisEnv as AxisEnv,
|
||||
Backend as Backend,
|
||||
Buffer as Buffer,
|
||||
ConcreteArray as ConcreteArray,
|
||||
Shape as Shape,
|
||||
ShapedArray as ShapedArray,
|
||||
SpatialSharding as SpatialSharding,
|
||||
TranslationContext as TranslationContext,
|
||||
TranslationRule as TranslationRule,
|
||||
XlaBuilder as XlaBuilder,
|
||||
XlaLoadedExecutable as XlaLoadedExecutable,
|
||||
XlaOp as XlaOp,
|
||||
XlaShape as XlaShape,
|
||||
_CppDeviceArray as _CppDeviceArray,
|
||||
_DeviceArray as _DeviceArray,
|
||||
abstractify as abstractify,
|
||||
aval_to_xla_shapes as aval_to_xla_shapes,
|
||||
axis_groups as axis_groups,
|
||||
axis_read as axis_read,
|
||||
backend_specific_translations as backend_specific_translations,
|
||||
canonicalize_dtype as canonicalize_dtype,
|
||||
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
|
||||
check_backend_matches as check_backend_matches,
|
||||
dtype_to_primitive_type as dtype_to_primitive_type,
|
||||
extend_axis_env as extend_axis_env,
|
||||
extend_name_stack as extend_name_stack,
|
||||
jaxpr_collectives as jaxpr_collectives,
|
||||
make_device_array as make_device_array,
|
||||
make_op_metadata as make_op_metadata,
|
||||
new_name_stack as new_name_stack,
|
||||
parameter as parameter,
|
||||
partition_list as partition_list,
|
||||
primitive_subcomputation as primitive_subcomputation,
|
||||
pytype_aval_mappings as pytype_aval_mappings,
|
||||
register_collective_primitive as register_collective_primitive,
|
||||
register_initial_style_primitive as register_initial_style_primitive,
|
||||
register_translation as register_translation,
|
||||
sharding_to_proto as sharding_to_proto,
|
||||
translations as translations,
|
||||
xb as xb,
|
||||
xc as xc,
|
||||
xe as xe,
|
||||
xla_call as xla_call,
|
||||
xla_call_p as xla_call_p,
|
||||
xla_destructure as xla_destructure,
|
||||
xla_shape_handlers as xla_shape_handlers,
|
||||
)
|
||||
|
||||
from jax._src.core import (
|
||||
ShapedArray as ShapedArray,
|
||||
ConcreteArray as ConcreteArray,
|
||||
)
|
||||
|
||||
# TODO(phawkins): update users.
|
||||
from jax._src.dispatch import (
|
||||
apply_primitive as apply_primitive,
|
||||
@ -68,9 +47,19 @@ from jax._src.dispatch import (
|
||||
device_put as device_put,
|
||||
)
|
||||
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc # type: ignore
|
||||
|
||||
from jax._src.interpreters.xla import (
|
||||
Device as _deprecated_Device,
|
||||
_deprecated_Device = xc.Device
|
||||
XlaOp = xc.XlaOp
|
||||
xe = xc._xla
|
||||
Backend = xe.Client
|
||||
Buffer = xc.Buffer
|
||||
_CppDeviceArray = xe.Buffer
|
||||
|
||||
from jax._src.device_array import (
|
||||
make_device_array as make_device_array,
|
||||
_DeviceArray as _DeviceArray,
|
||||
DeviceArray as _deprecated_DeviceArray,
|
||||
)
|
||||
|
||||
@ -92,8 +81,8 @@ del _deprecation_getattr
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.interpreters.xla import (
|
||||
Device as Device,
|
||||
Device = xc.Device
|
||||
from jax._src.device_array import (
|
||||
DeviceArray as DeviceArray,
|
||||
)
|
||||
del typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user