From ca17b6c08f19daa60d681533c51fa019929fcc43 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 8 Aug 2023 14:39:57 -0700 Subject: [PATCH] Move functions out of xla.py closer to their users. Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility. Remove an unused top_k translation rule as well. PiperOrigin-RevId: 554946059 --- jax/_src/interpreters/mlir.py | 25 ++++++- jax/_src/interpreters/pxla.py | 79 ++++++++++++++++++++-- jax/_src/interpreters/xla.py | 106 +----------------------------- jax/_src/lax/lax.py | 3 - jax/_src/lax/parallel.py | 2 +- jax/_src/maps.py | 9 ++- jax/experimental/host_callback.py | 13 +++- jax/experimental/jax2tf/jax2tf.py | 3 +- jax/interpreters/xla.py | 9 +-- tests/custom_object_test.py | 2 - tests/pmap_test.py | 9 ++- tests/xla_bridge_test.py | 2 +- 12 files changed, 123 insertions(+), 139 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 0bc185e2b..612d2c374 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -333,6 +333,13 @@ register_constant_handler(core.Token, _token_constant_handler) # Source locations +def get_canonical_source_file(frame: source_info_util.Frame) -> str: + source_file = frame.file_name + if config.jax_hlo_source_file_canonicalization_regex: + source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex, + '', source_file) + return source_file + def _traceback_to_location(tb: xc.Traceback) -> ir.Location: """Converts a full traceback to a callsite() MLIR location.""" frame_locs = [] @@ -340,7 +347,7 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location: frame = source_info_util.raw_frame_to_frame(code, lasti) if source_info_util.is_user_filename(frame.file_name): file_loc = ir.Location.file( - xla.get_canonical_source_file(frame), + get_canonical_source_file(frame), frame.start_line, frame.start_column, ) @@ -371,7 +378,7 @@ def _source_info_to_location( if frame is None: loc = ir.Location.unknown() else: - loc = ir.Location.file(xla.get_canonical_source_file(frame), + loc = ir.Location.file(get_canonical_source_file(frame), frame.start_line, frame.start_column) loc = ir.Location.name(eqn_str, childLoc=loc) # TODO(phawkins): also include primitive.name as the operator type. @@ -1383,13 +1390,25 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, return func_op +def check_backend_matches(inner_backend, outer_backend): + # For nested calls, the outermost call sets the backend for all inner calls; + # it's an error if the inner call has a conflicting explicit backend spec. + if inner_backend is None: + return + if (inner_backend != outer_backend and + outer_backend not in xb.expand_platform_alias(inner_backend)): + raise ValueError( + f"Outer-jit backend specification {outer_backend} must match explicit " + f"inner-jit backend specification {inner_backend}.") + + def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, avals_out, tokens_in, *args, dim_var_values: Sequence[ir.Value], arg_names=None, result_names=None): if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) - xla.check_backend_matches(backend, ctx.platform) + check_backend_matches(backend, ctx.platform) effects = list(tokens_in.effects()) output_types = map(aval_to_ir_types, avals_out) output_types = [token_type()] * len(effects) + output_types diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index ab29c9ca1..4176ec44a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1235,15 +1235,43 @@ def _pmap_dce_rule(used_outputs, eqn): return used_inputs, new_eqn +def _xla_call_partial_eval_update_params( + params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int + ) -> core.ParamDict: + donated_invars = params['donated_invars'] + if not kept_inputs and donated_invars: + # JaxprTrace.post_process_call creates a call with no input tracers + donated_invars = (False,) * num_new_inputs + else: + assert len(kept_inputs) == len(donated_invars) + # JaxprTrace.process_call drops known input tracers + donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept] + # Any new inputs are prepended to the left, so mark those as not donated. + donated_invars = [False] * num_new_inputs + donated_invars + return dict(params, donated_invars=tuple(donated_invars)) + +def xla_call_jvp_update_params(params, nz_tangents): + donated_invars = params['donated_invars'] + donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz] + new_donated_invars = (*donated_invars, *donated_tangents) + return dict(params, donated_invars=new_donated_invars) + +def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): + donated_invars = params['donated_invars'] + donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u] + donated_cotangents = [False for nz in nonzero_cts if nz] + return dict(params, donated_invars=(*donated_primals, *donated_cotangents)) + + # Set param update handlers to update `donated_invars` just like xla_call_p -pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params +pe.call_param_updaters[xla_pmap_p] = _xla_call_partial_eval_update_params pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \ partial(pe.call_partial_eval_custom_rule, 'call_jaxpr', _pmap_partial_eval_custom_params_updater, res_aval=_pmap_partial_eval_custom_res_maker) pe.dce_rules[xla_pmap_p] = _pmap_dce_rule -ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params -ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params +ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params +ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) @@ -1289,6 +1317,38 @@ def _hlo_shard(aval, axis_env, xs, in_axis): raise TypeError(aval) +def _axis_read(axis_env, axis_name): + try: + return max(i for i, name in enumerate(axis_env.names) if name == axis_name) + except ValueError: + raise NameError(f"unbound axis name: {axis_name}") from None + +def axis_groups(axis_env: sharding_impls.AxisEnv, name) -> tuple[tuple[int, ...]]: + if not isinstance(name, (list, tuple)): + name = (name,) + mesh_axes = tuple(unsafe_map(partial(_axis_read, axis_env), name)) + trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes)) + assert not ragged + mesh_spec = axis_env.sizes + (trailing_size,) + return _axis_groups(mesh_spec, mesh_axes) + +def _axis_groups(mesh_spec, mesh_axes): + """Computes replica group ids for a collective performed over a subset of the mesh. + + Args: + mesh_spec: A sequence of integers representing the mesh shape. + mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive) + indicating over which axes the collective is performed. + Returns: + A tuple of replica groups (i.e. tuples containing replica ids). + """ + iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec) + groups = np.reshape( + np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))), + (math.prod(np.take(mesh_spec, mesh_axes)), -1)) + return tuple(unsafe_map(tuple, groups.T)) + + # TODO(b/110096942): more efficient gather def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform): if aval is core.abstract_token: @@ -1311,7 +1371,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl x, mlir.dense_int_elements([1])).result padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result replica_groups = mlir.dense_int_elements( - xla.axis_groups(axis_env, axis_env.names[-1])) + axis_groups(axis_env, axis_env.names[-1])) out = hlo.CrossReplicaSumOp(padded, replica_groups).result if out_axis != 0: # TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead @@ -1335,18 +1395,23 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl raise TypeError(aval) +def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int): + return sharding_impls.AxisEnv(env.nreps, env.names + (name,), + env.sizes + (size,)) + + def _pmap_lowering(ctx, *in_nodes, axis_name, axis_size, global_axis_size, devices, name, call_jaxpr, backend=None, in_axes, out_axes, donated_invars, is_explicit_global_axis_size): del donated_invars # Unused. - xla.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platform) # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. if ctx.module_context.axis_env.names and devices is not None: raise ValueError("Nested pmap with explicit devices argument.") - new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name, - global_axis_size) + new_env = _extend_axis_env(ctx.module_context.axis_env, axis_name, + global_axis_size) # Shard the in_nodes that are mapped in_avals = [v.aval for v in call_jaxpr.invars] in_nodes_sharded = ( diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 067011541..4b5bcc27d 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -20,15 +20,11 @@ import dataclasses import functools from functools import partial import itertools as it -import math import operator -import re from typing import Any, Callable, Optional, Protocol, Union import numpy as np -from jax._src.config import config - from jax._src import core from jax._src import dtypes from jax._src import source_info_util @@ -59,13 +55,6 @@ def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]: dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype return (xc.Shape.array_shape(dtype, aval.shape),) -def get_canonical_source_file(frame: source_info_util.Frame): - source_file = frame.file_name - if config.jax_hlo_source_file_canonicalization_regex: - source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex, - '', source_file) - return source_file - # Utilities def parameter(builder, num, shape, name=None, replicated=None): @@ -121,18 +110,6 @@ def tuple_sharding_proto(elems): return proto -def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): - """Builds op_fn(*args, **kwargs) with sharding annotation.""" - builder.set_sharding(sharding_proto) - try: - return op_fn(*args, **kwargs) - finally: - builder.clear_sharding() - -def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs): - """Builds op_fn(*args, **kwargs) with sharding annotation.""" - return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args, - **kwargs) ### handlers @@ -141,16 +118,16 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs): def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: try: - return xla_shape_handlers[type(aval)](aval) + return _xla_shape_handlers[type(aval)](aval) except KeyError as err: raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err -xla_shape_handlers: dict[type[core.AbstractValue], +_xla_shape_handlers: dict[type[core.AbstractValue], Callable[[Any], Sequence[xc.Shape]]] = { ShapedArray: _make_array_shape, ConcreteArray: _make_array_shape, } -xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) +_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) # IR constants @@ -270,52 +247,6 @@ def xla_destructure(c, ans): num_elements = len(c.get_shape(ans).tuple_shapes()) return [xops.GetTupleElement(ans, i) for i in range(num_elements)] -def check_backend_matches(inner_backend, outer_backend): - # For nested calls, the outermost call sets the backend for all inner calls; - # it's an error if the inner call has a conflicting explicit backend spec. - if inner_backend is None: - return - if (inner_backend != outer_backend and - outer_backend not in xb.expand_platform_alias(inner_backend)): - raise ValueError( - f"Outer-jit backend specification {outer_backend} must match explicit " - f"inner-jit backend specification {inner_backend}.") - - -def extend_axis_env(env: AxisEnv, name, size: int): - return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,)) - -def axis_read(axis_env, axis_name): - try: - return max(i for i, name in enumerate(axis_env.names) if name == axis_name) - except ValueError: - raise NameError(f"unbound axis name: {axis_name}") from None - -def axis_groups(axis_env: AxisEnv, name) -> tuple[tuple[int, ...]]: - if not isinstance(name, (list, tuple)): - name = (name,) - mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name)) - trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes)) - assert not ragged - mesh_spec = axis_env.sizes + (trailing_size,) - return _axis_groups(mesh_spec, mesh_axes) - -def _axis_groups(mesh_spec, mesh_axes): - """Computes replica group ids for a collective performed over a subset of the mesh. - - Args: - mesh_spec: A sequence of integers representing the mesh shape. - mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive) - indicating over which axes the collective is performed. - Returns: - A tuple of replica groups (i.e. tuples containing replica ids). - """ - iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec) - groups = np.reshape( - np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))), - (math.prod(np.take(mesh_spec, mesh_axes)), -1)) - return tuple(unsafe_map(tuple, groups.T)) - # TODO(mattjj,skyewm): the functions here are utilities for checking if # not-yet-supported features are used with multi-host programming @@ -329,37 +260,6 @@ def jaxpr_collectives(jaxpr): for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr) -### xla_call underlying jit - - -def xla_call_partial_eval_update_params( - params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int - ) -> core.ParamDict: - donated_invars = params['donated_invars'] - if not kept_inputs and donated_invars: - # JaxprTrace.post_process_call creates a call with no input tracers - donated_invars = (False,) * num_new_inputs - else: - assert len(kept_inputs) == len(donated_invars) - # JaxprTrace.process_call drops known input tracers - donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept] - # Any new inputs are prepended to the left, so mark those as not donated. - donated_invars = [False] * num_new_inputs + donated_invars - return dict(params, donated_invars=tuple(donated_invars)) - -def xla_call_jvp_update_params(params, nz_tangents): - donated_invars = params['donated_invars'] - donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz] - new_donated_invars = (*donated_invars, *donated_tangents) - return dict(params, donated_invars=new_donated_invars) - -def xla_call_transpose_update_params(params, undef_primals, nonzero_cts): - donated_invars = params['donated_invars'] - donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u] - donated_cotangents = [False for nz in nonzero_cts if nz] - return dict(params, donated_invars=(*donated_primals, *donated_cotangents)) - - ### translation tables MYPY = False diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d5e4e3089..ad8ab6e0e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4184,9 +4184,6 @@ def _top_k_batch_rule(batched_args, batch_dims, *, k): else: return top_k(operand, k=k), (bdim, bdim) -def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k): - return xla.xla_destructure(ctx.builder, xops.TopK(x, k)) - top_k_p = Primitive('top_k') top_k_p.multiple_results = True top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p)) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 751a3ab84..c31e41693 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -690,7 +690,7 @@ def _batched_reduction_collective( return vals_out, [batching.not_mapped] * len(vals_out) def _replica_groups(axis_env, axis_name, axis_index_groups): - replica_groups = xla.axis_groups(axis_env, axis_name) + replica_groups = pxla.axis_groups(axis_env, axis_name) if axis_index_groups is not None: replica_groups = [[axis_group[i] for i in axis_index_group] for axis_group in replica_groups diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 20cf8490b..615441da8 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -50,7 +50,6 @@ from jax._src.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, convert_constvars_jaxpr, new_jaxpr_eqn) from jax._src.interpreters import pxla -from jax._src.interpreters import xla from jax._src.pjit import (sharding_constraint_p, get_unconstrained_dims, GSPMDSharding) from jax._src.sharding_impls import ( @@ -868,7 +867,7 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst # NOTE: We don't have to handle spmd_{in|out}_axes here, because # SPMD batching always gets involved as the last transform before XLA translation ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore -ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params +ad.call_param_updaters[xmap_p] = pxla.xla_call_jvp_update_params def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts @@ -1305,7 +1304,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): - xla.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platform) # The only way for any of those two assertions to be violated is when xmap # is using the SPMD lowering, but then this rule shouldn't even trigger. assert spmd_in_axes is None and spmd_out_axes is None @@ -1382,7 +1381,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, donated_invars, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): - xla.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platform) plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes) @@ -1450,7 +1449,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, resource_env, backend): assert spmd_in_axes is None and spmd_out_axes is None # This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule. - xla.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platform) plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes) manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index aa3c3ae93..90afe770a 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1061,6 +1061,15 @@ def _outside_call_impl(*args, **params): outside_call_p.def_impl(_outside_call_impl) +def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): + """Builds op_fn(*args, **kwargs) with sharding annotation.""" + builder.set_sharding(sharding_proto) + try: + return op_fn(*args, **kwargs) + finally: + builder.clear_sharding() + + def _outside_call_translation_rule(ctx, avals_in, avals_out, @@ -1137,8 +1146,8 @@ def _outside_call_translation_rule(ctx, build_infeed = functools.partial(xops.InfeedWithToken, after_outfeed_itoken, xla_client.Shape.tuple_shape(shape)) - outs_and_token = xla.with_sharding_proto(comp, infeed_sharding_proto, - build_infeed) + outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto, + build_infeed) outs = xops.GetTupleElement(outs_and_token, 0) next_itoken = xops.GetTupleElement(outs_and_token, 1) non_empty_results = [ diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 2cc61c64d..15c26609f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -57,6 +57,7 @@ from jax._src import random as random_internal from jax._src import source_info_util from jax._src import util from jax._src.interpreters import ad +from jax._src.interpreters import mlir from jax._src.lax import control_flow as lax_control_flow from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg @@ -1217,7 +1218,7 @@ def _make_op_metadata(primitive: core.Primitive, return xla_client.OpMetadata( op_type=primitive.name, op_name=eqn_str, - source_file=xla.get_canonical_source_file(frame) if frame else None, + source_file=mlir.get_canonical_source_file(frame) if frame else None, source_line=frame.start_line if frame else None) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 0f6511a97..14698c631 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -16,20 +16,17 @@ from jax._src.interpreters.xla import ( TranslationContext as TranslationContext, TranslationRule as TranslationRule, abstractify as abstractify, - axis_groups as axis_groups, backend_specific_translations as backend_specific_translations, canonicalize_dtype as canonicalize_dtype, canonicalize_dtype_handlers as canonicalize_dtype_handlers, - check_backend_matches as check_backend_matches, - parameter as parameter, pytype_aval_mappings as pytype_aval_mappings, register_collective_primitive as register_collective_primitive, - register_initial_style_primitive as register_initial_style_primitive, register_translation as register_translation, - sharding_to_proto as sharding_to_proto, translations as translations, xla_destructure as xla_destructure, - xla_shape_handlers as xla_shape_handlers, +) +from jax._src.interpreters.pxla import ( + axis_groups as axis_groups, ) from jax._src.core import ( diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index f24ad5e6f..e64d88372 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -123,7 +123,6 @@ core.pytype_aval_mappings[SparseArray] = lambda x: x.aval core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x -xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler def sparse_array_mlir_type_handler(a): return ( @@ -258,7 +257,6 @@ core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty() core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty() xla.canonicalize_dtype_handlers[Empty] = lambda x: x -xla.xla_shape_handlers[AbstractEmpty] = lambda _: () @unittest.skip("Test does not work with jax.Array") diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 58d05b5a7..7bc72ca0e 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -50,7 +50,6 @@ from jax._src.lib import xla_extension from jax._src.util import safe_map, safe_zip from jax._src.interpreters import mlir from jax._src.interpreters import pxla -from jax.interpreters import xla from jax._src import array from jax._src.sharding_impls import PmapSharding from jax.ad_checkpoint import checkpoint as new_checkpoint @@ -1082,16 +1081,16 @@ class PythonPmapTest(jtu.JaxTestCase): def testAxisGroups(self): axis_env = sharding_impls.AxisEnv(8, ('i', 'j'), (4, 2)) - groups = xla.axis_groups(axis_env, 'i') + groups = pxla.axis_groups(axis_env, 'i') self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7))) - groups = xla.axis_groups(axis_env, 'j') + groups = pxla.axis_groups(axis_env, 'j') self.assertEqual(groups, ((0, 1), (2, 3), (4, 5), (6, 7))) - groups = xla.axis_groups(axis_env, ('i', 'j')) + groups = pxla.axis_groups(axis_env, ('i', 'j')) self.assertEqual(groups, ((0, 1, 2, 3, 4, 5, 6, 7,),)) - groups = xla.axis_groups(axis_env, ('j', 'i')) + groups = pxla.axis_groups(axis_env, ('j', 'i')) self.assertEqual(len(groups), 1) self.assertEqual((tuple(sorted(groups[0])),), ((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 2373691c0..b6fb5d53e 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -22,7 +22,7 @@ from absl.testing import absltest from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax.interpreters import xla +from jax._src.interpreters import xla from jax._src.config import config config.parse_flags_with_absl()