Cleanup token handling during lowering

Version 0.4.27 of jaxlib is now the minimum version and it supports
real stablehlo tokens as module inputs and outputs. Hence we can
now clean up `mlir.lower_jaxpr_to_fun` to not use the kwargs
`create_tokens` and `replace_tokens_with_dummy` (both of them
are always False now).

We also remove `num_output_tokens` that is not used.
This commit is contained in:
George Necula 2024-05-14 13:20:58 +03:00
parent 66a92c41f6
commit 41153b168c
2 changed files with 16 additions and 69 deletions

View File

@ -963,7 +963,6 @@ def lower_jaxpr_to_module(
ctx, "main", jaxpr, ordered_effects,
name_stack=name_stack,
public=True,
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
@ -1091,15 +1090,6 @@ class TokenSet:
new_tokens.append((eff, self._tokens[eff]))
return TokenSet(new_tokens)
def dummy_token_type() -> Sequence[ir.Type]:
# TODO(b/302258959): For now HLO does not allow hlo.TokenType among
# arguments and results, so we use bool[0] to pass tokens to the
# top-level function only.
return aval_to_ir_types(core.ShapedArray((0,), np.bool_))
def dummy_token() -> Sequence[ir.Value]:
return ir_constants(np.zeros(0, np.bool_))
def lower_jaxpr_to_fun(
ctx: ModuleContext,
name: str,
@ -1107,16 +1097,13 @@ def lower_jaxpr_to_fun(
effects: Sequence[core.Effect],
name_stack: source_info_util.NameStack,
*,
create_tokens: bool = False,
public: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
use_sharding_annotations: bool = True,
input_output_aliases: Sequence[int | None] | None = None,
xla_donated_args: Sequence[bool] | None = None,
num_output_tokens: int = 0,
api_name: str = "jit",
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
@ -1137,11 +1124,7 @@ def lower_jaxpr_to_fun(
jaxpr: the jaxpr to lower.
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
that will be created in or used by the lowered function.
create_tokens: if true, the HLO will create tokens and ignore dummy input
tokens. See b/302258959.
public: if true, the function's visibility is set to "public".
replace_tokens_with_dummy: if true, token arguments/return values are
replaced with bool arrays of size [0]. See b/302258959.
replicated_args: if present, annotates arguments as replicated.
arg_shardings: sharding annotations for each argument (optional).
result_shardings: sharding annotations for each result (optional).
@ -1158,50 +1141,38 @@ def lower_jaxpr_to_fun(
Returns:
MLIR func op
"""
def aval_to_types(aval):
if replace_tokens_with_dummy and aval is core.abstract_token:
aval = core.ShapedArray((), np.dtype(np.bool_))
return aval_to_ir_types(aval)
# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
dim_var_types = map(aval_to_types, dim_var_avals)
dim_var_types = map(aval_to_ir_types, dim_var_avals)
# Function inputs: *dim_var_values, *tokens, *actual_inputs
input_types = map(aval_to_types, jaxpr.in_avals)
output_types = map(aval_to_types, jaxpr.out_avals)
input_types = map(aval_to_ir_types, jaxpr.in_avals)
output_types = map(aval_to_ir_types, jaxpr.out_avals)
num_tokens = len(effects)
if create_tokens:
# TODO(b/302258959): Use actual tokens
token_types = [dummy_token_type() for _ in effects]
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)]
else:
# If we aren't creating tokens they will be the initial inputs to the
# MLIR function.
output_token_types = []
token_types = [token_type() for _ in effects]
token_types = [token_type() for _ in effects]
token_avals = [core.abstract_token] * num_tokens
# Order of arguments: dim vars, tokens, array inputs
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
output_avals = [core.abstract_token] * num_tokens + jaxpr.out_avals
output_types = [*token_types, *output_types]
if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
# Update the existing aliases to account for the new output values
input_output_aliases = [None if a is None
else a + num_output_tokens + num_tokens
else a + num_tokens
for a in input_output_aliases] # type: ignore
if arg_shardings is not None:
token_shardings = [None] * (num_dim_vars + num_tokens)
arg_shardings = [*token_shardings, *arg_shardings]
if result_shardings is not None:
token_shardings = [None] * (num_tokens + num_output_tokens)
token_shardings = [None] * num_tokens
result_shardings = [*token_shardings, *result_shardings]
if replicated_args is not None:
token_replicated_args = [False] * (num_dim_vars + num_tokens)
@ -1210,13 +1181,13 @@ def lower_jaxpr_to_fun(
token_memory_kinds = [None] * (num_dim_vars + num_tokens)
arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds]
if result_memory_kinds is not None:
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
token_memory_kinds = [None] * num_tokens
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
if arg_layouts is not None:
token_layouts = [None] * (num_dim_vars + num_tokens)
arg_layouts = [*token_layouts, *arg_layouts]
if result_layouts is not None:
token_layouts = [None] * (num_tokens + num_output_tokens)
token_layouts = [None] * num_tokens
result_layouts = [*token_layouts, *result_layouts]
if xla_donated_args is not None:
xla_donated_args = [*([False] * (num_dim_vars + num_tokens)), *xla_donated_args]
@ -1427,35 +1398,17 @@ def lower_jaxpr_to_fun(
_, token_args, unflattened_args = util.split_list(
util.unflatten(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens])
if create_tokens:
tokens_in = TokenSet.create(effects)
else:
tokens_in = TokenSet(zip(effects, token_args))
args: list[list[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_tokens_with_dummy and aval is core.abstract_token:
args.append([hlo.create_token()])
else:
args.append(arg)
tokens_in = TokenSet(zip(effects, token_args))
args: list[list[ir.Value]] = unflattened_args
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = jaxpr_subcomp(
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
consts, *args, dim_var_values=dim_var_values)
outs = []
if create_tokens:
for _ in range(num_output_tokens):
outs.append(dummy_token())
for _ in effects:
outs.append(dummy_token())
else:
for eff in effects:
outs.append(wrap_singleton_ir_values(tokens_out.get(eff)))
for aval, out in zip(jaxpr.out_avals, out_vals):
if replace_tokens_with_dummy and aval is core.abstract_token:
outs.append(ir_constants(np.zeros((), np.bool_)))
else:
outs.append(out)
for eff in effects:
outs.append(wrap_singleton_ir_values(tokens_out.get(eff)))
outs.extend(out_vals)
flat_outputs = util.flatten(outs)

View File

@ -578,13 +578,7 @@ def _wrap_main_func(
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
def is_token(typ, attrs):
if typ == mlir.token_type()[0]:
return True
# TODO(b/302258959): in older versions we cannot use the token type
try:
return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value
except KeyError:
return False
return (typ == mlir.token_type()[0])
orig_input_types = orig_main.type.inputs
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs))