mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
fp8 matmul in pallas
PiperOrigin-RevId: 641254832
This commit is contained in:
parent
3914cb415d
commit
f51af87fc5
@ -1125,7 +1125,15 @@ def _dot_general_lowering_rule(
|
||||
(aval_out,) = ctx.avals_out
|
||||
out_type = aval_to_ir_type(aval_out)
|
||||
val_type = out_type.element_type
|
||||
if any(cls.isinstance(val_type) for cls in [ir.BF16Type, ir.F32Type]):
|
||||
if any(
|
||||
cls.isinstance(val_type)
|
||||
for cls in [
|
||||
ir.BF16Type,
|
||||
ir.F32Type,
|
||||
ir.Float8E5M2Type,
|
||||
ir.Float8E4M3FNType,
|
||||
]
|
||||
):
|
||||
val = ir.FloatAttr.get(val_type, 0.0)
|
||||
elif ir.IntegerType.isinstance(val_type):
|
||||
val = ir.IntegerAttr.get(val_type, 0)
|
||||
|
@ -403,8 +403,10 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
|
||||
finfo = jnp.finfo(dtype)
|
||||
nbits, nmant = finfo.bits, finfo.nmant
|
||||
|
||||
if nbits not in (16, 32, 64):
|
||||
raise TypeError(f"uniform only accepts 16-, 32-, or 64-bit dtypes, got {dtype}.")
|
||||
if nbits not in (8, 16, 32, 64):
|
||||
raise TypeError(
|
||||
f"uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot {dtype}."
|
||||
)
|
||||
|
||||
rng_bits = nbits
|
||||
if nmant < 8:
|
||||
@ -2354,7 +2356,6 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array:
|
||||
return tri
|
||||
|
||||
|
||||
|
||||
def lognormal(key: KeyArrayLike,
|
||||
sigma: RealArray = np.float32(1),
|
||||
shape: Shape | None = None,
|
||||
|
@ -875,9 +875,10 @@ LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
auto truncf_op = cast<arith::TruncFOp>(op);
|
||||
if (layouts_in.front()->bitwidth() != 32 ||
|
||||
layouts_out.front()->bitwidth() != 16) {
|
||||
(layouts_out.front()->bitwidth() != 16 &&
|
||||
layouts_out.front()->bitwidth() != 8)) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only 32-bit to 16-bit conversion supported");
|
||||
"Not implemented: Only 32-bit to 16-or-8-bit conversion supported");
|
||||
}
|
||||
return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(),
|
||||
*layouts_out.front());
|
||||
|
@ -1653,8 +1653,9 @@ class VectorLayoutInferer {
|
||||
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
|
||||
if (dyn_cast<arith::TruncFOp>(op)) {
|
||||
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
|
||||
dst_ty.getElementTypeBitWidth() == 16,
|
||||
"Only 32-bit to 16-bit truncation supported");
|
||||
(dst_ty.getElementTypeBitWidth() == 16 ||
|
||||
dst_ty.getElementTypeBitWidth() == 8),
|
||||
"Only 32-bit to 8-bit or 16-bit truncation supported");
|
||||
} else {
|
||||
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
|
||||
"Only 32-bit truncation supported");
|
||||
|
Loading…
x
Reference in New Issue
Block a user