Move functions out of xla.py closer to their users.

Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
This commit is contained in:
Peter Hawkins 2023-08-08 14:39:57 -07:00 committed by jax authors
parent d01695c746
commit ca17b6c08f
12 changed files with 123 additions and 139 deletions

View File

@ -333,6 +333,13 @@ register_constant_handler(core.Token, _token_constant_handler)
# Source locations
def get_canonical_source_file(frame: source_info_util.Frame) -> str:
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
return source_file
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
"""Converts a full traceback to a callsite() MLIR location."""
frame_locs = []
@ -340,7 +347,7 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
frame = source_info_util.raw_frame_to_frame(code, lasti)
if source_info_util.is_user_filename(frame.file_name):
file_loc = ir.Location.file(
xla.get_canonical_source_file(frame),
get_canonical_source_file(frame),
frame.start_line,
frame.start_column,
)
@ -371,7 +378,7 @@ def _source_info_to_location(
if frame is None:
loc = ir.Location.unknown()
else:
loc = ir.Location.file(xla.get_canonical_source_file(frame),
loc = ir.Location.file(get_canonical_source_file(frame),
frame.start_line, frame.start_column)
loc = ir.Location.name(eqn_str, childLoc=loc)
# TODO(phawkins): also include primitive.name as the operator type.
@ -1383,13 +1390,25 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
return func_op
def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend is None:
return
if (inner_backend != outer_backend and
outer_backend not in xb.expand_platform_alias(inner_backend)):
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
avals_out, tokens_in, *args,
dim_var_values: Sequence[ir.Value],
arg_names=None, result_names=None):
if isinstance(call_jaxpr, core.Jaxpr):
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
xla.check_backend_matches(backend, ctx.platform)
check_backend_matches(backend, ctx.platform)
effects = list(tokens_in.effects())
output_types = map(aval_to_ir_types, avals_out)
output_types = [token_type()] * len(effects) + output_types

View File

@ -1235,15 +1235,43 @@ def _pmap_dce_rule(used_outputs, eqn):
return used_inputs, new_eqn
def _xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
donated_invars = (False,) * num_new_inputs
else:
assert len(kept_inputs) == len(donated_invars)
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))
def xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
# Set param update handlers to update `donated_invars` just like xla_call_p
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
pe.call_param_updaters[xla_pmap_p] = _xla_call_partial_eval_update_params
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
partial(pe.call_partial_eval_custom_rule,
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
res_aval=_pmap_partial_eval_custom_res_maker)
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params
ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
@ -1289,6 +1317,38 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
raise TypeError(aval)
def _axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None
def axis_groups(axis_env: sharding_impls.AxisEnv, name) -> tuple[tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(_axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
# TODO(b/110096942): more efficient gather
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
if aval is core.abstract_token:
@ -1311,7 +1371,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
x, mlir.dense_int_elements([1])).result
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
replica_groups = mlir.dense_int_elements(
xla.axis_groups(axis_env, axis_env.names[-1]))
axis_groups(axis_env, axis_env.names[-1]))
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
if out_axis != 0:
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
@ -1335,17 +1395,22 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
raise TypeError(aval)
def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int):
return sharding_impls.AxisEnv(env.nreps, env.names + (name,),
env.sizes + (size,))
def _pmap_lowering(ctx, *in_nodes, axis_name,
axis_size, global_axis_size, devices, name,
call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, is_explicit_global_axis_size):
del donated_invars # Unused.
xla.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platform)
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
if ctx.module_context.axis_env.names and devices is not None:
raise ValueError("Nested pmap with explicit devices argument.")
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
new_env = _extend_axis_env(ctx.module_context.axis_env, axis_name,
global_axis_size)
# Shard the in_nodes that are mapped
in_avals = [v.aval for v in call_jaxpr.invars]

View File

