mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix JAX after upstream MLIR Python API change
Autogenerated MLIR Python API was changed to only accept optional operation arguments as keyword arguments in Python. PiperOrigin-RevId: 450651273
This commit is contained in:
parent
bf5d38c213
commit
c888a7e283
@ -4013,8 +4013,11 @@ def _infeed_lowering(ctx, token, *, shapes, partitions):
|
||||
for i in range(len(aval.shape) - 1, -1, -1)])
|
||||
for aval in shapes
|
||||
])
|
||||
infeed = mhlo.InfeedOp(flat_output_types + [mhlo.TokenType.get()], token,
|
||||
ir.StringAttr.get(''), layouts)
|
||||
infeed = mhlo.InfeedOp(
|
||||
flat_output_types + [mhlo.TokenType.get()],
|
||||
token,
|
||||
infeed_config=ir.StringAttr.get(''),
|
||||
layout=layouts)
|
||||
if partitions is not None:
|
||||
mlir.set_sharding(infeed, xla.sharding_to_proto(partitions))
|
||||
token = infeed.results[-1]
|
||||
@ -4053,8 +4056,10 @@ outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
|
||||
def _outfeed_lowering(ctx, token, *xs, partitions):
|
||||
token_aval = ctx.avals_in[0]
|
||||
outfeed = mhlo.OutfeedOp(
|
||||
mlir.aval_to_ir_type(token_aval), mlir.flatten_lowering_ir_args(xs),
|
||||
token, ir.StringAttr.get(''))
|
||||
mlir.aval_to_ir_type(token_aval),
|
||||
mlir.flatten_lowering_ir_args(xs),
|
||||
token,
|
||||
outfeed_config=ir.StringAttr.get(''))
|
||||
if partitions is not None:
|
||||
mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions))
|
||||
return outfeed.results
|
||||
|
@ -1996,10 +1996,14 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
core.ShapedArray(aval_out.shape, real_dtype))
|
||||
|
||||
def _scatter(operand_part, updates_part):
|
||||
scatter = mhlo.ScatterOp(operand_type_part, operand_part, indices,
|
||||
updates_part, scatter_dnums,
|
||||
ir.BoolAttr.get(indices_are_sorted),
|
||||
ir.BoolAttr.get(unique_indices))
|
||||
scatter = mhlo.ScatterOp(
|
||||
operand_type_part,
|
||||
operand_part,
|
||||
indices,
|
||||
updates_part,
|
||||
scatter_dnums,
|
||||
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted),
|
||||
unique_indices=ir.BoolAttr.get(unique_indices))
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype))
|
||||
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
|
@ -492,10 +492,13 @@ def _select_and_scatter_lower(
|
||||
scalar_aval = operand_aval.update(shape=())
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
op = mhlo.SelectAndScatterOp(
|
||||
mlir.aval_to_ir_type(aval_out), operand, source,
|
||||
init_value, mlir.dense_int_elements(window_dimensions),
|
||||
mlir.dense_int_elements(window_strides),
|
||||
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
operand,
|
||||
source,
|
||||
init_value,
|
||||
window_dimensions=mlir.dense_int_elements(window_dimensions),
|
||||
window_strides=mlir.dense_int_elements(window_strides),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
select = op.select.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(select):
|
||||
if select_jaxpr.effects:
|
||||
@ -734,12 +737,13 @@ def _select_and_gather_add_lowering(
|
||||
init = -np.inf if select_prim is lax.ge_p else np.inf
|
||||
rw = mhlo.ReduceWindowOp(
|
||||
[ir.RankedTensorType.get(out_aval.shape, double_word_type)],
|
||||
pack(operand, tangents), pack(const(dtype, init), const(dtype, 0)),
|
||||
pack(operand, tangents),
|
||||
pack(const(dtype, init), const(dtype, 0)),
|
||||
mlir.dense_int_elements(window_dimensions),
|
||||
mlir.dense_int_elements(window_strides),
|
||||
mlir.dense_int_elements(base_dilation),
|
||||
mlir.dense_int_elements(window_dilation),
|
||||
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
window_strides=mlir.dense_int_elements(window_strides),
|
||||
base_dilations=mlir.dense_int_elements(base_dilation),
|
||||
window_dilations=mlir.dense_int_elements(window_dilation),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
scalar_type = ir.RankedTensorType.get([], double_word_type)
|
||||
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
|
@ -1676,9 +1676,11 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
if convert_bool:
|
||||
float_zero = mlir.full_like_aval(0, padded_aval)
|
||||
out = mhlo.CompareOp(out, float_zero,
|
||||
mhlo.ComparisonDirectionAttr.get("NE"),
|
||||
mhlo.ComparisonTypeAttr.get("FLOAT")).result
|
||||
out = mhlo.CompareOp(
|
||||
out,
|
||||
float_zero,
|
||||
mhlo.ComparisonDirectionAttr.get("NE"),
|
||||
compare_type=mhlo.ComparisonTypeAttr.get("FLOAT")).result
|
||||
return out
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
Loading…
x
Reference in New Issue
Block a user