Migrated tl.semantic.cast to lower directly to MLIR

PiperOrigin-RevId: 604255461
This commit is contained in:
Sergei Lebedev 2024-02-05 02:34:17 -08:00 committed by jax authors
parent f2613387bd
commit c77d45d511

View File

@ -1266,11 +1266,145 @@ class math:
})
def _full(t: ir.Type, v: object) -> ir.Type:
element_type = t
if ir.RankedTensorType.isinstance(t):
element_type = ir.RankedTensorType(t).element_type
if isinstance(element_type, ir.IntegerType):
result = arith_dialect.constant(element_type, int(v))
elif isinstance(element_type, _FLOAT_TYPES):
result = arith_dialect.constant(element_type, float(v))
else:
raise NotImplementedError
if ir.RankedTensorType.isinstance(t):
return tt_dialect.splat(t, result)
else:
return result
class semantic:
@staticmethod
def cast(x: tensor, dst_ty: dtype) -> tensor:
return _to_tensor(tl.semantic.cast(x, dst_ty, builder.current))
src_ty = x.type
if src_ty.is_block():
dst_ty = block_type(dst_ty.scalar, x.shape)
if src_ty == dst_ty:
return x
src_element_ty = src_ty.scalar
dst_element_ty = dst_ty.scalar
if src_element_ty.is_fp8e4nv() or dst_element_ty.is_fp8e4nv():
# TODO(slebedev): Check the CUDA version and raise conditionally.
raise NotImplementedError("cannot cast from or to float8_e4m3fnuz")
if (
src_element_ty.is_floating()
and dst_element_ty.is_floating()
and (src_element_ty.is_fp8() or dst_element_ty.is_fp8())
):
return tensor(
tt_dialect.fp_to_fp(
dst_ty.to_ir(builder.current),
x.handle,
rounding=tt_dialect.RoundingMode.RTNE,
),
dst_ty,
)
if (
src_element_ty.is_fp16() or src_element_ty.is_bf16()
) and not dst_element_ty.is_fp32():
return semantic.cast(
semantic.cast(x, float32), dst_element_ty.to_ir(builder.current)
)
if src_element_ty.is_floating() and dst_element_ty.is_floating():
src_width = src_element_ty.primitive_bitwidth
dst_width = dst_element_ty.primitive_bitwidth
if src_width > dst_width:
return tensor(
arith_dialect.truncf(dst_ty.to_ir(builder.current), x.handle),
dst_ty,
)
elif src_width < dst_width:
return tensor(
arith_dialect.extf(dst_ty.to_ir(builder.current), x.handle), dst_ty
)
if (
src_element_ty.is_int()
and dst_element_ty.is_int()
and (
src_element_ty.int_bitwidth != dst_element_ty.int_bitwidth
or src_element_ty.int_signedness != dst_element_ty.int_signedness
)
):
if dst_element_ty.is_bool():
zero = tensor(_full(src_ty.to_ir(builder.current), 0), src_ty)
return semantic.not_equal(x, zero)
else:
sign_extend = (
src_element_ty.is_int_signed() and not src_element_ty.is_bool()
)
return tensor(
builder.current.create_int_cast(x.handle, dst_ty.to_ir(builder.current), sign_extend),
dst_ty,
)
if src_element_ty.is_standard_floating() and dst_element_ty.is_int():
if dst_element_ty.is_bool():
zero = tensor(_full(src_ty.to_ir(builder.current), 0), src_ty)
return semantic.not_equal(x, zero)
elif dst_element_ty.is_int_signed():
return tensor(
arith_dialect.fptosi(dst_ty.to_ir(builder.current), x.handle),
dst_ty,
)
else:
return tensor(
arith_dialect.fptoui(dst_ty.to_ir(builder.current), x.handle),
dst_ty,
)
if src_element_ty.is_int() and dst_element_ty.is_standard_floating():
if src_element_ty.is_bool() or not src_element_ty.is_int_signed():
return tensor(
arith_dialect.uitofp(dst_ty.to_ir(builder.current), x.handle),
dst_ty,
)
else:
return tensor(
arith_dialect.sitofp(dst_ty.to_ir(builder.current), x.handle),
dst_ty,
)
if src_element_ty.is_ptr() and dst_element_ty.is_int():
if dst_element_ty.int_bitwidth == 64:
return tensor(
tt_dialect.ptr_to_int(dst_ty.to_ir(builder.current), x.handle),
dst_ty,
)
else:
x = semantic.cast(x, int64)
zero = tensor(_full(x.type.to_ir(builder.current), 0), x.type)
return semantic.not_equal(x, zero)
if src_element_ty.is_int() and dst_element_ty.is_ptr():
return tensor(
tt_dialect.int_to_ptr(dst_ty.to_ir(builder.current), x.handle), dst_ty
)
if src_element_ty.is_ptr() and dst_element_ty.is_ptr():
return tensor(
tt_dialect.bitcast(dst_ty.to_ir(builder.current), x.handle), dst_ty
)
raise NotImplementedError(f"cannot cast {x} to {dst_ty}")
@staticmethod
def where(cond: tensor, x: tensor, y: tensor) -> tensor: