convert : match ssm_conv tensors by type

This commit is contained in:
Francis Couture-Harpin 2025-03-25 14:29:22 -04:00
parent 9c60fc4c78
commit 20b256e0fd

View File

@ -3803,8 +3803,6 @@ class MambaModel(Model):
_tok_embd = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
@ -3815,7 +3813,7 @@ class MambaModel(Model):
data_torch = -torch.exp(data_torch)
# [4 1 8192 1] -> [4 8192 1 1]
if new_name.endswith(".ssm_conv1d"):
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
data_torch = data_torch.squeeze()
# assuming token_embd.weight is seen before output.weight