mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Replace (deprecated) StrEnumAttr with EnumAttr.
ref: https://reviews.llvm.org/D120834 PiperOrigin-RevId: 435550738
This commit is contained in:
parent
1f95273a46
commit
6cd9804163
@ -1579,13 +1579,23 @@ ad.defjvp_zero(sign_p)
|
||||
def _sign_lower_mhlo(ctx, x):
|
||||
x_aval, = ctx.avals_in
|
||||
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
|
||||
return mhlo.SelectOp(
|
||||
mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))),
|
||||
x, mlir.full_like_aval(0, x_aval), ir.StringAttr.get("EQ"),
|
||||
ir.StringAttr.get("UNSIGNED")).result,
|
||||
mlir.full_like_aval(0, x_aval),
|
||||
mlir.full_like_aval(1, x_aval)).results
|
||||
if jax._src.lib.mlir_api_version >= 3:
|
||||
return mhlo.SelectOp(
|
||||
mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x,
|
||||
mlir.full_like_aval(0, x_aval),
|
||||
mhlo.ComparisonDirectionAttr.get('EQ'),
|
||||
mhlo.ComparisonTypeAttr.get('UNSIGNED')).result,
|
||||
mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1,
|
||||
x_aval)).results
|
||||
else:
|
||||
return mhlo.SelectOp(
|
||||
mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x,
|
||||
mlir.full_like_aval(0, x_aval), ir.StringAttr.get('EQ'),
|
||||
ir.StringAttr.get('UNSIGNED')).result,
|
||||
mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1,
|
||||
x_aval)).results
|
||||
return mhlo.SignOp(x).results
|
||||
|
||||
mlir.register_lowering(sign_p, _sign_lower_mhlo)
|
||||
@ -2216,9 +2226,15 @@ def _compare_lower_mhlo(direction: str, ctx, x, y):
|
||||
compare_type = "SIGNED"
|
||||
else:
|
||||
compare_type = "UNSIGNED"
|
||||
return mhlo.CompareOp(mlir.aval_to_ir_type(aval_out), x, y,
|
||||
ir.StringAttr.get(direction),
|
||||
ir.StringAttr.get(compare_type)).results
|
||||
if jax._src.lib.mlir_api_version >= 3:
|
||||
return mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(aval_out), x, y,
|
||||
mhlo.ComparisonDirectionAttr.get(direction),
|
||||
mhlo.ComparisonTypeAttr.get(compare_type)).results
|
||||
else:
|
||||
return mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(aval_out), x, y, ir.StringAttr.get(direction),
|
||||
ir.StringAttr.get(compare_type)).results
|
||||
|
||||
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
|
||||
ad.defjvp_zero(eq_p)
|
||||
@ -2630,7 +2646,13 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
|
||||
full_precision = (precision, precision)
|
||||
else:
|
||||
full_precision = precision
|
||||
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in full_precision])
|
||||
if jax._src.lib.mlir_api_version >= 3:
|
||||
return ir.ArrayAttr.get(
|
||||
[mhlo.PrecisionAttr.get(str(p)) for p in full_precision])
|
||||
else:
|
||||
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in full_precision])
|
||||
|
||||
|
||||
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision, preferred_element_type: Optional[np.dtype]):
|
||||
@ -3335,19 +3357,26 @@ def _select_mhlo_lowering(ctx, which, *cases):
|
||||
bool_shape = ir.RankedTensorType.get(which_aval.shape,
|
||||
ir.IntegerType.get_signless(1))
|
||||
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
|
||||
compare_type = ir.StringAttr.get("SIGNED")
|
||||
compare_type = 'SIGNED'
|
||||
else:
|
||||
compare_type = ir.StringAttr.get("UNSIGNED")
|
||||
lt = ir.StringAttr.get("LT")
|
||||
compare_type = 'UNSIGNED'
|
||||
lt = 'LT'
|
||||
|
||||
def _select(offset, cases):
|
||||
assert len(cases) > 0
|
||||
if len(cases) == 1:
|
||||
return cases[0]
|
||||
mid = len(cases) // 2
|
||||
pred = mhlo.CompareOp(
|
||||
bool_shape, which, mlir.full_like_aval(offset + mid, which_aval),
|
||||
lt, compare_type)
|
||||
if jax._src.lib.mlir_api_version >= 3:
|
||||
pred = mhlo.CompareOp(bool_shape, which,
|
||||
mlir.full_like_aval(offset + mid, which_aval),
|
||||
mhlo.ComparisonDirectionAttr.get(lt),
|
||||
mhlo.ComparisonTypeAttr.get(compare_type))
|
||||
else:
|
||||
pred = mhlo.CompareOp(bool_shape, which,
|
||||
mlir.full_like_aval(offset + mid, which_aval),
|
||||
ir.StringAttr.get(lt),
|
||||
ir.StringAttr.get(compare_type))
|
||||
return mhlo.SelectOp(pred, _select(offset, cases[:mid]),
|
||||
_select(offset + mid, cases[mid:])).result
|
||||
|
||||
|
@ -28,6 +28,7 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
from typing_extensions import Protocol
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import ad_util
|
||||
@ -820,15 +821,27 @@ def _minmax_mhlo(op, cmp, x, y):
|
||||
ry = mhlo.RealOp(y).result
|
||||
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
|
||||
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
|
||||
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
|
||||
ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
|
||||
mhlo.ImagOp(y).result,
|
||||
ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
if jax._src.lib.mlir_api_version >= 3:
|
||||
real_eq = mhlo.CompareOp(bool_shape, rx, ry,
|
||||
mhlo.ComparisonDirectionAttr.get("EQ"),
|
||||
mhlo.ComparisonTypeAttr.get("FLOAT"))
|
||||
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
|
||||
mhlo.ComparisonDirectionAttr.get(cmp),
|
||||
mhlo.ComparisonTypeAttr.get("FLOAT"))
|
||||
imag_cmp = mhlo.CompareOp(bool_shape,
|
||||
mhlo.ImagOp(x).result,
|
||||
mhlo.ImagOp(y).result,
|
||||
mhlo.ComparisonDirectionAttr.get(cmp),
|
||||
mhlo.ComparisonTypeAttr.get("FLOAT"))
|
||||
else:
|
||||
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
real_cmp = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
imag_cmp = mhlo.CompareOp(bool_shape,
|
||||
mhlo.ImagOp(x).result,
|
||||
mhlo.ImagOp(y).result, ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
|
||||
return mhlo.SelectOp(which, x, y)
|
||||
else:
|
||||
@ -850,9 +863,15 @@ def convert_mhlo(x, aval_in, aval_out):
|
||||
compare_type = "SIGNED"
|
||||
else:
|
||||
compare_type = "UNSIGNED"
|
||||
return mhlo.CompareOp(
|
||||
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
|
||||
ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result
|
||||
if jax._src.lib.mlir_api_version >= 3:
|
||||
return mhlo.CompareOp(
|
||||
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
|
||||
mhlo.ComparisonDirectionAttr.get("NE"),
|
||||
mhlo.ComparisonTypeAttr.get(compare_type)).result
|
||||
else:
|
||||
return mhlo.CompareOp(
|
||||
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
|
||||
ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result
|
||||
return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result
|
||||
|
||||
def _wrap_with_spmd_op(name: str,
|
||||
|
@ -44,6 +44,7 @@ import sys
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src.config import config
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
@ -1753,10 +1754,16 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
if convert_bool:
|
||||
float_zero = mlir.full_like_aval(0, padded_aval)
|
||||
out = mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
|
||||
out, float_zero, ir.StringAttr.get("NE"),
|
||||
ir.StringAttr.get("FLOAT")).result
|
||||
if jax._src.lib.mlir_api_version >= 3:
|
||||
out = mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
|
||||
out, float_zero, mhlo.ComparisonDirectionAttr.get("NE"),
|
||||
mhlo.ComparisonTypeAttr.get("FLOAT")).result
|
||||
else:
|
||||
out = mhlo.CompareOp(
|
||||
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
|
||||
out, float_zero, ir.StringAttr.get("NE"),
|
||||
ir.StringAttr.get("FLOAT")).result
|
||||
return out
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
@ -199,29 +199,29 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: eq float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq complex128[] complex128[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(1), np.complex128(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq int64[] int64[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "SIGNED"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type SIGNED">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(1), np.int64(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq uint16[] uint16[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "UNSIGNED"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type UNSIGNED">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
|
||||
|
||||
@ -257,15 +257,15 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: ge float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "GE"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GE">
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ge)
|
||||
|
||||
# CHECK-LABEL: TEST: gt float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "GT"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GT">
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.gt)
|
||||
|
||||
@ -302,8 +302,8 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: le float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "LE"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LE">
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.le)
|
||||
|
||||
@ -324,8 +324,8 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: lt float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "LT"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LT">
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.lt)
|
||||
|
||||
@ -346,8 +346,8 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: ne float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "NE"
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction NE">
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ne)
|
||||
|
||||
@ -418,7 +418,7 @@ def main(_):
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)
|
||||
|
||||
# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
|
||||
# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
|
||||
# CHECK: mhlo.shift_right_arithmetic
|
||||
# CHECK-SAME: tensor<ui8>
|
||||
print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)
|
||||
|
Loading…
x
Reference in New Issue
Block a user