mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Migrated tl.semantic.cast to lower directly to MLIR
PiperOrigin-RevId: 604255461
This commit is contained in:
parent
f2613387bd
commit
c77d45d511
@ -1266,11 +1266,145 @@ class math:
|
||||
})
|
||||
|
||||
|
||||
def _full(t: ir.Type, v: object) -> ir.Type:
|
||||
element_type = t
|
||||
if ir.RankedTensorType.isinstance(t):
|
||||
element_type = ir.RankedTensorType(t).element_type
|
||||
|
||||
if isinstance(element_type, ir.IntegerType):
|
||||
result = arith_dialect.constant(element_type, int(v))
|
||||
elif isinstance(element_type, _FLOAT_TYPES):
|
||||
result = arith_dialect.constant(element_type, float(v))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if ir.RankedTensorType.isinstance(t):
|
||||
return tt_dialect.splat(t, result)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
|
||||
class semantic:
|
||||
|
||||
@staticmethod
|
||||
def cast(x: tensor, dst_ty: dtype) -> tensor:
|
||||
return _to_tensor(tl.semantic.cast(x, dst_ty, builder.current))
|
||||
src_ty = x.type
|
||||
if src_ty.is_block():
|
||||
dst_ty = block_type(dst_ty.scalar, x.shape)
|
||||
if src_ty == dst_ty:
|
||||
return x
|
||||
|
||||
src_element_ty = src_ty.scalar
|
||||
dst_element_ty = dst_ty.scalar
|
||||
|
||||
if src_element_ty.is_fp8e4nv() or dst_element_ty.is_fp8e4nv():
|
||||
# TODO(slebedev): Check the CUDA version and raise conditionally.
|
||||
raise NotImplementedError("cannot cast from or to float8_e4m3fnuz")
|
||||
|
||||
if (
|
||||
src_element_ty.is_floating()
|
||||
and dst_element_ty.is_floating()
|
||||
and (src_element_ty.is_fp8() or dst_element_ty.is_fp8())
|
||||
):
|
||||
return tensor(
|
||||
tt_dialect.fp_to_fp(
|
||||
dst_ty.to_ir(builder.current),
|
||||
x.handle,
|
||||
rounding=tt_dialect.RoundingMode.RTNE,
|
||||
),
|
||||
dst_ty,
|
||||
)
|
||||
|
||||
if (
|
||||
src_element_ty.is_fp16() or src_element_ty.is_bf16()
|
||||
) and not dst_element_ty.is_fp32():
|
||||
return semantic.cast(
|
||||
semantic.cast(x, float32), dst_element_ty.to_ir(builder.current)
|
||||
)
|
||||
|
||||
if src_element_ty.is_floating() and dst_element_ty.is_floating():
|
||||
src_width = src_element_ty.primitive_bitwidth
|
||||
dst_width = dst_element_ty.primitive_bitwidth
|
||||
if src_width > dst_width:
|
||||
return tensor(
|
||||
arith_dialect.truncf(dst_ty.to_ir(builder.current), x.handle),
|
||||
dst_ty,
|
||||
)
|
||||
elif src_width < dst_width:
|
||||
return tensor(
|
||||
arith_dialect.extf(dst_ty.to_ir(builder.current), x.handle), dst_ty
|
||||
)
|
||||
|
||||
if (
|
||||
src_element_ty.is_int()
|
||||
and dst_element_ty.is_int()
|
||||
and (
|
||||
src_element_ty.int_bitwidth != dst_element_ty.int_bitwidth
|
||||
or src_element_ty.int_signedness != dst_element_ty.int_signedness
|
||||
)
|
||||
):
|
||||
if dst_element_ty.is_bool():
|
||||
zero = tensor(_full(src_ty.to_ir(builder.current), 0), src_ty)
|
||||
return semantic.not_equal(x, zero)
|
||||
else:
|
||||
sign_extend = (
|
||||
src_element_ty.is_int_signed() and not src_element_ty.is_bool()
|
||||
)
|
||||
return tensor(
|
||||
builder.current.create_int_cast(x.handle, dst_ty.to_ir(builder.current), sign_extend),
|
||||
dst_ty,
|
||||
)
|
||||
|
||||
if src_element_ty.is_standard_floating() and dst_element_ty.is_int():
|
||||
if dst_element_ty.is_bool():
|
||||
zero = tensor(_full(src_ty.to_ir(builder.current), 0), src_ty)
|
||||
return semantic.not_equal(x, zero)
|
||||
elif dst_element_ty.is_int_signed():
|
||||
return tensor(
|
||||
arith_dialect.fptosi(dst_ty.to_ir(builder.current), x.handle),
|
||||
dst_ty,
|
||||
)
|
||||
else:
|
||||
return tensor(
|
||||
arith_dialect.fptoui(dst_ty.to_ir(builder.current), x.handle),
|
||||
dst_ty,
|
||||
)
|
||||
|
||||
if src_element_ty.is_int() and dst_element_ty.is_standard_floating():
|
||||
if src_element_ty.is_bool() or not src_element_ty.is_int_signed():
|
||||
return tensor(
|
||||
arith_dialect.uitofp(dst_ty.to_ir(builder.current), x.handle),
|
||||
dst_ty,
|
||||
)
|
||||
else:
|
||||
return tensor(
|
||||
arith_dialect.sitofp(dst_ty.to_ir(builder.current), x.handle),
|
||||
dst_ty,
|
||||
)
|
||||
|
||||
if src_element_ty.is_ptr() and dst_element_ty.is_int():
|
||||
if dst_element_ty.int_bitwidth == 64:
|
||||
return tensor(
|
||||
tt_dialect.ptr_to_int(dst_ty.to_ir(builder.current), x.handle),
|
||||
dst_ty,
|
||||
)
|
||||
else:
|
||||
x = semantic.cast(x, int64)
|
||||
zero = tensor(_full(x.type.to_ir(builder.current), 0), x.type)
|
||||
return semantic.not_equal(x, zero)
|
||||
|
||||
if src_element_ty.is_int() and dst_element_ty.is_ptr():
|
||||
return tensor(
|
||||
tt_dialect.int_to_ptr(dst_ty.to_ir(builder.current), x.handle), dst_ty
|
||||
)
|
||||
|
||||
if src_element_ty.is_ptr() and dst_element_ty.is_ptr():
|
||||
return tensor(
|
||||
tt_dialect.bitcast(dst_ty.to_ir(builder.current), x.handle), dst_ty
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"cannot cast {x} to {dst_ty}")
|
||||
|
||||
@staticmethod
|
||||
def where(cond: tensor, x: tensor, y: tensor) -> tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user