Consolidate more XLA-lowering logic between jit, pmap, and xmap.

Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
This commit is contained in:
Peter Hawkins 2021-11-30 14:24:02 -08:00 committed by jax authors
parent dd6a6f206c
commit 68e9e1c26d
22 changed files with 239 additions and 228 deletions

View File

@ -2007,7 +2007,7 @@
" return [xla_consts[id(cnst)] for cnst in consts]\n",
"\n",
"def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:\n",
" return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n",
" return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n",
"\n",
"def _xla_shape(aval: ShapedArray) -> xe.Shape:\n",
" return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)"
@ -3630,7 +3630,7 @@
"\n",
" def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n",
" c = xc.XlaBuilder(name)\n",
" operand = xb.parameter(c, 0, operand_shape)\n",
" operand = xops.Parameter(c, 0, operand_shape)\n",
" operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n",
" outs = jaxpr_subcomp(c, jaxpr, operands)\n",
" return c.build(xops.Tuple(c, outs))\n",

View File

@ -1577,7 +1577,7 @@ def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]:
return [xla_consts[id(cnst)] for cnst in consts]
def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:
return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
def _xla_shape(aval: ShapedArray) -> xe.Shape:
return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
@ -2844,7 +2844,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xc.XlaBuilder(name)
operand = xb.parameter(c, 0, operand_shape)
operand = xops.Parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)
return c.build(xops.Tuple(c, outs))

View File

@ -1569,7 +1569,7 @@ def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]:
return [xla_consts[id(cnst)] for cnst in consts]
def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:
return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
def _xla_shape(aval: ShapedArray) -> xe.Shape:
return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
@ -2836,7 +2836,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xc.XlaBuilder(name)
operand = xb.parameter(c, 0, operand_shape)
operand = xops.Parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)
return c.build(xops.Tuple(c, outs))

View File

@ -860,7 +860,7 @@ def xla_computation(fun: Callable,
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
if out_parts is not None:
out_tuple = xb.with_sharding(c, out_parts_flat, build_out_tuple)
out_tuple = xla.with_sharding(c, out_parts_flat, build_out_tuple)
else:
out_tuple = build_out_tuple()

View File

@ -218,42 +218,16 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
tuple_args = len(abstract_args) > 100
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit'))
module: Any
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
if config.jax_enable_mlir:
# TODO(b/203122001): implement buffer donation.
assert not any(donated_invars), donated_invars
module = mlir.lower_jaxpr_to_module(
core.ClosedJaxpr(jaxpr, consts), backend.platform, axis_env, name_stack)
closed_jaxpr, backend.platform, axis_env, name_stack, donated_invars)
else:
# XLA HLO lowering path
c = xc.XlaBuilder(f"jit_{fun.__name__}")
xla_consts = xla._xla_consts(c, consts)
xla_args, donated_invars = xla._xla_callable_args(
c, abstract_args, tuple_args, donated_invars=donated_invars)
platform = backend.platform
ctx = xla.TranslationContext(c, platform, axis_env, name_stack)
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
# Replace tokens with a dummy array value, because the runtime cannot
# handle token arguments.
out_aval_lens = [len(xla.aval_to_xla_shapes(a)) for a in out_avals]
out_nodes = util.flatten(
[[xla._make_token_return_value(c)] if a is core.abstract_token
else v
for a, v in zip(out_avals, util.unflatten(out_nodes, out_aval_lens))])
# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
if platform in ("gpu", "tpu"):
donated_invars = xla.set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
warnings.warn("Some donated buffers were not usable: {}".format(
", ".join(unused_donations)))
module = c.build(output)
module = xla.lower_jaxpr_to_xla_module(
f"jit_{fun.__name__}", closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args=None,
arg_partitions=None, out_partitions=None)
return XlaComputation(
name, module, False, donated_invars, nreps, device, backend, tuple_args,
abstract_args, out_avals, kept_var_idx)

View File

