fp8 matmul in pallas

PiperOrigin-RevId: 641254832
This commit is contained in:
jax authors 2024-06-07 08:16:09 -07:00 committed by jax authors
parent 3914cb415d
commit f51af87fc5
4 changed files with 19 additions and 8 deletions

View File

@ -1125,7 +1125,15 @@ def _dot_general_lowering_rule(
(aval_out,) = ctx.avals_out
out_type = aval_to_ir_type(aval_out)
val_type = out_type.element_type
if any(cls.isinstance(val_type) for cls in [ir.BF16Type, ir.F32Type]):
if any(
cls.isinstance(val_type)
for cls in [
ir.BF16Type,
ir.F32Type,
ir.Float8E5M2Type,
ir.Float8E4M3FNType,
]
):
val = ir.FloatAttr.get(val_type, 0.0)
elif ir.IntegerType.isinstance(val_type):
val = ir.IntegerAttr.get(val_type, 0)

View File

@ -403,8 +403,10 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
finfo = jnp.finfo(dtype)
nbits, nmant = finfo.bits, finfo.nmant
if nbits not in (16, 32, 64):
raise TypeError(f"uniform only accepts 16-, 32-, or 64-bit dtypes, got {dtype}.")
if nbits not in (8, 16, 32, 64):
raise TypeError(
f"uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot {dtype}."
)
rng_bits = nbits
if nmant < 8:
@ -2354,7 +2356,6 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array:
return tri
def lognormal(key: KeyArrayLike,
sigma: RealArray = np.float32(1),
shape: Shape | None = None,

View File

@ -875,9 +875,10 @@ LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(layouts_out.front().has_value());
auto truncf_op = cast<arith::TruncFOp>(op);
if (layouts_in.front()->bitwidth() != 32 ||
layouts_out.front()->bitwidth() != 16) {
(layouts_out.front()->bitwidth() != 16 &&
layouts_out.front()->bitwidth() != 8)) {
return op.emitOpError(
"Not implemented: Only 32-bit to 16-bit conversion supported");
"Not implemented: Only 32-bit to 16-or-8-bit conversion supported");
}
return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(),
*layouts_out.front());

View File

@ -1653,8 +1653,9 @@ class VectorLayoutInferer {
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
if (dyn_cast<arith::TruncFOp>(op)) {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
dst_ty.getElementTypeBitWidth() == 16,
"Only 32-bit to 16-bit truncation supported");
(dst_ty.getElementTypeBitWidth() == 16 ||
dst_ty.getElementTypeBitWidth() == 8),
"Only 32-bit to 8-bit or 16-bit truncation supported");
} else {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
"Only 32-bit truncation supported");