Fix JAX after upstream MLIR Python API change

Autogenerated MLIR Python API was changed to only accept optional operation
arguments as keyword arguments in Python.

PiperOrigin-RevId: 450651273
This commit is contained in:
Alex Zinenko 2022-05-24 04:32:15 -07:00 committed by jax authors
parent bf5d38c213
commit c888a7e283
4 changed files with 35 additions and 20 deletions

View File

@ -4013,8 +4013,11 @@ def _infeed_lowering(ctx, token, *, shapes, partitions):
for i in range(len(aval.shape) - 1, -1, -1)])
for aval in shapes
])
infeed = mhlo.InfeedOp(flat_output_types + [mhlo.TokenType.get()], token,
ir.StringAttr.get(''), layouts)
infeed = mhlo.InfeedOp(
flat_output_types + [mhlo.TokenType.get()],
token,
infeed_config=ir.StringAttr.get(''),
layout=layouts)
if partitions is not None:
mlir.set_sharding(infeed, xla.sharding_to_proto(partitions))
token = infeed.results[-1]
@ -4053,8 +4056,10 @@ outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
def _outfeed_lowering(ctx, token, *xs, partitions):
token_aval = ctx.avals_in[0]
outfeed = mhlo.OutfeedOp(
mlir.aval_to_ir_type(token_aval), mlir.flatten_lowering_ir_args(xs),
token, ir.StringAttr.get(''))
mlir.aval_to_ir_type(token_aval),
mlir.flatten_lowering_ir_args(xs),
token,
outfeed_config=ir.StringAttr.get(''))
if partitions is not None:
mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions))
return outfeed.results

View File

@ -1996,10 +1996,14 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
core.ShapedArray(aval_out.shape, real_dtype))
def _scatter(operand_part, updates_part):
scatter = mhlo.ScatterOp(operand_type_part, operand_part, indices,
updates_part, scatter_dnums,
ir.BoolAttr.get(indices_are_sorted),
ir.BoolAttr.get(unique_indices))
scatter = mhlo.ScatterOp(
operand_type_part,
operand_part,
indices,
updates_part,
scatter_dnums,
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted),
unique_indices=ir.BoolAttr.get(unique_indices))
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype))
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):

View File

@ -492,10 +492,13 @@ def _select_and_scatter_lower(
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
op = mhlo.SelectAndScatterOp(
mlir.aval_to_ir_type(aval_out), operand, source,
init_value, mlir.dense_int_elements(window_dimensions),
mlir.dense_int_elements(window_strides),
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
mlir.aval_to_ir_type(aval_out),
operand,
source,
init_value,
window_dimensions=mlir.dense_int_elements(window_dimensions),
window_strides=mlir.dense_int_elements(window_strides),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
select = op.select.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(select):
if select_jaxpr.effects:
@ -734,12 +737,13 @@ def _select_and_gather_add_lowering(
init = -np.inf if select_prim is lax.ge_p else np.inf
rw = mhlo.ReduceWindowOp(
[ir.RankedTensorType.get(out_aval.shape, double_word_type)],
pack(operand, tangents), pack(const(dtype, init), const(dtype, 0)),
pack(operand, tangents),
pack(const(dtype, init), const(dtype, 0)),
mlir.dense_int_elements(window_dimensions),
mlir.dense_int_elements(window_strides),
mlir.dense_int_elements(base_dilation),
mlir.dense_int_elements(window_dilation),
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
window_strides=mlir.dense_int_elements(window_strides),
base_dilations=mlir.dense_int_elements(base_dilation),
window_dilations=mlir.dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
scalar_type = ir.RankedTensorType.get([], double_word_type)
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):

View File

@ -1676,9 +1676,11 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
float_zero = mlir.full_like_aval(0, padded_aval)
out = mhlo.CompareOp(out, float_zero,
mhlo.ComparisonDirectionAttr.get("NE"),
mhlo.ComparisonTypeAttr.get("FLOAT")).result
out = mhlo.CompareOp(
out,
float_zero,
mhlo.ComparisonDirectionAttr.get("NE"),
compare_type=mhlo.ComparisonTypeAttr.get("FLOAT")).result
return out
else:
raise TypeError(aval)