mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla. PiperOrigin-RevId: 413244450
This commit is contained in:
parent
dd6a6f206c
commit
68e9e1c26d
@ -2007,7 +2007,7 @@
|
||||
" return [xla_consts[id(cnst)] for cnst in consts]\n",
|
||||
"\n",
|
||||
"def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:\n",
|
||||
" return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n",
|
||||
" return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n",
|
||||
"\n",
|
||||
"def _xla_shape(aval: ShapedArray) -> xe.Shape:\n",
|
||||
" return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)"
|
||||
@ -3630,7 +3630,7 @@
|
||||
"\n",
|
||||
" def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n",
|
||||
" c = xc.XlaBuilder(name)\n",
|
||||
" operand = xb.parameter(c, 0, operand_shape)\n",
|
||||
" operand = xops.Parameter(c, 0, operand_shape)\n",
|
||||
" operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n",
|
||||
" outs = jaxpr_subcomp(c, jaxpr, operands)\n",
|
||||
" return c.build(xops.Tuple(c, outs))\n",
|
||||
|
@ -1577,7 +1577,7 @@ def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]:
|
||||
return [xla_consts[id(cnst)] for cnst in consts]
|
||||
|
||||
def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:
|
||||
return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
|
||||
return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
|
||||
|
||||
def _xla_shape(aval: ShapedArray) -> xe.Shape:
|
||||
return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
|
||||
@ -2844,7 +2844,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
|
||||
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
|
||||
c = xc.XlaBuilder(name)
|
||||
operand = xb.parameter(c, 0, operand_shape)
|
||||
operand = xops.Parameter(c, 0, operand_shape)
|
||||
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
|
||||
outs = jaxpr_subcomp(c, jaxpr, operands)
|
||||
return c.build(xops.Tuple(c, outs))
|
||||
|
@ -1569,7 +1569,7 @@ def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]:
|
||||
return [xla_consts[id(cnst)] for cnst in consts]
|
||||
|
||||
def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:
|
||||
return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
|
||||
return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
|
||||
|
||||
def _xla_shape(aval: ShapedArray) -> xe.Shape:
|
||||
return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
|
||||
@ -2836,7 +2836,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
|
||||
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
|
||||
c = xc.XlaBuilder(name)
|
||||
operand = xb.parameter(c, 0, operand_shape)
|
||||
operand = xops.Parameter(c, 0, operand_shape)
|
||||
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
|
||||
outs = jaxpr_subcomp(c, jaxpr, operands)
|
||||
return c.build(xops.Tuple(c, outs))
|
||||
|
@ -860,7 +860,7 @@ def xla_computation(fun: Callable,
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
|
||||
if out_parts is not None:
|
||||
out_tuple = xb.with_sharding(c, out_parts_flat, build_out_tuple)
|
||||
out_tuple = xla.with_sharding(c, out_parts_flat, build_out_tuple)
|
||||
else:
|
||||
out_tuple = build_out_tuple()
|
||||
|
||||
|
@ -218,42 +218,16 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
tuple_args = len(abstract_args) > 100
|
||||
axis_env = xla.AxisEnv(nreps, (), ())
|
||||
name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit'))
|
||||
module: Any
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
if config.jax_enable_mlir:
|
||||
# TODO(b/203122001): implement buffer donation.
|
||||
assert not any(donated_invars), donated_invars
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
core.ClosedJaxpr(jaxpr, consts), backend.platform, axis_env, name_stack)
|
||||
closed_jaxpr, backend.platform, axis_env, name_stack, donated_invars)
|
||||
else:
|
||||
# XLA HLO lowering path
|
||||
c = xc.XlaBuilder(f"jit_{fun.__name__}")
|
||||
xla_consts = xla._xla_consts(c, consts)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, abstract_args, tuple_args, donated_invars=donated_invars)
|
||||
platform = backend.platform
|
||||
ctx = xla.TranslationContext(c, platform, axis_env, name_stack)
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
# Replace tokens with a dummy array value, because the runtime cannot
|
||||
# handle token arguments.
|
||||
out_aval_lens = [len(xla.aval_to_xla_shapes(a)) for a in out_avals]
|
||||
out_nodes = util.flatten(
|
||||
[[xla._make_token_return_value(c)] if a is core.abstract_token
|
||||
else v
|
||||
for a, v in zip(out_avals, util.unflatten(out_nodes, out_aval_lens))])
|
||||
|
||||
# There is a non-zero cost to building an output tuple, particularly on TPU.
|
||||
# Avoid it if the output arity is 1.
|
||||
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
|
||||
if platform in ("gpu", "tpu"):
|
||||
donated_invars = xla.set_up_aliases(
|
||||
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
|
||||
if any(donated_invars):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(c.GetShape(a))
|
||||
for a, d in zip(xla_args, donated_invars) if d]
|
||||
warnings.warn("Some donated buffers were not usable: {}".format(
|
||||
", ".join(unused_donations)))
|
||||
module = c.build(output)
|
||||
module = xla.lower_jaxpr_to_xla_module(
|
||||
f"jit_{fun.__name__}", closed_jaxpr, backend.platform, axis_env,
|
||||
name_stack, tuple_args, donated_invars, replicated_args=None,
|
||||
arg_partitions=None, out_partitions=None)
|
||||
return XlaComputation(
|
||||
name, module, False, donated_invars, nreps, device, backend, tuple_args,
|
||||
abstract_args, out_avals, kept_var_idx)
|
||||
|
@ -48,7 +48,6 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import masking
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
|
||||
@ -342,7 +341,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
|
||||
|
||||
cond_c = xla_client.XlaBuilder("cond_computation")
|
||||
cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
|
||||
cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry))
|
||||
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
|
||||
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
|
||||
cond_ctx = ctx.replace(builder=cond_c,
|
||||
@ -359,7 +358,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
or_, list(range(cond_jaxpr.out_avals[0].ndim)))
|
||||
|
||||
body_c = xla_client.XlaBuilder("body_computation")
|
||||
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
|
||||
body_carry = xla.parameter(body_c, 0, c.get_shape(init_carry))
|
||||
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
||||
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
||||
body_ctx = ctx.replace(builder=body_c,
|
||||
@ -931,7 +930,7 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
|
||||
name_stack = extend_name_stack(ctx.name_stack, "cond")
|
||||
def make_computation(name, jaxpr, op_shape):
|
||||
c = xla_client.XlaBuilder(name + '_comp')
|
||||
op = xb.parameter(c, 0, op_shape)
|
||||
op = xla.parameter(c, 0, op_shape)
|
||||
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
||||
subctx = ctx.replace(
|
||||
builder=c, name_stack=extend_name_stack(name_stack, name + '_fun'))
|
||||
|
@ -3224,7 +3224,7 @@ def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True):
|
||||
axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions
|
||||
subc = xc.XlaBuilder("reduction_computation")
|
||||
assert len(consts) == 0, "Reduction computations cannot have constants"
|
||||
args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
|
||||
args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
|
||||
ctx = xla.TranslationContext(subc, platform, axis_env, '')
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
|
||||
if singleton:
|
||||
@ -3684,7 +3684,7 @@ def _sort_translation_rule(ctx, avals_in, avals_out, *operands, dimension,
|
||||
c = ctx.builder
|
||||
types = [c.get_shape(x).xla_element_type() for x in operands]
|
||||
subc = xc.XlaBuilder("sort_lt_comparator")
|
||||
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
|
||||
params = [xla.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
|
||||
for i, typ in enumerate(types) for j in range(2)]
|
||||
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
|
||||
backend=ctx.platform,
|
||||
@ -3918,7 +3918,7 @@ def _infeed_translation_rule(ctx, avals_in, avals_out, token, *, shapes,
|
||||
build_infeed = partial(xops.InfeedWithToken, token,
|
||||
xla_client.Shape.tuple_shape(shape))
|
||||
if partitions:
|
||||
xs_and_token = xb.with_sharding(c, partitions, build_infeed)
|
||||
xs_and_token = xla.with_sharding(c, partitions, build_infeed)
|
||||
else:
|
||||
# Note that infeed will default to replication if inside a sharded
|
||||
# computation and no sharding is specified.
|
||||
@ -3986,8 +3986,8 @@ def _outfeed_translation_rule(ctx, avals_in, avals_out, token, *xs, partitions):
|
||||
c = ctx.builder
|
||||
t = xops.Tuple(c, xs)
|
||||
if partitions is not None:
|
||||
return [xb.with_sharding(c, partitions, xops.OutfeedWithToken,
|
||||
t, token, c.get_shape(t))]
|
||||
return [xla.with_sharding(c, partitions, xops.OutfeedWithToken,
|
||||
t, token, c.get_shape(t))]
|
||||
else:
|
||||
return [xops.OutfeedWithToken(t, token, c.get_shape(t))]
|
||||
|
||||
|
@ -1487,7 +1487,7 @@ def _scatter_add_translation_rule(
|
||||
def _make_reducer(dtype):
|
||||
subc = xc.XlaBuilder("scatter_add_reducer")
|
||||
shape = xc.Shape.array_shape(np.dtype(dtype), ())
|
||||
args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)]
|
||||
args = [xla.parameter(subc, 0, shape), xla.parameter(subc, 1, shape)]
|
||||
out = xops.Add(args[0], args[1])
|
||||
return subc.build(out)
|
||||
|
||||
|
@ -806,9 +806,9 @@ def _select_and_gather_add_translation(
|
||||
|
||||
def reducer():
|
||||
c = xc.XlaBuilder("select_and_gather_pair_reducer")
|
||||
x = xb.parameter(c, 0,
|
||||
x = xla.parameter(c, 0,
|
||||
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
|
||||
y = xb.parameter(c, 1,
|
||||
y = xla.parameter(c, 1,
|
||||
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
|
||||
assert select_prim is lax.ge_p or select_prim is lax.le_p
|
||||
which = xops.Ge if select_prim is lax.ge_p else xops.Le
|
||||
|
@ -23,7 +23,7 @@ XLA. There are also a handful of related casting utilities.
|
||||
from functools import partial, lru_cache
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -435,81 +435,3 @@ def host_ids(backend=None):
|
||||
"instead. jax.host_ids will eventually be removed; please update your "
|
||||
"code.")
|
||||
return list(range(process_count(backend)))
|
||||
|
||||
|
||||
### utility functions
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
if name is None:
|
||||
name = ''
|
||||
if replicated is None:
|
||||
replicated = []
|
||||
elif isinstance(replicated, bool):
|
||||
replicated = [replicated] * shape.leaf_count()
|
||||
|
||||
return xops.Parameter(builder, num,
|
||||
shape.with_major_to_minor_layout_if_absent(), name,
|
||||
replicated)
|
||||
|
||||
# HLO instructions optionally can be annotated to say how the output should be
|
||||
# spatially partitioned (represented in XLA as OpSharding protos, see
|
||||
# _sharding_to_proto). For array outputs, the annotation is either an int per
|
||||
# dimension specifying the number of ways that dimension divided (i.e. the total
|
||||
# number of shards is the product), or None to indicate the array should be
|
||||
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
|
||||
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
||||
# checkers don't support recursive types), so we only represent one level of
|
||||
# nesting in this type definition.
|
||||
SpatialSharding = Union[Tuple[int, ...],
|
||||
None,
|
||||
Tuple[Union[Tuple[int, ...], None], ...]]
|
||||
|
||||
def _sharding_to_proto(sharding: SpatialSharding):
|
||||
"""Converts a SpatialSharding to an OpSharding.
|
||||
|
||||
See
|
||||
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
|
||||
for details on the OpSharding proto.
|
||||
"""
|
||||
proto = xla_client.OpSharding()
|
||||
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
|
||||
assert all(s is None or isinstance(s, tuple) for s in sharding)
|
||||
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
|
||||
|
||||
if sharding is None:
|
||||
proto.type = xla_client.OpSharding.Type.REPLICATED
|
||||
else:
|
||||
proto.type = xla_client.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = list(sharding)
|
||||
proto.tile_assignment_devices = list(range(np.product(sharding)))
|
||||
return proto
|
||||
|
||||
def tuple_sharding_proto(elems):
|
||||
proto = xla_client.OpSharding()
|
||||
assert all(isinstance(e, type(proto)) for e in elems)
|
||||
proto.type = xla_client.OpSharding.Type.TUPLE
|
||||
proto.tuple_shardings = elems
|
||||
return proto
|
||||
|
||||
def set_sharding_proto(builder, op, sharding_proto):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
# "Sharding" is a built-in custom call target that acts like an identity
|
||||
# function, and is used to attach an OpSharding to.
|
||||
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
|
||||
builder, b"Sharding", [op], builder.get_shape(op))
|
||||
|
||||
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
builder.set_sharding(sharding_proto)
|
||||
try:
|
||||
return op_fn(*args, **kwargs)
|
||||
finally:
|
||||
builder.clear_sharding()
|
||||
|
||||
def set_sharding(builder, op, sharding: SpatialSharding):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
|
||||
|
||||
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
|
||||
|
@ -63,7 +63,6 @@ from typing import (Any, Tuple)
|
||||
|
||||
import numpy as np
|
||||
from jax import lax, core
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import ad_util, dtypes
|
||||
|
||||
@ -167,10 +166,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
|
||||
def _comparator_builder(operand, op_type, is_max_k):
|
||||
c = xc.XlaBuilder(
|
||||
'top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
|
||||
p0 = xb.parameter(c, 0, xc.Shape.scalar_shape(op_type))
|
||||
p1 = xb.parameter(c, 1, xc.Shape.scalar_shape(op_type))
|
||||
xb.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
|
||||
xb.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
|
||||
p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
|
||||
p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type))
|
||||
xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
|
||||
xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
|
||||
if is_max_k:
|
||||
cmp_result = xc.ops.Gt(p0, p1)
|
||||
else:
|
||||
|
@ -640,7 +640,7 @@ xla.canonicalize_dtype_handlers[BoundedInt] = _bdint_canoncalize_dtype
|
||||
|
||||
def _make_params(c, dim_in_avals, in_avals):
|
||||
n = it.count()
|
||||
make = lambda a: [xb.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)]
|
||||
make = lambda a: [xla.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)]
|
||||
return map(make, dim_in_avals), map(make, in_avals)
|
||||
|
||||
def _xla_consts(c, consts):
|
||||
|
@ -946,7 +946,7 @@ def _outside_call_translation_rule(ctx, avals_in, avals_out,
|
||||
|
||||
token_sharding_proto = xla_client.OpSharding()
|
||||
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
|
||||
infeed_sharding_proto = xb.tuple_sharding_proto(
|
||||
infeed_sharding_proto = xla.tuple_sharding_proto(
|
||||
[array_sharding_proto] * len(non_empty_flat_results_aval) +
|
||||
[token_sharding_proto])
|
||||
|
||||
@ -959,8 +959,8 @@ def _outside_call_translation_rule(ctx, avals_in, avals_out,
|
||||
build_infeed = functools.partial(xops.InfeedWithToken,
|
||||
after_outfeed_itoken,
|
||||
xla_client.Shape.tuple_shape(shape))
|
||||
outs_and_token = xb.with_sharding_proto(comp, infeed_sharding_proto,
|
||||
build_infeed)
|
||||
outs_and_token = xla.with_sharding_proto(comp, infeed_sharding_proto,
|
||||
build_infeed)
|
||||
outs = xops.GetTupleElement(outs_and_token, 0)
|
||||
next_itoken = xops.GetTupleElement(outs_and_token, 1)
|
||||
non_empty_results = [
|
||||
|
@ -1464,7 +1464,7 @@ def _xmap_translation_rule_spmd(c, axis_env,
|
||||
|
||||
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
|
||||
sharded_global_in_nodes = [
|
||||
xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
|
||||
xla.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
|
||||
if aval_axes else node
|
||||
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
|
||||
]
|
||||
@ -1478,7 +1478,7 @@ def _xmap_translation_rule_spmd(c, axis_env,
|
||||
*sharded_global_in_nodes)
|
||||
|
||||
sharded_global_out_nodes = [
|
||||
xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
|
||||
xla.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
|
||||
if aval_axes else node
|
||||
for node, aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
|
||||
]
|
||||
|
@ -37,7 +37,6 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.sharded_jit import PartitionSpec
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
|
||||
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
||||
@ -514,8 +513,8 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
|
||||
for i, (n, axis_resources) in enumerate(safe_zip(in_nodes, in_axis_resources)):
|
||||
# N.B. inlined calls shouldn't have shardings set directly on the inputs or
|
||||
# outputs (set_sharding_proto adds an identity operation).
|
||||
arg = xb.parameter(subc, i, c.GetShape(n))
|
||||
args.append(xb.set_sharding_proto(subc, arg,
|
||||
arg = xla.parameter(subc, i, c.GetShape(n))
|
||||
args.append(xla.set_sharding_proto(subc, arg,
|
||||
get_sharding_proto(c, n, axis_resources, mesh)))
|
||||
|
||||
# TODO: Think about how to avoid duplicating constants with the outer jaxpr
|
||||
@ -525,7 +524,7 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
ctx, jaxpr.jaxpr, xla._xla_consts(subc, jaxpr.consts), *args)
|
||||
out_nodes = [
|
||||
xb.set_sharding_proto(subc, out,
|
||||
xla.set_sharding_proto(subc, out,
|
||||
get_sharding_proto(subc, out, axis_resources, mesh))
|
||||
for out, axis_resources in safe_zip(out_nodes, out_axis_resources)
|
||||
]
|
||||
@ -815,7 +814,7 @@ ad.deflinear2(sharding_constraint_p,
|
||||
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *,
|
||||
axis_resources, resource_env):
|
||||
mesh = resource_env.physical_mesh
|
||||
return [xb.set_sharding_proto(
|
||||
return [xla.set_sharding_proto(
|
||||
ctx.builder, x_node,
|
||||
get_sharding_proto(ctx.builder, x_node, axis_resources, mesh))]
|
||||
xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule)
|
||||
|
@ -23,6 +23,7 @@ import typing
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Type, Union,
|
||||
Tuple)
|
||||
from typing_extensions import Protocol
|
||||
import warnings
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
@ -319,11 +320,22 @@ def flatten_lowering_ir_args(
|
||||
return util.flatten(map(wrap_singleton_ir_values, xs))
|
||||
|
||||
def lower_jaxpr_to_module(jaxpr: core.ClosedJaxpr, platform: str,
|
||||
axis_env: xla.AxisEnv, name_stack: str) -> str:
|
||||
axis_env: xla.AxisEnv, name_stack: str,
|
||||
donated_invars: Sequence[bool]) -> str:
|
||||
"""Lowers a top-level jaxpr to an MHLO module.
|
||||
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime."""
|
||||
if platform in ("gpu", "tpu"):
|
||||
# TODO(b/203122001): implement buffer donation.
|
||||
assert not any(donated_invars), donated_invars
|
||||
if any(donated_invars):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(a) for a, d in zip(jaxpr.in_avals, donated_invars)
|
||||
if d]
|
||||
warnings.warn("Some donated buffers were not usable: {}".format(
|
||||
", ".join(unused_donations)))
|
||||
|
||||
ctx = LoweringContext(platform, axis_env, name_stack)
|
||||
if platform == "iree":
|
||||
ctx = ctx.replace(tuple_results=False)
|
||||
|
@ -995,31 +995,23 @@ def lower_parallel_callable(
|
||||
|
||||
axis_env = xla.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
|
||||
c = xc.XlaBuilder("pmap_{}".format(fun.__name__))
|
||||
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
|
||||
name_stack = extend_name_stack(wrap_name(name, 'pmap'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, shards.global_sharded_avals, tuple_args(shards),
|
||||
replicated=replicated_args,
|
||||
partitions=parts.arg_parts,
|
||||
donated_invars=donated_invars)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
ctx = xla.TranslationContext(c, backend.platform, axis_env,
|
||||
extend_name_stack(wrap_name(name, 'pmap')))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
build_out_tuple = partial(xops.Tuple, c, out_nodes)
|
||||
if parts.out_parts is not None:
|
||||
out_tuple = xb.with_sharding(c, parts.out_parts, build_out_tuple)
|
||||
else:
|
||||
out_tuple = build_out_tuple()
|
||||
if config.jax_enable_mlir:
|
||||
# TODO(phawkins): handle replicated_args.
|
||||
# TODO(phawkins): handle sharding.
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
closed_jaxpr, backend.platform, axis_env, name_stack, donated_invars)
|
||||
else:
|
||||
module = xla.lower_jaxpr_to_xla_module(
|
||||
f"pmap_{fun.__name__}", closed_jaxpr, backend.platform, axis_env,
|
||||
name_stack, tuple_args(shards), donated_invars, replicated_args,
|
||||
parts.arg_parts, parts.out_parts)
|
||||
|
||||
if backend.platform in ("gpu", "tpu"):
|
||||
donated_invars = xla.set_up_aliases(c, xla_args, c.GetShape(out_tuple),
|
||||
donated_invars, tuple_args(shards))
|
||||
built = c.Build(out_tuple)
|
||||
|
||||
return PmapComputation(built, pci, replicas, parts, shards)
|
||||
return PmapComputation(module, pci, replicas, parts, shards)
|
||||
|
||||
|
||||
class PmapComputation:
|
||||
@ -1917,8 +1909,6 @@ def lower_mesh_computation(
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
# 3. Build up the HLO
|
||||
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
|
||||
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
|
||||
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
in_partitions: Optional[List]
|
||||
if spmd_lowering:
|
||||
@ -1929,40 +1919,29 @@ def lower_mesh_computation(
|
||||
for aval, aval_in_axes in safe_zip(global_in_untiled_avals, in_axes)]
|
||||
out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto()
|
||||
for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)]
|
||||
out_partitions_t = xla.tuple_sharding_proto(out_partitions)
|
||||
partitions_proto = True
|
||||
axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) # All named axes have been vmapped
|
||||
else:
|
||||
replicated_args = [not axis for axis in in_axes]
|
||||
in_partitions = None
|
||||
out_partitions_t = None
|
||||
partitions_proto = False
|
||||
axis_env = xla.AxisEnv(nreps=mesh.size,
|
||||
names=tuple(global_axis_sizes.keys()),
|
||||
sizes=tuple(global_axis_sizes.values()))
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, in_jaxpr_avals, tuple_args,
|
||||
replicated=replicated_args,
|
||||
partitions=in_partitions,
|
||||
partitions_proto=partitions_proto,
|
||||
donated_invars=donated_invars)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
name_stack = extend_name_stack(wrap_name(transformed_name, 'xmap'))
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
ctx = xla.TranslationContext(
|
||||
c, backend.platform, axis_env,
|
||||
extend_name_stack(wrap_name(transformed_name, 'xmap')))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
if spmd_lowering:
|
||||
out_partitions_t = xb.tuple_sharding_proto(out_partitions)
|
||||
out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)
|
||||
else:
|
||||
out_tuple = xops.Tuple(c, out_nodes)
|
||||
# TODO(phawkins): add MLIR lowering.
|
||||
module = xla.lower_jaxpr_to_xla_module(
|
||||
f"xmap_{fun.__name__}", closed_jaxpr, backend.platform, axis_env,
|
||||
name_stack, tuple_args, donated_invars, replicated_args,
|
||||
in_partitions, out_partitions_t,
|
||||
partitions_are_protos=partitions_proto)
|
||||
|
||||
if backend.platform in ("gpu", "tpu"):
|
||||
xla.set_up_aliases(c, xla_args, c.GetShape(out_tuple), donated_invars,
|
||||
tuple_args)
|
||||
# TODO: Warn about unused donations?
|
||||
|
||||
built = c.Build(out_tuple)
|
||||
return MeshComputation(
|
||||
built, donated_invars, mesh, local_in_untiled_avals,
|
||||
module, donated_invars, mesh, local_in_untiled_avals,
|
||||
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
|
||||
in_axes, out_axes, spmd_lowering, tuple_args)
|
||||
|
||||
|
@ -147,7 +147,7 @@ def _sharded_callable(
|
||||
ctx = xla.TranslationContext(
|
||||
c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit")))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
|
||||
out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
|
||||
built = c.Build(out_tuple)
|
||||
|
||||
if nparts <= xb.local_device_count():
|
||||
@ -191,10 +191,10 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
|
||||
|
||||
args = []
|
||||
for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
|
||||
# We use xb.set_sharding instead of xb.with_sharding because inlined calls
|
||||
# We use xla.set_sharding instead of xla.with_sharding because inlined calls
|
||||
# shouldn't have shardings set directly on the inputs or outputs.
|
||||
arg = xb.parameter(subc, i, c.GetShape(n))
|
||||
args.append(xb.set_sharding(subc, arg, sharding))
|
||||
arg = xla.parameter(subc, i, c.GetShape(n))
|
||||
args.append(xla.set_sharding(subc, arg, sharding))
|
||||
|
||||
ctx = xla.TranslationContext(
|
||||
subc, backend, axis_env,
|
||||
@ -202,7 +202,7 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, call_jaxpr, (), *args)
|
||||
out_parts = out_parts_thunk()
|
||||
assert len(out_parts) == len(out_nodes)
|
||||
out_nodes = [xb.set_sharding(subc, out, sharding)
|
||||
out_nodes = [xla.set_sharding(subc, out, sharding)
|
||||
for out, sharding in safe_zip(out_nodes, out_parts)]
|
||||
|
||||
subc = subc.build(xops.Tuple(subc, out_nodes))
|
||||
@ -218,7 +218,7 @@ def _execute_spatially_partitioned(compiled, in_handler, out_handler, *args):
|
||||
def _xla_sharded_args(c, avals, in_parts):
|
||||
xla_args = []
|
||||
for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)):
|
||||
param = xb.with_sharding(c, sharding, xb.parameter, c, i,
|
||||
param = xla.with_sharding(c, sharding, xla.parameter, c, i,
|
||||
*xla.aval_to_xla_shapes(aval))
|
||||
xla_args.append(param)
|
||||
return xla_args
|
||||
@ -413,7 +413,7 @@ def _sharding_constraint_impl(x, partitions):
|
||||
|
||||
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node,
|
||||
partitions):
|
||||
return [xb.set_sharding(ctx.builder, x_node, partitions)]
|
||||
return [xla.set_sharding(ctx.builder, x_node, partitions)]
|
||||
|
||||
sharding_constraint_p = core.Primitive("sharding_constraint")
|
||||
sharding_constraint_p.def_impl(_sharding_constraint_impl)
|
||||
|
@ -22,9 +22,10 @@ from functools import partial
|
||||
import itertools as it
|
||||
import operator
|
||||
import re
|
||||
from typing import (Any, Callable, Deque, Dict, List, Optional, Sequence, Set,
|
||||
Type, Tuple, NamedTuple)
|
||||
from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional,
|
||||
Sequence, Set, Type, Tuple, Union)
|
||||
from typing_extensions import Protocol
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -43,7 +44,6 @@ import jax._src.pretty_printer as pp
|
||||
from jax._src import util
|
||||
from jax._src.util import (prod, extend_name_stack, wrap_name,
|
||||
safe_zip, safe_map, partition_list)
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import ad
|
||||
@ -113,6 +113,84 @@ def make_op_metadata(primitive: core.Primitive,
|
||||
source_file=_get_canonical_source_file(frame) if frame else None,
|
||||
source_line=frame.line_num if frame else None)
|
||||
|
||||
# Utilities
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
if name is None:
|
||||
name = ''
|
||||
if replicated is None:
|
||||
replicated = []
|
||||
elif isinstance(replicated, bool):
|
||||
replicated = [replicated] * shape.leaf_count()
|
||||
|
||||
return xops.Parameter(builder, num,
|
||||
shape.with_major_to_minor_layout_if_absent(), name,
|
||||
replicated)
|
||||
|
||||
# HLO instructions optionally can be annotated to say how the output should be
|
||||
# spatially partitioned (represented in XLA as OpSharding protos, see
|
||||
# _sharding_to_proto). For array outputs, the annotation is either an int per
|
||||
# dimension specifying the number of ways that dimension divided (i.e. the total
|
||||
# number of shards is the product), or None to indicate the array should be
|
||||
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
|
||||
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
||||
# checkers don't support recursive types), so we only represent one level of
|
||||
# nesting in this type definition.
|
||||
SpatialSharding = Union[Tuple[int, ...],
|
||||
None,
|
||||
Tuple[Union[Tuple[int, ...], None], ...]]
|
||||
|
||||
def _sharding_to_proto(sharding: SpatialSharding):
|
||||
"""Converts a SpatialSharding to an OpSharding.
|
||||
|
||||
See
|
||||
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
|
||||
for details on the OpSharding proto.
|
||||
"""
|
||||
proto = xc.OpSharding()
|
||||
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
|
||||
assert all(s is None or isinstance(s, tuple) for s in sharding)
|
||||
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
|
||||
|
||||
if sharding is None:
|
||||
proto.type = xc.OpSharding.Type.REPLICATED
|
||||
else:
|
||||
proto.type = xc.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = list(sharding)
|
||||
proto.tile_assignment_devices = list(range(np.product(sharding)))
|
||||
return proto
|
||||
|
||||
def tuple_sharding_proto(elems):
|
||||
proto = xc.OpSharding()
|
||||
assert all(isinstance(e, type(proto)) for e in elems)
|
||||
proto.type = xc.OpSharding.Type.TUPLE
|
||||
proto.tuple_shardings = elems
|
||||
return proto
|
||||
|
||||
def set_sharding_proto(builder, op, sharding_proto):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
# "Sharding" is a built-in custom call target that acts like an identity
|
||||
# function, and is used to attach an OpSharding to.
|
||||
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
|
||||
builder, b"Sharding", [op], builder.get_shape(op))
|
||||
|
||||
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
builder.set_sharding(sharding_proto)
|
||||
try:
|
||||
return op_fn(*args, **kwargs)
|
||||
finally:
|
||||
builder.clear_sharding()
|
||||
|
||||
def set_sharding(builder, op, sharding: SpatialSharding):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
|
||||
|
||||
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
|
||||
|
||||
|
||||
### handlers
|
||||
|
||||
# Numpy dtypes -> XLA primitive types
|
||||
@ -400,7 +478,7 @@ def _xla_callable_args(
|
||||
if partitions is None:
|
||||
tuple_parts = None
|
||||
elif partitions_proto:
|
||||
tuple_parts = xb.tuple_sharding_proto(partitions)
|
||||
tuple_parts = tuple_sharding_proto(partitions)
|
||||
else:
|
||||
tuple_parts = tuple(partitions)
|
||||
tuple_shape = xc.Shape.tuple_shape(
|
||||
@ -419,15 +497,15 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions,
|
||||
is_token = xla_shape.is_token()
|
||||
if filter_tokens and is_token:
|
||||
xla_shape = _token_param_shape()
|
||||
make_param = partial(xb.parameter, builder, param_num, xla_shape,
|
||||
make_param = partial(parameter, builder, param_num, xla_shape,
|
||||
replicated=replicated)
|
||||
with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
|
||||
with_sharding_fn = with_sharding_proto if parts_proto else with_sharding
|
||||
if partitions is None:
|
||||
out = make_param()
|
||||
elif partitions is _replicated_param:
|
||||
out = with_sharding(builder, None, make_param)
|
||||
out = with_sharding_fn(builder, None, make_param)
|
||||
else:
|
||||
out = with_sharding(builder, partitions, make_param)
|
||||
out = with_sharding_fn(builder, partitions, make_param)
|
||||
if filter_tokens and is_token:
|
||||
out = xops.CreateToken(builder)
|
||||
return out
|
||||
@ -583,9 +661,9 @@ def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
|
||||
Given the following computation:
|
||||
|
||||
>>> c = xc.XlaBuilder("example")
|
||||
>>> p0 = xb.parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
|
||||
>>> p1 = xb.parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
|
||||
>>> p2 = xb.parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
|
||||
>>> p0 = parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
|
||||
>>> p1 = parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
|
||||
>>> p2 = parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
|
||||
>>> o = xops.Tuple(c, [p0, p1, p2])
|
||||
|
||||
We can query the arrays in the output tuple:
|
||||
@ -659,6 +737,55 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
|
||||
|
||||
|
||||
|
||||
def lower_jaxpr_to_xla_module(
|
||||
fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv,
|
||||
name_stack: str, tuple_args: bool, donated_invars: Sequence[bool],
|
||||
replicated_args: Optional[Sequence[bool]],
|
||||
arg_partitions: Optional[Any],
|
||||
out_partitions: Optional[Any],
|
||||
partitions_are_protos: bool = False
|
||||
) -> xc.XlaComputation:
|
||||
"""Lowers a closed jaxpr to a top-level XLA module."""
|
||||
c = xc.XlaBuilder(fn_name)
|
||||
xla_consts = _xla_consts(c, jaxpr.consts)
|
||||
xla_args, donated_invars = _xla_callable_args(
|
||||
c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars,
|
||||
replicated=replicated_args, partitions=arg_partitions,
|
||||
partitions_proto=partitions_are_protos)
|
||||
ctx = TranslationContext(c, platform, axis_env, name_stack)
|
||||
out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
|
||||
# Replace tokens with a dummy array value, because the runtime cannot
|
||||
# handle token arguments.
|
||||
out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
|
||||
out_nodes = util.flatten(
|
||||
[[_make_token_return_value(c)] if a is core.abstract_token
|
||||
else v for a, v in zip(jaxpr.out_avals,
|
||||
util.unflatten(out_nodes, out_aval_lens))])
|
||||
|
||||
# There is a non-zero cost to building an output tuple, particularly on TPU.
|
||||
# Avoid it if the output arity is 1.
|
||||
if out_partitions is None:
|
||||
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
|
||||
else:
|
||||
build_out_tuple = partial(xops.Tuple, c, out_nodes)
|
||||
if partitions_are_protos:
|
||||
output = with_sharding_proto(c, out_partitions, build_out_tuple)
|
||||
else:
|
||||
output = with_sharding(c, out_partitions, build_out_tuple)
|
||||
|
||||
if platform in ("gpu", "tpu"):
|
||||
donated_invars = set_up_aliases(
|
||||
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
|
||||
if any(donated_invars):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(c.GetShape(a))
|
||||
for a, d in zip(xla_args, donated_invars) if d]
|
||||
warnings.warn("Some donated buffers were not usable: {}".format(
|
||||
", ".join(unused_donations)))
|
||||
return c.build(output)
|
||||
|
||||
|
||||
|
||||
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
|
||||
xla_call = xla_call_p.bind
|
||||
|
||||
@ -698,7 +825,7 @@ def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
|
||||
c = ctx.builder
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
subc = xc.XlaBuilder(f"jit_{name}")
|
||||
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
|
||||
args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
|
||||
sub_ctx = ctx.replace(
|
||||
builder=subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit')))
|
||||
@ -934,7 +1061,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
|
||||
|
||||
true_op = xops.Tuple(c, in_nodes)
|
||||
remat_subc = xc.XlaBuilder("remat_call_subcomputation")
|
||||
input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
|
||||
input_op = parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
|
||||
args = xla_destructure(remat_subc, input_op)
|
||||
sub_ctx = ctx.replace(
|
||||
builder=remat_subc,
|
||||
@ -945,7 +1072,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
|
||||
|
||||
false_op = true_op
|
||||
dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation")
|
||||
xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
|
||||
parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
|
||||
out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
|
||||
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
|
||||
|
||||
@ -959,7 +1086,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
# Dummy subc for getting subcomp shapes.
|
||||
dummy_inputs = xops.Tuple(c, in_nodes)
|
||||
dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation")
|
||||
dummy_input_op = xb.parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
|
||||
dummy_input_op = parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
|
||||
dummy_args = xla_destructure(dummy_subc, dummy_input_op)
|
||||
dummy_ctx = ctx.replace(
|
||||
builder=dummy_subc,
|
||||
@ -972,7 +1099,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
|
||||
|
||||
cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
|
||||
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
input_op = parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
i = xops.GetTupleElement(input_op, 0)
|
||||
rng = xops.RngUniform(xops.Constant(cond_subc, np.array(1, dtype=np.int32)),
|
||||
xops.Constant(cond_subc, np.array(2, dtype=np.int32)),
|
||||
@ -980,7 +1107,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
cond_subc = cond_subc.build(xops.Lt(i, rng))
|
||||
|
||||
body_subc = xc.XlaBuilder("remat_body_subcomputation")
|
||||
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
input_op = parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
|
||||
i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32)))
|
||||
body_ctx = ctx.replace(
|
||||
@ -1019,7 +1146,7 @@ def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
c = ctx.builder
|
||||
subc = xc.XlaBuilder(name)
|
||||
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
|
||||
args = [parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
|
||||
sub_ctx = ctx.replace(builder=subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, name))
|
||||
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
|
||||
|
@ -244,8 +244,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
# Jit and Donate arguments
|
||||
|
||||
def test_jit_donate_argnums_warning_raised(self):
|
||||
if jax.config.jax_enable_mlir:
|
||||
raise unittest.SkipTest("Buffer donation not yet implemented via MLIR")
|
||||
x = jnp.array([1.0, 2.0], jnp.float32)
|
||||
y = jnp.array([1, 2], jnp.int32)
|
||||
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
|
||||
@ -256,7 +254,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertLen(w, 1)
|
||||
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
||||
self.assertIn(
|
||||
"Some donated buffers were not usable: f32[2]{0}, s32[2]{0}",
|
||||
"Some donated buffers were not usable:",
|
||||
str(w[-1].message))
|
||||
|
||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||
|
@ -26,6 +26,7 @@ from jax.interpreters import xla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
xops = xla_client.ops
|
||||
xc = xla_client
|
||||
xb = xla_bridge
|
||||
|
||||
from jax.config import config
|
||||
@ -113,14 +114,14 @@ def sparse_array_result_handler(device, aval):
|
||||
|
||||
def sparse_array_shape_handler(a):
|
||||
return (
|
||||
xla.xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),
|
||||
xla.xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),
|
||||
xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),
|
||||
xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),
|
||||
)
|
||||
|
||||
def sparse_array_device_put_handler(a, device):
|
||||
return (
|
||||
xla.xb.get_device_backend(device).buffer_from_pyval(a.data, device),
|
||||
xla.xb.get_device_backend(device).buffer_from_pyval(a.indices, device)
|
||||
xb.get_device_backend(device).buffer_from_pyval(a.data, device),
|
||||
xb.get_device_backend(device).buffer_from_pyval(a.indices, device)
|
||||
)
|
||||
|
||||
def sparse_array_constant_handler(c, val, canonicalize_dtypes):
|
||||
|
@ -19,6 +19,7 @@ from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -47,13 +48,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_parameter_replication_default(self):
|
||||
c = xc.XlaBuilder("test")
|
||||
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
||||
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
||||
built_c = c.Build()
|
||||
assert "replication" not in built_c.as_hlo_text()
|
||||
|
||||
def test_parameter_replication(self):
|
||||
c = xc.XlaBuilder("test")
|
||||
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
|
||||
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
|
||||
False)
|
||||
built_c = c.Build()
|
||||
assert "parameter_replication={false}" in built_c.as_hlo_text()
|
||||
|
Loading…
x
Reference in New Issue
Block a user