@ -48,7 +48,6 @@ from jax.interpreters import batching
from jax.interpreters import masking
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
@ -342,7 +341,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
cond_c = xla_client.XlaBuilder("cond_computation")
cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
cond_ctx = ctx.replace(builder=cond_c,
@ -359,7 +358,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
or_, list(range(cond_jaxpr.out_avals[0].ndim)))
body_c = xla_client.XlaBuilder("body_computation")
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
body_carry = xla.parameter(body_c, 0, c.get_shape(init_carry))
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
body_ctx = ctx.replace(builder=body_c,
@ -931,7 +930,7 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
name_stack = extend_name_stack(ctx.name_stack, "cond")
def make_computation(name, jaxpr, op_shape):
c = xla_client.XlaBuilder(name + '_comp')
op = xb.parameter(c, 0, op_shape)
op = xla.parameter(c, 0, op_shape)
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
subctx = ctx.replace(
builder=c, name_stack=extend_name_stack(name_stack, name + '_fun'))

View File

@ -3224,7 +3224,7 @@ def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True):
axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions
subc = xc.XlaBuilder("reduction_computation")
assert len(consts) == 0, "Reduction computations cannot have constants"
args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
ctx = xla.TranslationContext(subc, platform, axis_env, '')
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
if singleton:
@ -3684,7 +3684,7 @@ def _sort_translation_rule(ctx, avals_in, avals_out, *operands, dimension,
c = ctx.builder
types = [c.get_shape(x).xla_element_type() for x in operands]
subc = xc.XlaBuilder("sort_lt_comparator")
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
params = [xla.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
for i, typ in enumerate(types) for j in range(2)]
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
backend=ctx.platform,
@ -3918,7 +3918,7 @@ def _infeed_translation_rule(ctx, avals_in, avals_out, token, *, shapes,
build_infeed = partial(xops.InfeedWithToken, token,
xla_client.Shape.tuple_shape(shape))
if partitions:
xs_and_token = xb.with_sharding(c, partitions, build_infeed)
xs_and_token = xla.with_sharding(c, partitions, build_infeed)
else:
# Note that infeed will default to replication if inside a sharded
# computation and no sharding is specified.
@ -3986,8 +3986,8 @@ def _outfeed_translation_rule(ctx, avals_in, avals_out, token, *xs, partitions):
c = ctx.builder
t = xops.Tuple(c, xs)
if partitions is not None:
return [xb.with_sharding(c, partitions, xops.OutfeedWithToken,
t, token, c.get_shape(t))]
return [xla.with_sharding(c, partitions, xops.OutfeedWithToken,
t, token, c.get_shape(t))]
else:
return [xops.OutfeedWithToken(t, token, c.get_shape(t))]

View File

@ -1487,7 +1487,7 @@ def _scatter_add_translation_rule(
def _make_reducer(dtype):
subc = xc.XlaBuilder("scatter_add_reducer")
shape = xc.Shape.array_shape(np.dtype(dtype), ())
args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)]
args = [xla.parameter(subc, 0, shape), xla.parameter(subc, 1, shape)]
out = xops.Add(args[0], args[1])
return subc.build(out)

View File

@ -806,9 +806,9 @@ def _select_and_gather_add_translation(
def reducer():
c = xc.XlaBuilder("select_and_gather_pair_reducer")
x = xb.parameter(c, 0,
x = xla.parameter(c, 0,
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
y = xb.parameter(c, 1,
y = xla.parameter(c, 1,
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
assert select_prim is lax.ge_p or select_prim is lax.le_p
which = xops.Ge if select_prim is lax.ge_p else xops.Le

View File

@ -23,7 +23,7 @@ XLA. There are also a handful of related casting utilities.
from functools import partial, lru_cache
import os
import threading
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
import warnings
from absl import logging
@ -435,81 +435,3 @@ def host_ids(backend=None):
"instead. jax.host_ids will eventually be removed; please update your "
"code.")
return list(range(process_count(backend)))
### utility functions
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# _sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Tuple[int, ...],
None,
Tuple[Union[Tuple[int, ...], None], ...]]
def _sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xla_client.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
if sharding is None:
proto.type = xla_client.OpSharding.Type.REPLICATED
else:
proto.type = xla_client.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_devices = list(range(np.product(sharding)))
return proto
def tuple_sharding_proto(elems):
proto = xla_client.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xla_client.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
def set_sharding_proto(builder, op, sharding_proto):
"""Uses CustomCall to annotate a value as sharded."""
# "Sharding" is a built-in custom call target that acts like an identity
# function, and is used to attach an OpSharding to.
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
builder, b"Sharding", [op], builder.get_shape(op))
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def set_sharding(builder, op, sharding: SpatialSharding):
"""Uses CustomCall to annotate a value as sharded."""
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)

View File

@ -63,7 +63,6 @@ from typing import (Any, Tuple)
import numpy as np
from jax import lax, core
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src import ad_util, dtypes
@ -167,10 +166,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
def _comparator_builder(operand, op_type, is_max_k):
c = xc.XlaBuilder(
'top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
p0 = xb.parameter(c, 0, xc.Shape.scalar_shape(op_type))
p1 = xb.parameter(c, 1, xc.Shape.scalar_shape(op_type))
xb.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
xb.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type))
xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
if is_max_k:
cmp_result = xc.ops.Gt(p0, p1)
else:

View File

@ -640,7 +640,7 @@ xla.canonicalize_dtype_handlers[BoundedInt] = _bdint_canoncalize_dtype
def _make_params(c, dim_in_avals, in_avals):
n = it.count()
make = lambda a: [xb.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)]
make = lambda a: [xla.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)]
return map(make, dim_in_avals), map(make, in_avals)
def _xla_consts(c, consts):

