From 54269c114539eaa96eff51ed7a4d3c769e3879db Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 16 Feb 2023 11:54:25 -0800 Subject: [PATCH] Remove more exported names from jax.interpreters.xla. None of these appear to have public users, and this module is not included in the deprecation policy. Also: * shorten a number of alias chains. * move make_op_metadata() into its only caller in jax2tf * delete the unused function dtype_to_primitive_type. PiperOrigin-RevId: 510205315 --- jax/_src/api.py | 8 +- jax/_src/dispatch.py | 13 ++- jax/_src/interpreters/mlir.py | 4 +- jax/_src/interpreters/pxla.py | 15 ++-- jax/_src/interpreters/xla.py | 80 ++----------------- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/loops.py | 7 +- jax/_src/maps.py | 16 ++-- jax/_src/stages.py | 7 +- .../compilation_cache/compilation_cache.py | 5 +- jax/experimental/jax2tf/call_tf.py | 2 +- jax/experimental/jax2tf/jax2tf.py | 18 ++++- jax/interpreters/xla.py | 49 +++++------- 13 files changed, 80 insertions(+), 146 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 3d656c395..0c04748ba 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -78,7 +78,7 @@ from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp from jax.custom_transpose import custom_transpose from jax.interpreters import partial_eval as pe from jax.interpreters import mlir -from jax.interpreters import xla +from jax._src.interpreters import xla from jax._src.config import ( config, @@ -560,7 +560,7 @@ class _BackendAndDeviceInfo(NamedTuple): committed_to_device: bool class _FastpathData(NamedTuple): - xla_executable: xla.XlaLoadedExecutable + xla_executable: xc.LoadedExecutable out_pytree_def: Any sticky_device: Optional[xc.Device] avals: Iterable[Any] @@ -1101,7 +1101,7 @@ def xla_computation(fun: Callable, ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] out_shape = tree_unflatten(out_tree(), out_shapes_flat) for out_aval in out_avals: - if not isinstance(out_aval, xla.ShapedArray): + if not isinstance(out_aval, ShapedArray): raise RuntimeError("As we want to propagate the weak_type, we need " "to get a ShapedArray, otherwise this " "information is lost") @@ -2327,7 +2327,7 @@ def _python_pmap( class _PmapFastpathData(NamedTuple): version: int # For forward and backward compatibility - xla_executable: xla.XlaLoadedExecutable + xla_executable: xc.LoadedExecutable in_handler: Any out_handler: Any out_pytree_def: Any diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 10ad11942..0c7b72307 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -87,7 +87,6 @@ Backend = xe.Client Device = xc.Device Buffer = xe.Buffer -XlaLoadedExecutable = xla.XlaLoadedExecutable CompileOptions = xc.CompileOptions map, unsafe_map = util.safe_map, map @@ -810,7 +809,7 @@ if MYPY: ResultHandler = Any else: class ResultHandler(Protocol): - def __call__(self, env: Optional[Sequence[Any]], *args: xla.Buffer) -> Any: + def __call__(self, env: Optional[Sequence[Any]], *args: xc.Buffer) -> Any: """Boxes raw buffers into their user-facing representation.""" def aval_to_result_handler(sticky_device: Optional[Device], @@ -894,7 +893,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect], return input_bufs, _remove_tokens -def _execute_compiled(name: str, compiled: XlaLoadedExecutable, +def _execute_compiled(name: str, compiled: xc.LoadedExecutable, input_handler: Optional[Callable], output_buffer_counts: Sequence[int], result_handler: Callable, has_unordered_effects: bool, @@ -919,7 +918,7 @@ def _execute_compiled(name: str, compiled: XlaLoadedExecutable, def _execute_replicated(name: str, - compiled: XlaLoadedExecutable, + compiled: xc.LoadedExecutable, input_handler: Optional[Callable], output_buffer_counts: Sequence[int], result_handler: Callable, @@ -939,7 +938,7 @@ def _execute_replicated(name: str, for device in compiled.local_devices()] input_bufs_flip = list(unsafe_zip(*input_bufs)) out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(input_bufs_flip) - out_flat = [bufs[0] for bufs in out_bufs_flat_rep] + out_flat = [bufs[0] for bufs in out_bufs_flat_rep] # type: ignore check_special(name, out_flat) out_bufs = unflatten(out_flat, output_buffer_counts) if from_lower_sharding_computation: @@ -1098,7 +1097,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options, def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str, compile_options: CompileOptions, - backend: Backend) -> Optional[XlaLoadedExecutable]: + backend: Backend) -> Optional[xc.LoadedExecutable]: """Looks up `computation` in the persisent compilation cache.""" # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc @@ -1117,7 +1116,7 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str, def _cache_write(serialized_computation: Union[str, bytes, ir.Module], compile_time_secs: float, module_name: str, compile_options: CompileOptions, - backend: Backend, compiled: XlaLoadedExecutable, + backend: Backend, compiled: xc.LoadedExecutable, host_callbacks: List[Any]): """Writes `serialized_computation` to the persistent compilation cache.""" # Avoid import cycle between jax and jax.experimental diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 96bf1d745..fc78dcb8d 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1016,8 +1016,8 @@ def lower_jaxpr_to_fun( args.append(hlo.CreateTokenOp().results) else: args.append(arg) - callee_name_stack = xla.extend_name_stack(ctx.name_stack, - util.wrap_name(name, api_name)) + callee_name_stack = util.extend_name_stack( + ctx.name_stack, util.wrap_name(name, api_name)) out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts), *args, dim_var_values=dim_var_values) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 97e27283f..8aa511b50 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1261,11 +1261,11 @@ def parallel_callable(fun: lu.WrappedFun, @dataclasses.dataclass(frozen=True) class ParallelCallableInfo: name: str - backend: xla.Backend + backend: xc.Client axis_name: core.AxisName axis_size: int global_axis_size: int - devices: Optional[Sequence[xla.Device]] + devices: Optional[Sequence[xc.Device]] in_axes: Iterable[Optional[int]] out_axes_thunk: Callable[[], Sequence[Optional[int]]] avals: Sequence[core.AbstractValue] @@ -1370,7 +1370,7 @@ def lower_parallel_callable( axis_name: core.AxisName, axis_size: int, global_axis_size: int, - devices: Optional[Sequence[xla.Device]], + devices: Optional[Sequence[xc.Device]], name: str, in_axes: Iterable[Optional[int]], out_axes_thunk: Callable[[], Sequence[Optional[int]]], @@ -1733,6 +1733,7 @@ def _get_pmap_sharding(devices, specs): multi_host_supported_collectives: Set[core.Primitive] = set() + def check_multihost_collective_allowlist(jaxpr): used_collectives = set(xla.jaxpr_collectives(jaxpr)) if not used_collectives.issubset(multi_host_supported_collectives): @@ -2295,8 +2296,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore sub_ctx = ctx.module_context.replace( axis_context=mlir.ReplicaAxisContext(new_env), - name_stack=xla.extend_name_stack(ctx.module_context.name_stack, - util.wrap_name(name, 'pmap'))) + name_stack=util.extend_name_stack(ctx.module_context.name_stack, + util.wrap_name(name, 'pmap'))) sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (), *in_nodes_sharded, dim_var_values=ctx.dim_var_values) @@ -2778,7 +2779,7 @@ ShardingInfo = Tuple[ def _get_and_check_device_assignment( shardings: Iterable[ShardingInfo], devices: Optional[Sequence[xc.Device]], -) -> Tuple[xla.Backend, Sequence[xc.Device]]: +) -> Tuple[xc.Client, Sequence[xc.Device]]: from jax._src.api import local_devices first_sharding_info = None @@ -3572,7 +3573,7 @@ class UnloadedMeshExecutable: class MeshExecutableFastpathData(NamedTuple): - xla_executable: xla.XlaLoadedExecutable + xla_executable: xc.LoadedExecutable out_pytree_def: Any in_shardings: Sequence[sharding_internal.XLACompatibleSharding] out_shardings: Sequence[sharding_internal.XLACompatibleSharding] diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 4e113872b..e397e90d0 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -35,14 +35,11 @@ from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact +from jax._src.core import ConcreteArray, ShapedArray from jax._src.interpreters import ad from jax._src.util import (prod, new_name_stack, safe_zip, safe_map, partition_list) -# TODO: update callers to refer to new location. -from jax._src.util import extend_name_stack as extend_name_stack # noqa: F401 -from jax._src.util import wrap_name as wrap_name # noqa: F401 from jax._src.typing import Shape from jax._src.lib import xla_bridge as xb @@ -55,28 +52,12 @@ xe = xc._xla xops = xc._xla.ops # Types -Backend = xe.Client -Device = xc.Device -Buffer = xe.Buffer - -XlaOp = xc.XlaOp -XlaShape = xc.Shape -XlaBuilder = xc.XlaBuilder -XlaLoadedExecutable = Any -XlaLoadedExecutable = xc.LoadedExecutable # type:ignore - -# TODO(phawkins): update code to point to new locations. -DeviceArray = device_array.DeviceArray -_DeviceArray = device_array._DeviceArray -_CppDeviceArray = xe.Buffer -make_device_array = device_array.make_device_array - def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() -def _make_array_shape(a: ShapedArray) -> Sequence[XlaShape]: +def _make_array_shape(a: ShapedArray) -> Sequence[xc.Shape]: if a.dtype == dtypes.float0: return (xc.Shape.array_shape(np.dtype('bool'), a.shape),) else: @@ -89,22 +70,6 @@ def get_canonical_source_file(frame: source_info_util.Frame): '', source_file) return source_file -tracebacks = {} -def make_op_metadata(primitive: core.Primitive, - params: Dict, *, - source_info: source_info_util.SourceInfo, - name_stack: Union[str, source_info_util.NameStack] = "", - ) -> xc.OpMetadata: - eqn_str = (str(source_info.name_stack) + '/' - + str_eqn_compact(primitive.name, params)) - tracebacks[eqn_str] = source_info.traceback - frame = source_info_util.user_frame(source_info) - return xc.OpMetadata( - op_type=primitive.name, - op_name=eqn_str, - source_file=get_canonical_source_file(frame) if frame else None, - source_line=frame.start_line if frame else None) - # Utilities def parameter(builder, num, shape, name=None, replicated=None): @@ -176,47 +141,16 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs): ### handlers -# Numpy dtypes -> XLA primitive types - -_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = { - np.dtype('bool'): xc.PrimitiveType.PRED, - np.dtype('int8'): xc.PrimitiveType.S8, - np.dtype('int16'): xc.PrimitiveType.S16, - np.dtype('int32'): xc.PrimitiveType.S32, - np.dtype('int64'): xc.PrimitiveType.S64, - np.dtype('uint8'): xc.PrimitiveType.U8, - np.dtype('uint16'): xc.PrimitiveType.U16, - np.dtype('uint32'): xc.PrimitiveType.U32, - np.dtype('uint64'): xc.PrimitiveType.U64, - np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16, - np.dtype('float16'): xc.PrimitiveType.F16, - np.dtype('float32'): xc.PrimitiveType.F32, - np.dtype('float64'): xc.PrimitiveType.F64, - np.dtype('complex64'): xc.PrimitiveType.C64, - np.dtype('complex128'): xc.PrimitiveType.C128, -} - -def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType: - """Converts a NumPy dtype into an XLA PrimitiveType.""" - # Many things (e.g., strings, scalar types) can be compared with NumPy dtypes, - # but may not hash correctly. Make sure we have a true np.dtype. - assert isinstance(dtype, np.dtype), type(dtype) - try: - return _dtype_to_primitive_type[dtype] - except KeyError as err: - raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err - - # JAX abstract values -> XLA shapes -def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]: +def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: try: return xla_shape_handlers[type(aval)](aval) except KeyError as err: raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err xla_shape_handlers: Dict[Type[core.AbstractValue], - Callable[[Any], Sequence[XlaShape]]] = { + Callable[[Any], Sequence[xc.Shape]]] = { ShapedArray: _make_array_shape, ConcreteArray: _make_array_shape, } @@ -493,8 +427,8 @@ if not MYPY: def __call__(self, ctx: TranslationContext, avals_in: Sequence[core.AbstractValue], avals_out: Sequence[core.AbstractValue], - *args: XlaOp, **kw - ) -> Sequence[XlaOp]: + *args: xc.XlaOp, **kw + ) -> Sequence[xc.XlaOp]: """A translation rule lowers a primitive invocation into an XLA HLO.""" else: TranslationRule = Any @@ -545,7 +479,7 @@ def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule: @functools.wraps(f) def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue], avals_out: Sequence[core.AbstractValue], - *args: XlaOp, **kw) -> Sequence[XlaOp]: + *args: xc.XlaOp, **kw) -> Sequence[xc.XlaOp]: ans = f(ctx.builder, *args, **kw) if (prim.multiple_results or any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index adfaad290..52754d5fc 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -837,7 +837,7 @@ def _cond_lowering(ctx, index, *args, branches, linear): branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): sub_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun')) + name_stack=extend_name_stack(name_stack, f'branch_{i}_fun')) out_vals, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr.jaxpr, tokens_in, map(mlir.ir_constants, jaxpr.consts), diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 1a06c2c1d..7755660af 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1472,7 +1472,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, cond_args = cond_args[num_tokens:] x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(name_stack, 'cond')) + name_stack=extend_name_stack(name_stack, 'cond')) ((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), _map(mlir.ir_constants, cond_jaxpr.consts), *(x + z), dim_var_values=ctx.dim_var_values) @@ -1504,15 +1504,14 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, tokens_in = mlir.TokenSet(zip(body_effects, token_args)) x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts]) body_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(name_stack, 'body')) + name_stack=extend_name_stack(name_stack, 'body')) new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, tokens_in, _map(mlir.ir_constants, body_jaxpr.consts), *(y + z), dim_var_values=ctx.dim_var_values) out_tokens = [tokens_out.get(eff) for eff in body_effects] if batched: body_pred_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(name_stack, - 'body_pred')) + name_stack=extend_name_stack(name_stack, 'body_pred')) ((body_pred,),), _ = mlir.jaxpr_subcomp( body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), _map(mlir.ir_constants, cond_jaxpr.consts), diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 5a544572d..2a33c8536 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -40,13 +40,13 @@ from jax._src.sharding import NamedSharding from jax._src.interpreters import mlir from jax.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax.interpreters import xla +from jax._src.interpreters import xla from jax.interpreters import batching from jax.interpreters import ad from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3, as_hashable_function, distributed_debug_log, tuple_insert, moveaxis, split_list, wrap_name, - merge_lists, partition_list) + merge_lists, partition_list, extend_name_stack) from jax import lax source_info_util.register_exclusion(__file__) @@ -1373,8 +1373,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. sub_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(ctx.module_context.name_stack, - wrap_name(name, 'xmap'))) + name_stack=extend_name_stack(ctx.module_context.name_stack, + wrap_name(name, 'xmap'))) if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(), @@ -1440,8 +1440,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. sub_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(ctx.module_context.name_stack, - wrap_name(name, 'xmap'))) + name_stack=extend_name_stack(ctx.module_context.name_stack, + wrap_name(name, 'xmap'))) if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, @@ -1491,8 +1491,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, # translation rule just because the extra tuple stuff is a pain. assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext) sub_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(ctx.module_context.name_stack, - wrap_name(name, 'xmap')), + name_stack=extend_name_stack(ctx.module_context.name_stack, + wrap_name(name, 'xmap')), axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes)) if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') diff --git a/jax/_src/stages.py b/jax/_src/stages.py index aa95cac79..123f1b41e 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -37,7 +37,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tu import jax from jax import tree_util -from jax.lib import xla_client as xc +from jax.interpreters import mlir from jax._src import core from jax._src import source_info_util @@ -45,8 +45,7 @@ from jax._src import traceback_util from jax._src import util from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import use_stablehlo -from jax.interpreters import mlir -from jax._src.interpreters import xla +from jax._src.lib import xla_client as xc source_info_util.register_exclusion(__file__) @@ -200,7 +199,7 @@ class Lowering(Protocol): class XlaExecutable(Executable): - def xla_extension_executable(self) -> xla.XlaLoadedExecutable: + def xla_extension_executable(self) -> xc.LoadedExecutable: raise NotImplementedError("must override") def call(self, *args_flat) -> Sequence[Any]: diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index ba35db29b..1ba5e024b 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -21,6 +21,7 @@ from typing import List, Optional from jax.experimental.compilation_cache.gfile_cache import GFileCache from jax._src import path as pathlib +from jax._src.lib import xla_client from jax._src.lib import xla_extension_version from jax._src.lib import version_str as jaxlib_version_str from jax.interpreters import xla @@ -50,7 +51,7 @@ def initialize_cache(path): def get_executable(xla_computation, compile_options, - backend) -> Optional[xla.XlaLoadedExecutable]: + backend) -> Optional[xla_client.LoadedExecutable]: """Returns the cached executable if present, or None otherwise.""" assert _cache is not None, "initialize_cache must be called before you can call get_executable()" cache_key = get_cache_key(xla_computation, compile_options, backend) @@ -63,7 +64,7 @@ def get_executable(xla_computation, compile_options, return xla_executable_deserialized def put_executable(module_name, xla_computation, compile_options, - executable: xla.XlaLoadedExecutable, backend): + executable: xla_client.LoadedExecutable, backend): """Adds 'executable' to the cache, possibly evicting older entries.""" assert _cache is not None, "initialize_cache must be called before you can call put_executable()" cache_key = get_cache_key(xla_computation, compile_options, backend) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 33a8661b8..cec35c28b 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -411,7 +411,7 @@ def _code_generator_and_avals( xla_comp = xla_client.XlaComputation(func_tf_hlo) # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode - def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray: + def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: if not res_shape.is_static(): msg = ("Compiled TensorFlow function has dynamic output shape " + f"{res_shape}. call_tf can used " + diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 74a28fb60..ade89af11 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1083,6 +1083,19 @@ class TensorFlowTracer(core.Tracer): def full_lower(self): return self +def _make_op_metadata(primitive: core.Primitive, + params: Dict, *, + source_info: source_info_util.SourceInfo, + ) -> xla_client.OpMetadata: + eqn_str = (str(source_info.name_stack) + '/' + + core.str_eqn_compact(primitive.name, params)) + frame = source_info_util.user_frame(source_info) + return xla_client.OpMetadata( + op_type=primitive.name, + op_name=eqn_str, + source_file=xla.get_canonical_source_file(frame) if frame else None, + source_line=frame.start_line if frame else None) + class TensorFlowTrace(core.Trace): """Trace class that underlies the jax2tf transformation. @@ -1168,9 +1181,8 @@ class TensorFlowTrace(core.Trace): with tf.name_scope(_sanitize_scope_name(scope)): if _thread_local_state.include_xla_op_metadata: - op_metadata = xla.make_op_metadata(primitive, params, - name_stack=current_name_stack, - source_info=source_info_util.current()) + op_metadata = _make_op_metadata(primitive, params, + source_info=source_info_util.current()) op_metadata_proto = xla_data_pb2.OpMetadata( op_type=op_metadata.op_type, op_name=op_metadata.op_name, diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 0312df701..c204e1df8 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -14,53 +14,32 @@ from jax._src.interpreters.xla import ( AxisEnv as AxisEnv, - Backend as Backend, - Buffer as Buffer, - ConcreteArray as ConcreteArray, - Shape as Shape, - ShapedArray as ShapedArray, - SpatialSharding as SpatialSharding, TranslationContext as TranslationContext, TranslationRule as TranslationRule, - XlaBuilder as XlaBuilder, - XlaLoadedExecutable as XlaLoadedExecutable, - XlaOp as XlaOp, - XlaShape as XlaShape, - _CppDeviceArray as _CppDeviceArray, - _DeviceArray as _DeviceArray, abstractify as abstractify, - aval_to_xla_shapes as aval_to_xla_shapes, axis_groups as axis_groups, - axis_read as axis_read, backend_specific_translations as backend_specific_translations, canonicalize_dtype as canonicalize_dtype, canonicalize_dtype_handlers as canonicalize_dtype_handlers, check_backend_matches as check_backend_matches, - dtype_to_primitive_type as dtype_to_primitive_type, - extend_axis_env as extend_axis_env, - extend_name_stack as extend_name_stack, - jaxpr_collectives as jaxpr_collectives, - make_device_array as make_device_array, - make_op_metadata as make_op_metadata, - new_name_stack as new_name_stack, parameter as parameter, - partition_list as partition_list, - primitive_subcomputation as primitive_subcomputation, pytype_aval_mappings as pytype_aval_mappings, register_collective_primitive as register_collective_primitive, register_initial_style_primitive as register_initial_style_primitive, register_translation as register_translation, sharding_to_proto as sharding_to_proto, translations as translations, - xb as xb, - xc as xc, - xe as xe, xla_call as xla_call, xla_call_p as xla_call_p, xla_destructure as xla_destructure, xla_shape_handlers as xla_shape_handlers, ) +from jax._src.core import ( + ShapedArray as ShapedArray, + ConcreteArray as ConcreteArray, +) + # TODO(phawkins): update users. from jax._src.dispatch import ( apply_primitive as apply_primitive, @@ -68,9 +47,19 @@ from jax._src.dispatch import ( device_put as device_put, ) +from jax._src.lib import xla_bridge as xb +from jax._src.lib import xla_client as xc # type: ignore -from jax._src.interpreters.xla import ( - Device as _deprecated_Device, +_deprecated_Device = xc.Device +XlaOp = xc.XlaOp +xe = xc._xla +Backend = xe.Client +Buffer = xc.Buffer +_CppDeviceArray = xe.Buffer + +from jax._src.device_array import ( + make_device_array as make_device_array, + _DeviceArray as _DeviceArray, DeviceArray as _deprecated_DeviceArray, ) @@ -92,8 +81,8 @@ del _deprecation_getattr import typing if typing.TYPE_CHECKING: - from jax._src.interpreters.xla import ( - Device as Device, + Device = xc.Device + from jax._src.device_array import ( DeviceArray as DeviceArray, ) del typing