@ -20,15 +20,11 @@ import dataclasses
import functools
from functools import partial
import itertools as it
import math
import operator
import re
from typing import Any, Callable, Optional, Protocol, Union
import numpy as np
from jax._src.config import config
from jax._src import core
from jax._src import dtypes
from jax._src import source_info_util
@ -59,13 +55,6 @@ def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
return (xc.Shape.array_shape(dtype, aval.shape),)
def get_canonical_source_file(frame: source_info_util.Frame):
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
return source_file
# Utilities
def parameter(builder, num, shape, name=None, replicated=None):
@ -121,18 +110,6 @@ def tuple_sharding_proto(elems):
return proto
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args,
**kwargs)
### handlers
@ -141,16 +118,16 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
try:
return xla_shape_handlers[type(aval)](aval)
return _xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
xla_shape_handlers: dict[type[core.AbstractValue],
_xla_shape_handlers: dict[type[core.AbstractValue],
Callable[[Any], Sequence[xc.Shape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
# IR constants
@ -270,52 +247,6 @@ def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend is None:
return
if (inner_backend != outer_backend and
outer_backend not in xb.expand_platform_alias(inner_backend)):
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
def extend_axis_env(env: AxisEnv, name, size: int):
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
def axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None
def axis_groups(axis_env: AxisEnv, name) -> tuple[tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
# TODO(mattjj,skyewm): the functions here are utilities for checking if
# not-yet-supported features are used with multi-host programming
@ -329,37 +260,6 @@ def jaxpr_collectives(jaxpr):
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
### xla_call underlying jit
def xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
donated_invars = (False,) * num_new_inputs
else:
assert len(kept_inputs) == len(donated_invars)
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))
def xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
def xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
### translation tables
MYPY = False

View File

@ -4184,9 +4184,6 @@ def _top_k_batch_rule(batched_args, batch_dims, *, k):
else:
return top_k(operand, k=k), (bdim, bdim)
def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
return xla.xla_destructure(ctx.builder, xops.TopK(x, k))
top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))

View File

@ -690,7 +690,7 @@ def _batched_reduction_collective(
return vals_out, [batching.not_mapped] * len(vals_out)
def _replica_groups(axis_env, axis_name, axis_index_groups):
replica_groups = xla.axis_groups(axis_env, axis_name)
replica_groups = pxla.axis_groups(axis_env, axis_name)
if axis_index_groups is not None:
replica_groups = [[axis_group[i] for i in axis_index_group]
for axis_group in replica_groups

View File

@ -50,7 +50,6 @@ from jax._src.interpreters.partial_eval import (
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
convert_constvars_jaxpr, new_jaxpr_eqn)
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.pjit import (sharding_constraint_p, get_unconstrained_dims,
GSPMDSharding)
from jax._src.sharding_impls import (
@ -868,7 +867,7 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
# SPMD batching always gets involved as the last transform before XLA translation
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params
ad.call_param_updaters[xmap_p] = pxla.xla_call_jvp_update_params
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
@ -1305,7 +1304,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
global_axis_sizes,
spmd_in_axes, spmd_out_axes,
axis_resources, resource_env, backend):
xla.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platform)
# The only way for any of those two assertions to be violated is when xmap
# is using the SPMD lowering, but then this rule shouldn't even trigger.
assert spmd_in_axes is None and spmd_out_axes is None
@ -1382,7 +1381,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
donated_invars, global_axis_sizes, spmd_in_axes,
spmd_out_axes, axis_resources,
resource_env, backend):
xla.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platform)
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes)
@ -1450,7 +1449,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
resource_env, backend):
assert spmd_in_axes is None and spmd_out_axes is None
# This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule.
xla.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platform)
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes)
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))

View File

@ -1061,6 +1061,15 @@ def _outside_call_impl(*args, **params):
outside_call_p.def_impl(_outside_call_impl)
def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def _outside_call_translation_rule(ctx,
avals_in,
avals_out,
@ -1137,7 +1146,7 @@ def _outside_call_translation_rule(ctx,
build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
xla_client.Shape.tuple_shape(shape))
outs_and_token = xla.with_sharding_proto(comp, infeed_sharding_proto,
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs = xops.GetTupleElement(outs_and_token, 0)
next_itoken = xops.GetTupleElement(outs_and_token, 1)

View File

@ -57,6 +57,7 @@ from jax._src import random as random_internal
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
@ -1217,7 +1218,7 @@ def _make_op_metadata(primitive: core.Primitive,
return xla_client.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=xla.get_canonical_source_file(frame) if frame else None,
source_file=mlir.get_canonical_source_file(frame) if frame else None,
source_line=frame.start_line if frame else None)

View File

@ -16,20 +16,17 @@ from jax._src.interpreters.xla import (
TranslationContext as TranslationContext,
TranslationRule as TranslationRule,
abstractify as abstractify,
axis_groups as axis_groups,
backend_specific_translations as backend_specific_translations,
canonicalize_dtype as canonicalize_dtype,
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
check_backend_matches as check_backend_matches,
parameter as parameter,
pytype_aval_mappings as pytype_aval_mappings,
register_collective_primitive as register_collective_primitive,
register_initial_style_primitive as register_initial_style_primitive,
register_translation as register_translation,
sharding_to_proto as sharding_to_proto,
translations as translations,
xla_destructure as xla_destructure,
xla_shape_handlers as xla_shape_handlers,
)
from jax._src.interpreters.pxla import (
axis_groups as axis_groups,
)
from jax._src.core import (

View File

@ -123,7 +123,6 @@ core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
def sparse_array_mlir_type_handler(a):
return (
@ -258,7 +257,6 @@ core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty()
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
@unittest.skip("Test does not work with jax.Array")

View File

@ -50,7 +50,6 @@ from jax._src.lib import xla_extension
from jax._src.util import safe_map, safe_zip
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax.interpreters import xla
from jax._src import array
from jax._src.sharding_impls import PmapSharding
from jax.ad_checkpoint import checkpoint as new_checkpoint
@ -1082,16 +1081,16 @@ class PythonPmapTest(jtu.JaxTestCase):
def testAxisGroups(self):
axis_env = sharding_impls.AxisEnv(8, ('i', 'j'), (4, 2))
groups = xla.axis_groups(axis_env, 'i')
groups = pxla.axis_groups(axis_env, 'i')
self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))
groups = xla.axis_groups(axis_env, 'j')
groups = pxla.axis_groups(axis_env, 'j')
self.assertEqual(groups, ((0, 1), (2, 3), (4, 5), (6, 7)))
groups = xla.axis_groups(axis_env, ('i', 'j'))
groups = pxla.axis_groups(axis_env, ('i', 'j'))
self.assertEqual(groups, ((0, 1, 2, 3, 4, 5, 6, 7,),))
groups = xla.axis_groups(axis_env, ('j', 'i'))
groups = pxla.axis_groups(axis_env, ('j', 'i'))
self.assertEqual(len(groups), 1)
self.assertEqual((tuple(sorted(groups[0])),),
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter

View File

@ -22,7 +22,7 @@ from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax.interpreters import xla
from jax._src.interpreters import xla
from jax._src.config import config
config.parse_flags_with_absl()