View File

@ -946,7 +946,7 @@ def _outside_call_translation_rule(ctx, avals_in, avals_out,
token_sharding_proto = xla_client.OpSharding()
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
infeed_sharding_proto = xb.tuple_sharding_proto(
infeed_sharding_proto = xla.tuple_sharding_proto(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])
@ -959,8 +959,8 @@ def _outside_call_translation_rule(ctx, avals_in, avals_out,
build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
xla_client.Shape.tuple_shape(shape))
outs_and_token = xb.with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs_and_token = xla.with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs = xops.GetTupleElement(outs_and_token, 0)
next_itoken = xops.GetTupleElement(outs_and_token, 1)
non_empty_results = [

View File

@ -1464,7 +1464,7 @@ def _xmap_translation_rule_spmd(c, axis_env,
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
sharded_global_in_nodes = [
xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
xla.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
if aval_axes else node
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
]
@ -1478,7 +1478,7 @@ def _xmap_translation_rule_spmd(c, axis_env,
*sharded_global_in_nodes)
sharded_global_out_nodes = [
xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
xla.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
if aval_axes else node
for node, aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
]

View File

@ -37,7 +37,6 @@ from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters.sharded_jit import PartitionSpec
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
@ -514,8 +513,8 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
for i, (n, axis_resources) in enumerate(safe_zip(in_nodes, in_axis_resources)):
# N.B. inlined calls shouldn't have shardings set directly on the inputs or
# outputs (set_sharding_proto adds an identity operation).
arg = xb.parameter(subc, i, c.GetShape(n))
args.append(xb.set_sharding_proto(subc, arg,
arg = xla.parameter(subc, i, c.GetShape(n))
args.append(xla.set_sharding_proto(subc, arg,
get_sharding_proto(c, n, axis_resources, mesh)))
# TODO: Think about how to avoid duplicating constants with the outer jaxpr
@ -525,7 +524,7 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
out_nodes = xla.jaxpr_subcomp(
ctx, jaxpr.jaxpr, xla._xla_consts(subc, jaxpr.consts), *args)
out_nodes = [
xb.set_sharding_proto(subc, out,
xla.set_sharding_proto(subc, out,
get_sharding_proto(subc, out, axis_resources, mesh))
for out, axis_resources in safe_zip(out_nodes, out_axis_resources)
]
@ -815,7 +814,7 @@ ad.deflinear2(sharding_constraint_p,
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *,
axis_resources, resource_env):
mesh = resource_env.physical_mesh
return [xb.set_sharding_proto(
return [xla.set_sharding_proto(
ctx.builder, x_node,
get_sharding_proto(ctx.builder, x_node, axis_resources, mesh))]
xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule)

View File

@ -23,6 +23,7 @@ import typing
from typing import (Any, Callable, Dict, List, Optional, Sequence, Type, Union,
Tuple)
from typing_extensions import Protocol
import warnings
from jax import core
from jax import linear_util as lu
@ -319,11 +320,22 @@ def flatten_lowering_ir_args(
return util.flatten(map(wrap_singleton_ir_values, xs))
def lower_jaxpr_to_module(jaxpr: core.ClosedJaxpr, platform: str,
axis_env: xla.AxisEnv, name_stack: str) -> str:
axis_env: xla.AxisEnv, name_stack: str,
donated_invars: Sequence[bool]) -> str:
"""Lowers a top-level jaxpr to an MHLO module.
Handles the quirks of the argument/return value passing conventions of the
runtime."""
if platform in ("gpu", "tpu"):
# TODO(b/203122001): implement buffer donation.
assert not any(donated_invars), donated_invars
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(a) for a, d in zip(jaxpr.in_avals, donated_invars)
if d]
warnings.warn("Some donated buffers were not usable: {}".format(
", ".join(unused_donations)))
ctx = LoweringContext(platform, axis_env, name_stack)
if platform == "iree":
ctx = ctx.replace(tuple_results=False)

View File

@ -995,31 +995,23 @@ def lower_parallel_callable(
axis_env = xla.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
c = xc.XlaBuilder("pmap_{}".format(fun.__name__))
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
name_stack = extend_name_stack(wrap_name(name, 'pmap'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
xla_args, donated_invars = xla._xla_callable_args(
c, shards.global_sharded_avals, tuple_args(shards),
replicated=replicated_args,
partitions=parts.arg_parts,
donated_invars=donated_invars)
module: Union[str, xc.XlaComputation]
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
ctx = xla.TranslationContext(c, backend.platform, axis_env,
extend_name_stack(wrap_name(name, 'pmap')))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
build_out_tuple = partial(xops.Tuple, c, out_nodes)
if parts.out_parts is not None:
out_tuple = xb.with_sharding(c, parts.out_parts, build_out_tuple)
else:
out_tuple = build_out_tuple()
if config.jax_enable_mlir:
# TODO(phawkins): handle replicated_args.
# TODO(phawkins): handle sharding.
module = mlir.lower_jaxpr_to_module(
closed_jaxpr, backend.platform, axis_env, name_stack, donated_invars)
else:
module = xla.lower_jaxpr_to_xla_module(
f"pmap_{fun.__name__}", closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args(shards), donated_invars, replicated_args,
parts.arg_parts, parts.out_parts)
if backend.platform in ("gpu", "tpu"):
donated_invars = xla.set_up_aliases(c, xla_args, c.GetShape(out_tuple),
donated_invars, tuple_args(shards))
built = c.Build(out_tuple)
return PmapComputation(built, pci, replicas, parts, shards)
return PmapComputation(module, pci, replicas, parts, shards)
class PmapComputation:
@ -1917,8 +1909,6 @@ def lower_mesh_computation(
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
# 3. Build up the HLO
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
in_partitions: Optional[List]
if spmd_lowering:
@ -1929,40 +1919,29 @@ def lower_mesh_computation(
for aval, aval_in_axes in safe_zip(global_in_untiled_avals, in_axes)]
out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto()
for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)]
out_partitions_t = xla.tuple_sharding_proto(out_partitions)
partitions_proto = True
axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) # All named axes have been vmapped
else:
replicated_args = [not axis for axis in in_axes]
in_partitions = None
out_partitions_t = None
partitions_proto = False
axis_env = xla.AxisEnv(nreps=mesh.size,
names=tuple(global_axis_sizes.keys()),
sizes=tuple(global_axis_sizes.values()))
xla_args, donated_invars = xla._xla_callable_args(
c, in_jaxpr_avals, tuple_args,
replicated=replicated_args,
partitions=in_partitions,
partitions_proto=partitions_proto,
donated_invars=donated_invars)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
name_stack = extend_name_stack(wrap_name(transformed_name, 'xmap'))
with core.extend_axis_env_nd(mesh.shape.items()):
ctx = xla.TranslationContext(
c, backend.platform, axis_env,
extend_name_stack(wrap_name(transformed_name, 'xmap')))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
if spmd_lowering:
out_partitions_t = xb.tuple_sharding_proto(out_partitions)
out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)
else:
out_tuple = xops.Tuple(c, out_nodes)
# TODO(phawkins): add MLIR lowering.
module = xla.lower_jaxpr_to_xla_module(
f"xmap_{fun.__name__}", closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args,
in_partitions, out_partitions_t,
partitions_are_protos=partitions_proto)
if backend.platform in ("gpu", "tpu"):
xla.set_up_aliases(c, xla_args, c.GetShape(out_tuple), donated_invars,
tuple_args)
# TODO: Warn about unused donations?
built = c.Build(out_tuple)
return MeshComputation(
built, donated_invars, mesh, local_in_untiled_avals,
module, donated_invars, mesh, local_in_untiled_avals,
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
in_axes, out_axes, spmd_lowering, tuple_args)

View File

@ -147,7 +147,7 @@ def _sharded_callable(
ctx = xla.TranslationContext(
c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit")))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
built = c.Build(out_tuple)
if nparts <= xb.local_device_count():
@ -191,10 +191,10 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
args = []
for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
# We use xb.set_sharding instead of xb.with_sharding because inlined calls
# We use xla.set_sharding instead of xla.with_sharding because inlined calls
# shouldn't have shardings set directly on the inputs or outputs.
arg = xb.parameter(subc, i, c.GetShape(n))
args.append(xb.set_sharding(subc, arg, sharding))
arg = xla.parameter(subc, i, c.GetShape(n))
args.append(xla.set_sharding(subc, arg, sharding))
ctx = xla.TranslationContext(
subc, backend, axis_env,
@ -202,7 +202,7 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
out_nodes = xla.jaxpr_subcomp(ctx, call_jaxpr, (), *args)
out_parts = out_parts_thunk()
assert len(out_parts) == len(out_nodes)
out_nodes = [xb.set_sharding(subc, out, sharding)
out_nodes = [xla.set_sharding(subc, out, sharding)
for out, sharding in safe_zip(out_nodes, out_parts)]
subc = subc.build(xops.Tuple(subc, out_nodes))
@ -218,7 +218,7 @@ def _execute_spatially_partitioned(compiled, in_handler, out_handler, *args):
def _xla_sharded_args(c, avals, in_parts):
xla_args = []
for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)):
param = xb.with_sharding(c, sharding, xb.parameter, c, i,
param = xla.with_sharding(c, sharding, xla.parameter, c, i,
*xla.aval_to_xla_shapes(aval))
xla_args.append(param)
return xla_args
@ -413,7 +413,7 @@ def _sharding_constraint_impl(x, partitions):
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node,
partitions):
return [xb.set_sharding(ctx.builder, x_node, partitions)]
return [xla.set_sharding(ctx.builder, x_node, partitions)]
sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)

View File

@ -22,9 +22,10 @@ from functools import partial
import itertools as it
import operator
import re
from typing import (Any, Callable, Deque, Dict, List, Optional, Sequence, Set,
Type, Tuple, NamedTuple)
from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional,
Sequence, Set, Type, Tuple, Union)
from typing_extensions import Protocol
import warnings
import numpy as np
@ -43,7 +44,6 @@ import jax._src.pretty_printer as pp
from jax._src import util
from jax._src.util import (prod, extend_name_stack, wrap_name,
safe_zip, safe_map, partition_list)
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
@ -113,6 +113,84 @@ def make_op_metadata(primitive: core.Primitive,
source_file=_get_canonical_source_file(frame) if frame else None,
source_line=frame.line_num if frame else None)
# Utilities
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# _sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Tuple[int, ...],
None,
Tuple[Union[Tuple[int, ...], None], ...]]
def _sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xc.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
if sharding is None:
proto.type = xc.OpSharding.Type.REPLICATED
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_devices = list(range(np.product(sharding)))
return proto
def tuple_sharding_proto(elems):
proto = xc.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xc.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
def set_sharding_proto(builder, op, sharding_proto):
"""Uses CustomCall to annotate a value as sharded."""
# "Sharding" is a built-in custom call target that acts like an identity
# function, and is used to attach an OpSharding to.
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
builder, b"Sharding", [op], builder.get_shape(op))
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def set_sharding(builder, op, sharding: SpatialSharding):
"""Uses CustomCall to annotate a value as sharded."""
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
### handlers
# Numpy dtypes -> XLA primitive types
@ -400,7 +478,7 @@ def _xla_callable_args(
if partitions is None:
tuple_parts = None
elif partitions_proto:
tuple_parts = xb.tuple_sharding_proto(partitions)
tuple_parts = tuple_sharding_proto(partitions)
else:
tuple_parts = tuple(partitions)
tuple_shape = xc.Shape.tuple_shape(
@ -419,15 +497,15 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions,
is_token = xla_shape.is_token()
if filter_tokens and is_token:
xla_shape = _token_param_shape()
make_param = partial(xb.parameter, builder, param_num, xla_shape,
make_param = partial(parameter, builder, param_num, xla_shape,
replicated=replicated)
with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
with_sharding_fn = with_sharding_proto if parts_proto else with_sharding
if partitions is None:
out = make_param()
elif partitions is _replicated_param:
out = with_sharding(builder, None, make_param)
out = with_sharding_fn(builder, None, make_param)
else:
out = with_sharding(builder, partitions, make_param)
out = with_sharding_fn(builder, partitions, make_param)
if filter_tokens and is_token:
out = xops.CreateToken(builder)
return out
@ -583,9 +661,9 @@ def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
Given the following computation:
>>> c = xc.XlaBuilder("example")
>>> p0 = xb.parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
>>> p1 = xb.parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
>>> p2 = xb.parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
>>> p0 = parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
>>> p1 = parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
>>> p2 = parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
>>> o = xops.Tuple(c, [p0, p1, p2])
We can query the arrays in the output tuple:
@ -659,6 +737,55 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
def lower_jaxpr_to_xla_module(
fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv,
name_stack: str, tuple_args: bool, donated_invars: Sequence[bool],
replicated_args: Optional[Sequence[bool]],
arg_partitions: Optional[Any],
out_partitions: Optional[Any],
partitions_are_protos: bool = False
) -> xc.XlaComputation:
"""Lowers a closed jaxpr to a top-level XLA module."""
c = xc.XlaBuilder(fn_name)
xla_consts = _xla_consts(c, jaxpr.consts)
xla_args, donated_invars = _xla_callable_args(
c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars,
replicated=replicated_args, partitions=arg_partitions,
partitions_proto=partitions_are_protos)
ctx = TranslationContext(c, platform, axis_env, name_stack)
out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
# Replace tokens with a dummy array value, because the runtime cannot
# handle token arguments.
out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
out_nodes = util.flatten(
[[_make_token_return_value(c)] if a is core.abstract_token
else v for a, v in zip(jaxpr.out_avals,
util.unflatten(out_nodes, out_aval_lens))])
# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
if out_partitions is None:
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
else:
build_out_tuple = partial(xops.Tuple, c, out_nodes)
if partitions_are_protos:
output = with_sharding_proto(c, out_partitions, build_out_tuple)
else:
output = with_sharding(c, out_partitions, build_out_tuple)
if platform in ("gpu", "tpu"):
donated_invars = set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
warnings.warn("Some donated buffers were not usable: {}".format(
", ".join(unused_donations)))
return c.build(output)
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
@ -698,7 +825,7 @@ def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
c = ctx.builder
check_backend_matches(backend, ctx.platform)
subc = xc.XlaBuilder(f"jit_{name}")
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
sub_ctx = ctx.replace(
builder=subc,
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit')))
@ -934,7 +1061,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
true_op = xops.Tuple(c, in_nodes)
remat_subc = xc.XlaBuilder("remat_call_subcomputation")
input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
input_op = parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
args = xla_destructure(remat_subc, input_op)
sub_ctx = ctx.replace(
builder=remat_subc,
@ -945,7 +1072,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
false_op = true_op
dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation")
xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
@ -959,7 +1086,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
# Dummy subc for getting subcomp shapes.
dummy_inputs = xops.Tuple(c, in_nodes)
dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation")
dummy_input_op = xb.parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
dummy_input_op = parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
dummy_args = xla_destructure(dummy_subc, dummy_input_op)
dummy_ctx = ctx.replace(
builder=dummy_subc,
@ -972,7 +1099,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
input_op = parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
i = xops.GetTupleElement(input_op, 0)
rng = xops.RngUniform(xops.Constant(cond_subc, np.array(1, dtype=np.int32)),
xops.Constant(cond_subc, np.array(2, dtype=np.int32)),
@ -980,7 +1107,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
cond_subc = cond_subc.build(xops.Lt(i, rng))
body_subc = xc.XlaBuilder("remat_body_subcomputation")
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
input_op = parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32)))
body_ctx = ctx.replace(
@ -1019,7 +1146,7 @@ def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
check_backend_matches(backend, ctx.platform)
c = ctx.builder
subc = xc.XlaBuilder(name)
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
args = [parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
sub_ctx = ctx.replace(builder=subc,
name_stack=extend_name_stack(ctx.name_stack, name))
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)

View File

@ -244,8 +244,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# Jit and Donate arguments
def test_jit_donate_argnums_warning_raised(self):
if jax.config.jax_enable_mlir:
raise unittest.SkipTest("Buffer donation not yet implemented via MLIR")
x = jnp.array([1.0, 2.0], jnp.float32)
y = jnp.array([1, 2], jnp.int32)
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
@ -256,7 +254,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn(
"Some donated buffers were not usable: f32[2]{0}, s32[2]{0}",
"Some donated buffers were not usable:",
str(w[-1].message))
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.

View File

@ -26,6 +26,7 @@ from jax.interpreters import xla
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge, xla_client
xops = xla_client.ops
xc = xla_client
xb = xla_bridge
from jax.config import config
@ -113,14 +114,14 @@ def sparse_array_result_handler(device, aval):
def sparse_array_shape_handler(a):
return (
xla.xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),
xla.xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),
xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),
xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),
)
def sparse_array_device_put_handler(a, device):
return (
xla.xb.get_device_backend(device).buffer_from_pyval(a.data, device),
xla.xb.get_device_backend(device).buffer_from_pyval(a.indices, device)
xb.get_device_backend(device).buffer_from_pyval(a.data, device),
xb.get_device_backend(device).buffer_from_pyval(a.indices, device)
)
def sparse_array_constant_handler(c, val, canonicalize_dtypes):

View File

@ -19,6 +19,7 @@ from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax.interpreters import xla
from jax._src.config import config
config.parse_flags_with_absl()
@ -47,13 +48,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
def test_parameter_replication_default(self):
c = xc.XlaBuilder("test")
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
built_c = c.Build()
assert "replication" not in built_c.as_hlo_text()
def test_parameter_replication(self):
c = xc.XlaBuilder("test")
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
False)
built_c = c.Build()
assert "parameter_replication={false}" in built_c.as_hlo_text()