mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas] Allow matmul kernel to accept bfloat16 and int8 inputs
PiperOrigin-RevId: 580966611
This commit is contained in:
parent
cf3c041366
commit
9b1572c028
@ -850,10 +850,11 @@ def _dot_general_lowering_rule(
|
||||
(lhs_dims, rhs_dims), _ = dimension_numbers
|
||||
(aval_out,) = ctx.avals_out
|
||||
out_type = aval_to_ir_type(aval_out)
|
||||
if ctx.avals_out[0].dtype == jnp.float32:
|
||||
val = ir.FloatAttr.get(ir.F32Type.get(), 0.0)
|
||||
elif ctx.avals_out[0].dtype == jnp.float16:
|
||||
val = ir.FloatAttr.get(ir.F16Type.get(), 0.0)
|
||||
val_type = out_type.element_type
|
||||
if any(cls.isinstance(val_type) for cls in [ir.BF16Type, ir.F32Type]):
|
||||
val = ir.FloatAttr.get(val_type, 0.0)
|
||||
elif ir.IntegerType.isinstance(val_type):
|
||||
val = ir.IntegerAttr.get(val_type, 0)
|
||||
else:
|
||||
raise NotImplementedError(ctx.avals_out[0].dtype)
|
||||
if any(len(a.shape) != 2 for a in ctx.avals_in):
|
||||
|
Loading…
x
Reference in New Issue
Block a user