convert : fix squeeze for ssm_conv tensors

This commit is contained in:
Georgi Gerganov 2025-03-25 19:54:18 +02:00
parent 053b3f9aae
commit 9c60fc4c78
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3814,6 +3814,10 @@ class MambaModel(Model):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
# [4 1 8192 1] -> [4 8192 1 1]
if new_name.endswith(".ssm_conv1d"):
data_torch = data_torch.squeeze()
# assuming token_embd.weight is seen before output.weight
if self._tok_embd is not None and new_name == output_name:
if torch.equal(self._tok_embd, data_torch):