mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Cleanup token handling during lowering
Version 0.4.27 of jaxlib is now the minimum version and it supports real stablehlo tokens as module inputs and outputs. Hence we can now clean up `mlir.lower_jaxpr_to_fun` to not use the kwargs `create_tokens` and `replace_tokens_with_dummy` (both of them are always False now). We also remove `num_output_tokens` that is not used.
This commit is contained in:
parent
66a92c41f6
commit
41153b168c
@ -963,7 +963,6 @@ def lower_jaxpr_to_module(
|
||||
ctx, "main", jaxpr, ordered_effects,
|
||||
name_stack=name_stack,
|
||||
public=True,
|
||||
num_output_tokens=0,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=arg_shardings,
|
||||
result_shardings=result_shardings,
|
||||
@ -1091,15 +1090,6 @@ class TokenSet:
|
||||
new_tokens.append((eff, self._tokens[eff]))
|
||||
return TokenSet(new_tokens)
|
||||
|
||||
def dummy_token_type() -> Sequence[ir.Type]:
|
||||
# TODO(b/302258959): For now HLO does not allow hlo.TokenType among
|
||||
# arguments and results, so we use bool[0] to pass tokens to the
|
||||
# top-level function only.
|
||||
return aval_to_ir_types(core.ShapedArray((0,), np.bool_))
|
||||
|
||||
def dummy_token() -> Sequence[ir.Value]:
|
||||
return ir_constants(np.zeros(0, np.bool_))
|
||||
|
||||
def lower_jaxpr_to_fun(
|
||||
ctx: ModuleContext,
|
||||
name: str,
|
||||
@ -1107,16 +1097,13 @@ def lower_jaxpr_to_fun(
|
||||
effects: Sequence[core.Effect],
|
||||
name_stack: source_info_util.NameStack,
|
||||
*,
|
||||
create_tokens: bool = False,
|
||||
public: bool = False,
|
||||
replace_tokens_with_dummy: bool = False,
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
use_sharding_annotations: bool = True,
|
||||
input_output_aliases: Sequence[int | None] | None = None,
|
||||
xla_donated_args: Sequence[bool] | None = None,
|
||||
num_output_tokens: int = 0,
|
||||
api_name: str = "jit",
|
||||
arg_names: Sequence[str | None] | None = None,
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
@ -1137,11 +1124,7 @@ def lower_jaxpr_to_fun(
|
||||
jaxpr: the jaxpr to lower.
|
||||
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
|
||||
that will be created in or used by the lowered function.
|
||||
create_tokens: if true, the HLO will create tokens and ignore dummy input
|
||||
tokens. See b/302258959.
|
||||
public: if true, the function's visibility is set to "public".
|
||||
replace_tokens_with_dummy: if true, token arguments/return values are
|
||||
replaced with bool arrays of size [0]. See b/302258959.
|
||||
replicated_args: if present, annotates arguments as replicated.
|
||||
arg_shardings: sharding annotations for each argument (optional).
|
||||
result_shardings: sharding annotations for each result (optional).
|
||||
@ -1158,50 +1141,38 @@ def lower_jaxpr_to_fun(
|
||||
Returns:
|
||||
MLIR func op
|
||||
"""
|
||||
def aval_to_types(aval):
|
||||
if replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
aval = core.ShapedArray((), np.dtype(np.bool_))
|
||||
return aval_to_ir_types(aval)
|
||||
|
||||
# The first dimension variable may be the platform index
|
||||
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
|
||||
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
|
||||
dim_var_types = map(aval_to_types, dim_var_avals)
|
||||
dim_var_types = map(aval_to_ir_types, dim_var_avals)
|
||||
|
||||
# Function inputs: *dim_var_values, *tokens, *actual_inputs
|
||||
input_types = map(aval_to_types, jaxpr.in_avals)
|
||||
output_types = map(aval_to_types, jaxpr.out_avals)
|
||||
input_types = map(aval_to_ir_types, jaxpr.in_avals)
|
||||
output_types = map(aval_to_ir_types, jaxpr.out_avals)
|
||||
num_tokens = len(effects)
|
||||
|
||||
if create_tokens:
|
||||
# TODO(b/302258959): Use actual tokens
|
||||
token_types = [dummy_token_type() for _ in effects]
|
||||
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)]
|
||||
else:
|
||||
# If we aren't creating tokens they will be the initial inputs to the
|
||||
# MLIR function.
|
||||
output_token_types = []
|
||||
token_types = [token_type() for _ in effects]
|
||||
token_types = [token_type() for _ in effects]
|
||||
token_avals = [core.abstract_token] * num_tokens
|
||||
# Order of arguments: dim vars, tokens, array inputs
|
||||
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
|
||||
input_types = [*dim_var_types, *token_types, *input_types]
|
||||
output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
|
||||
output_types = [*output_token_types, *token_types, *output_types]
|
||||
output_avals = [core.abstract_token] * num_tokens + jaxpr.out_avals
|
||||
output_types = [*token_types, *output_types]
|
||||
|
||||
if input_output_aliases is not None:
|
||||
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
|
||||
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
|
||||
# Update the existing aliases to account for the new output values
|
||||
input_output_aliases = [None if a is None
|
||||
else a + num_output_tokens + num_tokens
|
||||
else a + num_tokens
|
||||
for a in input_output_aliases] # type: ignore
|
||||
|
||||
if arg_shardings is not None:
|
||||
token_shardings = [None] * (num_dim_vars + num_tokens)
|
||||
arg_shardings = [*token_shardings, *arg_shardings]
|
||||
if result_shardings is not None:
|
||||
token_shardings = [None] * (num_tokens + num_output_tokens)
|
||||
token_shardings = [None] * num_tokens
|
||||
result_shardings = [*token_shardings, *result_shardings]
|
||||
if replicated_args is not None:
|
||||
token_replicated_args = [False] * (num_dim_vars + num_tokens)
|
||||
@ -1210,13 +1181,13 @@ def lower_jaxpr_to_fun(
|
||||
token_memory_kinds = [None] * (num_dim_vars + num_tokens)
|
||||
arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds]
|
||||
if result_memory_kinds is not None:
|
||||
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
|
||||
token_memory_kinds = [None] * num_tokens
|
||||
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
|
||||
if arg_layouts is not None:
|
||||
token_layouts = [None] * (num_dim_vars + num_tokens)
|
||||
arg_layouts = [*token_layouts, *arg_layouts]
|
||||
if result_layouts is not None:
|
||||
token_layouts = [None] * (num_tokens + num_output_tokens)
|
||||
token_layouts = [None] * num_tokens
|
||||
result_layouts = [*token_layouts, *result_layouts]
|
||||
if xla_donated_args is not None:
|
||||
xla_donated_args = [*([False] * (num_dim_vars + num_tokens)), *xla_donated_args]
|
||||
@ -1427,35 +1398,17 @@ def lower_jaxpr_to_fun(
|
||||
_, token_args, unflattened_args = util.split_list(
|
||||
util.unflatten(flat_args, map(len, input_types)),
|
||||
[num_dim_vars, num_tokens])
|
||||
if create_tokens:
|
||||
tokens_in = TokenSet.create(effects)
|
||||
else:
|
||||
tokens_in = TokenSet(zip(effects, token_args))
|
||||
args: list[list[ir.Value]] = []
|
||||
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
|
||||
if replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
args.append([hlo.create_token()])
|
||||
else:
|
||||
args.append(arg)
|
||||
tokens_in = TokenSet(zip(effects, token_args))
|
||||
args: list[list[ir.Value]] = unflattened_args
|
||||
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
|
||||
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
out_vals, tokens_out = jaxpr_subcomp(
|
||||
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
|
||||
consts, *args, dim_var_values=dim_var_values)
|
||||
outs = []
|
||||
if create_tokens:
|
||||
for _ in range(num_output_tokens):
|
||||
outs.append(dummy_token())
|
||||
for _ in effects:
|
||||
outs.append(dummy_token())
|
||||
else:
|
||||
for eff in effects:
|
||||
outs.append(wrap_singleton_ir_values(tokens_out.get(eff)))
|
||||
for aval, out in zip(jaxpr.out_avals, out_vals):
|
||||
if replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
outs.append(ir_constants(np.zeros((), np.bool_)))
|
||||
else:
|
||||
outs.append(out)
|
||||
for eff in effects:
|
||||
outs.append(wrap_singleton_ir_values(tokens_out.get(eff)))
|
||||
outs.extend(out_vals)
|
||||
|
||||
flat_outputs = util.flatten(outs)
|
||||
|
||||
|
@ -578,13 +578,7 @@ def _wrap_main_func(
|
||||
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
|
||||
|
||||
def is_token(typ, attrs):
|
||||
if typ == mlir.token_type()[0]:
|
||||
return True
|
||||
# TODO(b/302258959): in older versions we cannot use the token type
|
||||
try:
|
||||
return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value
|
||||
except KeyError:
|
||||
return False
|
||||
return (typ == mlir.token_type()[0])
|
||||
|
||||
orig_input_types = orig_main.type.inputs
|
||||
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs))
|
||||
|
Loading…
x
Reference in New Issue
Block a user