diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 6b722c613..5d8de6cef 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -850,10 +850,11 @@ def _dot_general_lowering_rule( (lhs_dims, rhs_dims), _ = dimension_numbers (aval_out,) = ctx.avals_out out_type = aval_to_ir_type(aval_out) - if ctx.avals_out[0].dtype == jnp.float32: - val = ir.FloatAttr.get(ir.F32Type.get(), 0.0) - elif ctx.avals_out[0].dtype == jnp.float16: - val = ir.FloatAttr.get(ir.F16Type.get(), 0.0) + val_type = out_type.element_type + if any(cls.isinstance(val_type) for cls in [ir.BF16Type, ir.F32Type]): + val = ir.FloatAttr.get(val_type, 0.0) + elif ir.IntegerType.isinstance(val_type): + val = ir.IntegerAttr.get(val_type, 0) else: raise NotImplementedError(ctx.avals_out[0].dtype) if any(len(a.shape) != 2 for a in ctx.avals_in):