mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11468 from gnecula:ds_fix_ctx
PiperOrigin-RevId: 460632380
This commit is contained in:
commit
5b02c0c9df
@ -3980,13 +3980,9 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
with ir.InsertionPoint(comparator):
|
||||
lower_comparator = mlir.lower_fun(partial(_sort_lt_comparator),
|
||||
multiple_results=False)
|
||||
sub_ctx = mlir.LoweringRuleContext(
|
||||
module_context = ctx.module_context,
|
||||
primitive=None,
|
||||
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
|
||||
avals_out=[core.ShapedArray((), np.bool_)],
|
||||
tokens_in=ctx.tokens_in,
|
||||
tokens_out=ctx.tokens_out)
|
||||
sub_ctx = ctx.replace(primitive=None,
|
||||
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
|
||||
avals_out=[core.ShapedArray((), np.bool_)])
|
||||
|
||||
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
|
||||
num_keys=num_keys)
|
||||
|
@ -1203,12 +1203,7 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
|
||||
lu, _nan_like_mhlo(out_aval))
|
||||
sub_ctx = mlir.LoweringRuleContext(module_context=ctx.module_context,
|
||||
primitive=None,
|
||||
avals_in=[pivot_aval],
|
||||
avals_out=[perm_aval],
|
||||
tokens_in=ctx.tokens_in,
|
||||
tokens_out=ctx.tokens_out)
|
||||
sub_ctx = ctx.replace(primitive=None, avals_in=[pivot_aval], avals_out=[perm_aval])
|
||||
perm_fn = mlir.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
|
||||
multiple_results=False)
|
||||
perm, = perm_fn(sub_ctx, pivot)
|
||||
|
@ -704,13 +704,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
aval_out = aval.update(
|
||||
shape=np.delete(np.array(aval.shape, dtype=np.int64),
|
||||
positional_axes))
|
||||
reducer_ctx = mlir.LoweringRuleContext(
|
||||
module_context=ctx.module_context,
|
||||
primitive=None,
|
||||
avals_in=[aval],
|
||||
avals_out=[aval_out],
|
||||
tokens_in=ctx.tokens_in,
|
||||
tokens_out=ctx.tokens_out)
|
||||
reducer_ctx = ctx.replace(primitive=None, avals_in=[aval], avals_out=[aval_out])
|
||||
out, = reducer(reducer_ctx, arg, axes=tuple(positional_axes))[0]
|
||||
return out
|
||||
args = map(_positional_reduce, ctx.avals_in, args)
|
||||
@ -728,10 +722,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_block):
|
||||
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
|
||||
reducer_ctx = mlir.LoweringRuleContext(
|
||||
module_context = ctx.module_context,
|
||||
primitive=None, avals_in=[scalar_aval] * 2, avals_out=[scalar_aval],
|
||||
tokens_in=ctx.tokens_in, tokens_out=ctx.tokens_out)
|
||||
reducer_ctx = ctx.replace(primitive=None,
|
||||
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
@ -1284,10 +1276,9 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x,
|
||||
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_block):
|
||||
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
|
||||
reducer_ctx = mlir.LoweringRuleContext(
|
||||
module_context = ctx.module_context,
|
||||
primitive=None, avals_in=[scalar_aval] * 2, avals_out=[scalar_aval],
|
||||
tokens_in=ctx.tokens_in, tokens_out=ctx.tokens_out)
|
||||
reducer_ctx = ctx.replace(primitive=None,
|
||||
avals_in=[scalar_aval] * 2,
|
||||
avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
|
@ -1372,11 +1372,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
mlir.lower_fun(partial(_tile, in_axes=arg_in_axes,
|
||||
axis_sizes=local_mesh_shape),
|
||||
multiple_results=False)(
|
||||
mlir.LoweringRuleContext(module_context=ctx.module_context,
|
||||
primitive=None,
|
||||
avals_in=[aval], avals_out=None,
|
||||
tokens_in=ctx.tokens_in,
|
||||
tokens_out=ctx.tokens_out),
|
||||
ctx.replace(primitive=None,
|
||||
avals_in=[aval], avals_out=None),
|
||||
in_node)[0]
|
||||
for v, aval, in_node, arg_in_axes
|
||||
in zip(call_jaxpr.invars, ctx.avals_in, in_nodes, mesh_in_axes))
|
||||
@ -1398,12 +1395,9 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape,
|
||||
platform=ctx.module_context.platform),
|
||||
multiple_results=False)(
|
||||
mlir.LoweringRuleContext(module_context=ctx.module_context,
|
||||
primitive=None,
|
||||
avals_in=[vectorized_outvar.aval],
|
||||
avals_out=None,
|
||||
tokens_in=ctx.tokens_in,
|
||||
tokens_out=ctx.tokens_out), tiled_out)[0]
|
||||
ctx.replace(primitive=None,
|
||||
avals_in=[vectorized_outvar.aval],
|
||||
avals_out=None), tiled_out)[0]
|
||||
for v, vectorized_outvar, tiled_out, ans_out_axes
|
||||
in zip(call_jaxpr.outvars, vectorized_jaxpr.outvars, tiled_outs,
|
||||
mesh_out_axes)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user