[Pallas] Allow matmul kernel to accept bfloat16 and int8 inputs

PiperOrigin-RevId: 580966611
This commit is contained in:
Tomás Longeri 2023-11-09 11:09:15 -08:00 committed by jax authors
parent cf3c041366
commit 9b1572c028

View File

@ -850,10 +850,11 @@ def _dot_general_lowering_rule(
(lhs_dims, rhs_dims), _ = dimension_numbers
(aval_out,) = ctx.avals_out
out_type = aval_to_ir_type(aval_out)
if ctx.avals_out[0].dtype == jnp.float32:
val = ir.FloatAttr.get(ir.F32Type.get(), 0.0)
elif ctx.avals_out[0].dtype == jnp.float16:
val = ir.FloatAttr.get(ir.F16Type.get(), 0.0)
val_type = out_type.element_type
if any(cls.isinstance(val_type) for cls in [ir.BF16Type, ir.F32Type]):
val = ir.FloatAttr.get(val_type, 0.0)
elif ir.IntegerType.isinstance(val_type):
val = ir.IntegerAttr.get(val_type, 0)
else:
raise NotImplementedError(ctx.avals_out[0].dtype)
if any(len(a.shape) != 2 for a in ctx.avals_in):