mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 03:26:08 +00:00
llama: Add support for RWKV v7 architecture (#12412)
* ggml: Add op l2_norm Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add op rwkv_wkv7 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: Add support for RWKV7 and ARWKV7 models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix inference with RWKV6Qwen2 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: add more (a)rwkv7 variants in size Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Apply code-format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * fix MUSA build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix shape error with rwkv using llama-parallel Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
60c902926c
commit
7dfad387e3
@ -908,6 +908,40 @@ class Model:
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_rwkv_world(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
|
||||
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
|
||||
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||
@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV6
|
||||
|
||||
def set_vocab(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
@ -3565,6 +3568,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||
yield (new_name, data)
|
||||
|
||||
|
||||
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
|
||||
class Rwkv7Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def calc_lora_rank(self, hidden_size, exponent, multiplier):
|
||||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
try:
|
||||
head_size = self.hparams["head_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
except KeyError:
|
||||
head_size = self.hparams["head_dim"]
|
||||
layer_norm_eps = self.hparams["norm_eps"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
try:
|
||||
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
except KeyError:
|
||||
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
lerp_weights: dict[int, dict[str, Tensor]] = {}
|
||||
lora_needs_transpose: bool = True
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# unify tensor names here to make life easier
|
||||
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
|
||||
name = name.replace("self_attn", "attention").replace("attn", "attention")
|
||||
name = name.replace("time_mixer.", "")
|
||||
# lora layer names in fla-hub's impl
|
||||
if "_lora.lora" in name:
|
||||
self.lora_needs_transpose = False
|
||||
name = name.replace("_lora.lora.0.weight", "1.weight")
|
||||
name = name.replace("_lora.lora.2.weight", "2.weight")
|
||||
name = name.replace("_lora.lora.2.bias", "0.weight")
|
||||
|
||||
name = name.replace("feed_forward_norm", "ln2")
|
||||
name = name.replace("g_norm", "ln_x")
|
||||
|
||||
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
|
||||
# some models have dummy v0/v1/v2 on first layer while others don't
|
||||
# ignore them all since they are not used
|
||||
return
|
||||
|
||||
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
|
||||
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
|
||||
|
||||
if bid is not None and "attention.x_" in name:
|
||||
if "attention.x_x" in name:
|
||||
# already concatenated
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
|
||||
yield (new_name, data)
|
||||
else:
|
||||
try:
|
||||
self.lerp_weights[bid][name] = data_torch
|
||||
except KeyError:
|
||||
self.lerp_weights[bid] = {name: data_torch}
|
||||
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
|
||||
yield (new_name, data)
|
||||
return
|
||||
else:
|
||||
data_torch = data_torch.squeeze()
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
|
||||
new_name += ".weight"
|
||||
|
||||
if self.lora_needs_transpose and any(
|
||||
new_name.endswith(t) for t in [
|
||||
"time_mix_w1.weight", "time_mix_w2.weight",
|
||||
"time_mix_a1.weight", "time_mix_a2.weight",
|
||||
"time_mix_v1.weight", "time_mix_v2.weight",
|
||||
"time_mix_g1.weight", "time_mix_g2.weight",
|
||||
]
|
||||
):
|
||||
data_torch = data_torch.transpose(0, 1)
|
||||
|
||||
if 'r_k' in new_name:
|
||||
data_torch = data_torch.flatten()
|
||||
|
||||
if bid == 0 and "time_mix_a" in new_name:
|
||||
# dummy v0/v1/v2 on first layer
|
||||
# easist way to make llama happy
|
||||
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("RwkvHybridForCausalLM")
|
||||
class ARwkv7Model(Rwkv7Model):
|
||||
model_arch = gguf.MODEL_ARCH.ARWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = self.hparams["head_size"]
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
intermediate_size = self.hparams["intermediate_size"]
|
||||
wkv_has_gate = self.hparams["wkv_has_gate"]
|
||||
assert self.hparams["wkv_version"] == 7
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
lora_rank_decay = 64
|
||||
lora_rank_iclr = 64
|
||||
lora_rank_value_residual_mix = 32
|
||||
lora_rank_gate = 128 if wkv_has_gate else 0
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_token_shift_count(1)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
@ -454,6 +454,7 @@ extern "C" {
|
||||
GGML_OP_RMS_NORM,
|
||||
GGML_OP_RMS_NORM_BACK,
|
||||
GGML_OP_GROUP_NORM,
|
||||
GGML_OP_L2_NORM,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
GGML_OP_MUL_MAT_ID,
|
||||
@ -502,6 +503,7 @@ extern "C" {
|
||||
GGML_OP_ADD_REL_POS,
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@ -1095,6 +1097,18 @@ extern "C" {
|
||||
int n_groups,
|
||||
float eps);
|
||||
|
||||
// l2 normalize along rows
|
||||
// used in rwkv v7
|
||||
GGML_API struct ggml_tensor * ggml_l2_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
// a - x
|
||||
// b - dy
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||
@ -1890,6 +1904,16 @@ extern "C" {
|
||||
struct ggml_tensor * state,
|
||||
float scale);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * r,
|
||||
struct ggml_tensor * w,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
|
@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_l2_norm
|
||||
|
||||
static void ggml_compute_forward_l2_norm_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_l2_norm(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_l2_norm_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mul_mat
|
||||
|
||||
static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv7
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const int64_t T = dst->src[1]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t HEADS = dst->src[1]->ne[1];
|
||||
const int64_t n_seqs = dst->src[6]->ne[1];
|
||||
const int64_t head_size = C / HEADS;
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * r = (float *) dst->src[0]->data;
|
||||
float * w = (float *) dst->src[1]->data;
|
||||
float * k = (float *) dst->src[2]->data;
|
||||
float * v = (float *) dst->src[3]->data;
|
||||
float * a = (float *) dst->src[4]->data;
|
||||
float * b = (float *) dst->src[5]->data;
|
||||
|
||||
int64_t t_stride = HEADS * head_size; // Same to C
|
||||
|
||||
int64_t h_stride = C / HEADS;
|
||||
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||
int64_t h_stride_2d = head_size * head_size;
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
int64_t t_offset = t * t_stride;
|
||||
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
int64_t h_offset = h * h_stride;
|
||||
int64_t t_h_offset = t_offset + h_offset;
|
||||
int64_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t ii = 0; ii < head_size; ii++) {
|
||||
int64_t t_h_i_offset = t_h_offset + ii;
|
||||
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
||||
|
||||
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
||||
|
||||
float sa = 0;
|
||||
{
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
||||
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
||||
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
||||
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
||||
}
|
||||
}
|
||||
GGML_F32_VEC_REDUCE(sa, sum);
|
||||
}
|
||||
|
||||
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
||||
|
||||
int64_t j = 0;
|
||||
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
for (; j < head_size; j += GGML_F32_STEP) {
|
||||
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
||||
|
||||
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
||||
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
||||
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
||||
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
||||
|
||||
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
||||
|
||||
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||
// kv + s * decay + sa * b
|
||||
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
||||
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
||||
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
||||
|
||||
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
||||
}
|
||||
}
|
||||
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
||||
|
||||
// There shouldn't be left-overs though.
|
||||
for (; j < head_size; j++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v[t_h_i_offset] * k_val;
|
||||
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
int64_t t_offset = t * t_stride;
|
||||
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
int64_t h_offset = h * h_stride;
|
||||
int64_t t_h_offset = t_offset + h_offset;
|
||||
int64_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t i = 0; i < head_size; i++) {
|
||||
int64_t t_h_i_offset = t_h_offset + i;
|
||||
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float v_val = v[t_h_i_offset];
|
||||
|
||||
float sa = 0, result = 0;
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||
result += state_cur[h_2d_i_j_offset] * r_val;
|
||||
}
|
||||
dst_data[t_h_i_offset] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_unary
|
||||
|
||||
static void ggml_compute_forward_map_unary_f32(
|
||||
@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_group_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
ggml_compute_forward_l2_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_compute_forward_mul_mat(params, tensor);
|
||||
@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_gla(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_UNARY:
|
||||
{
|
||||
ggml_unary_op_f32_t fun;
|
||||
@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_MUL_MAT:
|
||||
@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
@ -36,7 +36,7 @@
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GROUP_NORM:
|
||||
ggml_cuda_op_group_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_cuda_op_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_cuda_op_concat(ctx, dst);
|
||||
break;
|
||||
@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
@ -3161,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||
@ -3215,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
|
@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// template <int block_size>
|
||||
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
// const int tid = threadIdx.x;
|
||||
|
||||
// float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// const float xi = x[row*ncols + col];
|
||||
// tmp += xi * xi;
|
||||
// }
|
||||
|
||||
// // sum up partial sums
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// if (block_size > WARP_SIZE) {
|
||||
// __shared__ float s_sum[32];
|
||||
// int warp_id = threadIdx.x / WARP_SIZE;
|
||||
// int lane_id = threadIdx.x % WARP_SIZE;
|
||||
// if (lane_id == 0) {
|
||||
// s_sum[warp_id] = tmp;
|
||||
// }
|
||||
// __syncthreads();
|
||||
// tmp = s_sum[lane_id];
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// }
|
||||
|
||||
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||
// }
|
||||
// }
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void l2_norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float s_sum[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
||||
|
||||
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
199
ggml/src/ggml-cuda/wkv.cu
Normal file
199
ggml/src/ggml-cuda/wkv.cu
Normal file
@ -0,0 +1,199 @@
|
||||
#include "common.cuh"
|
||||
#include "wkv.cuh"
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
|
||||
|
||||
#ifndef GGML_USE_MUSA
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
__syncthreads();
|
||||
|
||||
float sa = 0;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4)
|
||||
{
|
||||
const float4& a = (float4&)(_a[j]);
|
||||
const float4& s = (float4&)(state[j]);
|
||||
sa += a.x * s.x;
|
||||
sa += a.y * s.y;
|
||||
sa += a.z * s.z;
|
||||
sa += a.w * s.w;
|
||||
}
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& w = (float4&)(_w[j]);
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& b = (float4&)(_b[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
s.x = s.x * w.x + kv.x + sa * b.x;
|
||||
s.y = s.y * w.y + kv.y + sa * b.y;
|
||||
s.z = s.z * w.z + kv.z + sa * b.z;
|
||||
s.w = s.w * w.w + kv.w + sa * b.w;
|
||||
|
||||
y += s.x * r.x;
|
||||
y += s.y * r.y;
|
||||
y += s.z * r.z;
|
||||
y += s.w * r.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * r_d = (const float *)dst->src[0]->data;
|
||||
const float * w_d = (const float *)dst->src[1]->data;
|
||||
const float * k_d = (const float *)dst->src[2]->data;
|
||||
const float * v_d = (const float *)dst->src[3]->data;
|
||||
const float * a_d = (const float *)dst->src[4]->data;
|
||||
const float * b_d = (const float *)dst->src[5]->data;
|
||||
const float * s_d = (const float *)dst->src[6]->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
}
|
||||
}
|
@ -3,3 +3,5 @@
|
||||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@ -1,89 +0,0 @@
|
||||
#include "common.cuh"
|
||||
#include "wkv6.cuh"
|
||||
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
@ -285,6 +285,13 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_rms_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||
@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
||||
@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_ARGMAX:
|
||||
return true;
|
||||
@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
||||
[encoder setBytes:&T length:sizeof(T) atIndex:8];
|
||||
[encoder setBytes:&C length:sizeof(C) atIndex:9];
|
||||
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
size_t offs_src6 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
|
||||
[encoder setBytes:&B length:sizeof(B) atIndex:8];
|
||||
[encoder setBytes:&T length:sizeof(T) atIndex:9];
|
||||
[encoder setBytes:&C length:sizeof(C) atIndex:10];
|
||||
[encoder setBytes:&H length:sizeof(H) atIndex:11];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv6_f32(
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * r,
|
||||
device const float * tf,
|
||||
device const float * td,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _tf[head_size];
|
||||
threadgroup float _td[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_tf[tid] = tf[head_id * head_size + tid];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0;
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
float4 temp = tf_vec * kv + s_vec;
|
||||
y += dot(r_vec, temp);
|
||||
|
||||
s_vec = s_vec * td_vec + kv;
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv7_f32(
|
||||
device const float * r,
|
||||
device const float * w,
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * a,
|
||||
device const float * b,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _w[head_size];
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _a[head_size];
|
||||
threadgroup float _b[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0, sa = 0.0;
|
||||
|
||||
float4 sa_vec(0.0);
|
||||
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa_vec += a_vec * s_vec;
|
||||
}
|
||||
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(s_vec, r_vec);
|
||||
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_argmax(
|
||||
device const void * x,
|
||||
device int32_t * dst,
|
||||
@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_l2_norm(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||
|
||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
@ -26,7 +26,7 @@
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "wkv6.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "cpy.hpp"
|
||||
|
@ -2696,6 +2696,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
||||
@ -3410,6 +3416,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_sycl_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_sycl_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
return false;
|
||||
@ -3487,6 +3496,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_sycl_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
@ -4012,6 +4024,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return (op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SCALE:
|
||||
@ -4045,6 +4058,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
default:
|
||||
|
@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row * ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
|
||||
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
/*
|
||||
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
tmp = 0.f;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row * ncols + col] = scale * x[row * ncols + col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
||||
ggml_tensor* dst, const float* src0_dd,
|
||||
const float* src1_dd, float* dst_dd,
|
||||
@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
|
||||
#endif // GGML_SYCL_NORM_HPP
|
||||
|
305
ggml/src/ggml-sycl/wkv.cpp
Normal file
305
ggml/src/ggml-sycl/wkv.cpp
Normal file
@ -0,0 +1,305 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "wkv.hpp"
|
||||
|
||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||
|
||||
// Helper function for the main kernel
|
||||
template <int block_size>
|
||||
static void rwkv_wkv6_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* k, const float* v, const float* r,
|
||||
const float* tf, const float* td, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
// Set up shared memory pointers
|
||||
float* _k = shared_mem;
|
||||
float* _r = _k + head_size;
|
||||
float* _tf = _r + head_size;
|
||||
float* _td = _tf + head_size;
|
||||
|
||||
// Local state array
|
||||
float state[block_size];
|
||||
|
||||
// Load initial state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
// Sync threads before shared memory operations
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load time-mixing parameters
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main sequence processing loop
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load current timestep data to shared memory
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
|
||||
// Process in chunks of 4 for better vectorization
|
||||
sycl::float4 k4, r4, tf4, td4, s4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
// Load data in vec4 chunks
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute key-value product
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
// Accumulate weighted sum
|
||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||
|
||||
// Update state
|
||||
s4 = s4 * td4 + kv4;
|
||||
|
||||
// Store updated state
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Save final state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static void rwkv_wkv7_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* r, const float* w, const float* k, const float* v,
|
||||
const float* a, const float* b, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float* _r = shared_mem;
|
||||
float* _w = _r + head_size;
|
||||
float* _k = _w + head_size;
|
||||
float* _a = _k + head_size;
|
||||
float* _b = _a + head_size;
|
||||
|
||||
float state[block_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0, sa = 0;
|
||||
sycl::float4 a4, s4;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa += sycl::dot(a4, s4);
|
||||
}
|
||||
|
||||
sycl::float4 r4, w4, k4, b4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
s4 = s4 * w4 + kv4 + sa * b4;
|
||||
y += sycl::dot(r4, s4);
|
||||
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
const float* r_d = (const float*)dst->src[2]->data;
|
||||
const float* tf_d = (const float*)dst->src[3]->data;
|
||||
const float* td_d = (const float*)dst->src[4]->data;
|
||||
const float* s_d = (const float*)dst->src[5]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* r_d = (const float*)dst->src[0]->data;
|
||||
const float* w_d = (const float*)dst->src[1]->data;
|
||||
const float* k_d = (const float*)dst->src[2]->data;
|
||||
const float* v_d = (const float*)dst->src[3]->data;
|
||||
const float* a_d = (const float*)dst->src[4]->data;
|
||||
const float* b_d = (const float*)dst->src[5]->data;
|
||||
const float* s_d = (const float*)dst->src[6]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
10
ggml/src/ggml-sycl/wkv.hpp
Normal file
10
ggml/src/ggml-sycl/wkv.hpp
Normal file
@ -0,0 +1,10 @@
|
||||
#ifndef GGML_SYCL_WKV_HPP
|
||||
#define GGML_SYCL_WKV_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_WKV_HPP
|
@ -1,143 +0,0 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "wkv6.hpp"
|
||||
|
||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||
|
||||
// Helper function for the main kernel
|
||||
static void rwkv_wkv_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* k, const float* v, const float* r,
|
||||
const float* tf, const float* td, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
// Set up shared memory pointers
|
||||
float* _k = shared_mem;
|
||||
float* _r = _k + head_size;
|
||||
float* _tf = _r + head_size;
|
||||
float* _td = _tf + head_size;
|
||||
|
||||
// Local state array
|
||||
float state[WKV_BLOCK_SIZE];
|
||||
|
||||
// Load initial state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
// Sync threads before shared memory operations
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load time-mixing parameters
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main sequence processing loop
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load current timestep data to shared memory
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
|
||||
// Process in chunks of 4 for better vectorization
|
||||
sycl::float4 k4, r4, tf4, td4, s4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
// Load data in vec4 chunks
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute key-value product
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
// Accumulate weighted sum
|
||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||
|
||||
// Update state
|
||||
s4 = s4 * td4 + kv4;
|
||||
|
||||
// Store updated state
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Save final state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
const float* r_d = (const float*)dst->src[2]->data;
|
||||
const float* tf_d = (const float*)dst->src[3]->data;
|
||||
const float* td_d = (const float*)dst->src[4]->data;
|
||||
const float* s_d = (const float*)dst->src[5]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv_f32_kernel(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
#ifndef GGML_SYCL_WKV6_HPP
|
||||
#define GGML_SYCL_WKV6_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
|
||||
#endif // GGML_SYCL_WKV6_HPP
|
@ -304,6 +304,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_group_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_back_f32;
|
||||
vk_pipeline pipeline_l2_norm_f32;
|
||||
vk_pipeline pipeline_gelu_f32;
|
||||
vk_pipeline pipeline_gelu_quick_f32;
|
||||
vk_pipeline pipeline_silu_f32;
|
||||
@ -328,6 +329,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
@ -629,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
struct vk_op_rwkv_wkv7_push_constants {
|
||||
uint32_t B;
|
||||
uint32_t T;
|
||||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
@ -2263,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@ -2374,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (auto &c : compiles) {
|
||||
@ -5473,6 +5485,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_rms_norm_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_L2_NORM:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_l2_norm_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
@ -5612,6 +5629,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_rwkv_wkv6_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
@ -5859,6 +5881,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
@ -6108,23 +6131,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * k = dst->src[0];
|
||||
const ggml_tensor * v = dst->src[1];
|
||||
const ggml_tensor * r = dst->src[2];
|
||||
const ggml_tensor * tf = dst->src[3];
|
||||
const ggml_tensor * td = dst->src[4];
|
||||
const ggml_tensor * state = dst->src[5];
|
||||
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
||||
GGML_ASSERT(version == 6 || version == 7);
|
||||
int num_srcs = version == 6 ? 6 : 7;
|
||||
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
||||
}
|
||||
|
||||
GGML_ASSERT(!ggml_is_quantized(k->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(v->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(r->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(td->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(state->type));
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
if (dryrun) {
|
||||
@ -6133,89 +6150,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
||||
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
||||
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
|
||||
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
|
||||
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
||||
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
||||
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
|
||||
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
|
||||
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
||||
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
||||
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
||||
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
||||
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
||||
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
||||
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||
}
|
||||
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
|
||||
K_uma = d_K != nullptr;
|
||||
V_uma = d_V != nullptr;
|
||||
R_uma = d_R != nullptr;
|
||||
TF_uma = d_TF != nullptr;
|
||||
TD_uma = d_TD != nullptr;
|
||||
STATE_uma = d_State != nullptr;
|
||||
DST_uma = d_D != nullptr;
|
||||
dst_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!K_uma) {
|
||||
d_K = k_buf_ctx->dev_buffer;
|
||||
k_offset = vk_tensor_offset(k) + k->view_offs;
|
||||
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||
if (!srcs_uma[i]) {
|
||||
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||
}
|
||||
}
|
||||
if (!V_uma) {
|
||||
d_V = v_buf_ctx->dev_buffer;
|
||||
v_offset = vk_tensor_offset(v) + v->view_offs;
|
||||
}
|
||||
if (!R_uma) {
|
||||
d_R = r_buf_ctx->dev_buffer;
|
||||
r_offset = vk_tensor_offset(r) + r->view_offs;
|
||||
}
|
||||
if (!TF_uma) {
|
||||
d_TF = tf_buf_ctx->dev_buffer;
|
||||
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
||||
}
|
||||
if (!TD_uma) {
|
||||
d_TD = td_buf_ctx->dev_buffer;
|
||||
td_offset = vk_tensor_offset(td) + td->view_offs;
|
||||
}
|
||||
if (!STATE_uma) {
|
||||
d_State = state_buf_ctx->dev_buffer;
|
||||
state_offset = vk_tensor_offset(state) + state->view_offs;
|
||||
}
|
||||
if (!DST_uma) {
|
||||
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
if (!dst_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
|
||||
const uint64_t k_size = ggml_nbytes(k);
|
||||
const uint64_t v_size = ggml_nbytes(v);
|
||||
const uint64_t r_size = ggml_nbytes(r);
|
||||
const uint64_t tf_size = ggml_nbytes(tf);
|
||||
const uint64_t td_size = ggml_nbytes(td);
|
||||
const uint64_t state_size = ggml_nbytes(state);
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
std::array<uint32_t, 3> elements = {
|
||||
(uint32_t)(pc.B * pc.H),
|
||||
1,
|
||||
1
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_K, k_offset, k_size },
|
||||
vk_subbuffer{ d_V, v_offset, v_size },
|
||||
vk_subbuffer{ d_R, r_offset, r_size },
|
||||
vk_subbuffer{ d_TF, tf_offset, tf_size },
|
||||
vk_subbuffer{ d_TD, td_offset, td_size },
|
||||
vk_subbuffer{ d_State, state_offset, state_size },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
if (version == 6) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
} else if (version == 7) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
||||
} else {
|
||||
// shouldn't happen
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
@ -6224,7 +6225,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_rwkv6(
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
@ -6232,6 +6233,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
6,
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[6]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
7,
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
@ -6533,6 +6554,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
@ -7528,6 +7554,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
@ -7544,6 +7571,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
@ -7590,6 +7618,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@ -7707,6 +7736,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
@ -7797,6 +7830,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
@ -7870,6 +7908,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
@ -7889,6 +7928,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
@ -8806,6 +8846,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
@ -8835,6 +8876,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
@ -9219,6 +9261,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
||||
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
||||
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_L2_NORM) {
|
||||
const float eps = ((float *) tensor->op_params)[0];
|
||||
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
||||
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
||||
if (src1 != nullptr) {
|
||||
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
@ -9338,6 +9383,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
||||
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
||||
src_clone[4], src_clone[5], src_clone[6]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
src_clone[0]->flags = src0->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
|
41
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
Normal file
41
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
Normal file
@ -0,0 +1,41 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||
}
|
||||
}
|
@ -434,6 +434,7 @@ void process_shaders() {
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
@ -528,6 +529,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
|
91
ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp
Normal file
91
ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp
Normal file
@ -0,0 +1,91 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#define BLOCK_SIZE 64
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint B;
|
||||
uint T;
|
||||
uint C;
|
||||
uint H;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
|
||||
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
|
||||
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
|
||||
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
|
||||
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
|
||||
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
|
||||
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
|
||||
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
|
||||
|
||||
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint head_size = BLOCK_SIZE;
|
||||
const uint batch_id = gl_WorkGroupID.x / H;
|
||||
const uint head_id = gl_WorkGroupID.x % H;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
A_TYPE state[BLOCK_SIZE];
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
barrier();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
barrier();
|
||||
|
||||
A_TYPE sa = 0.0;
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
sa += dot(s_vec, a_vec);
|
||||
}
|
||||
|
||||
const A_TYPE v_val = v[t];
|
||||
A_TYPE y = 0.0;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
vec4 kv = k_vec * v_val;
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(r_vec, s_vec);
|
||||
|
||||
state[j] = s_vec.x;
|
||||
state[j+1] = s_vec.y;
|
||||
state[j+2] = s_vec.z;
|
||||
state[j+3] = s_vec.w;
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"RMS_NORM",
|
||||
"RMS_NORM_BACK",
|
||||
"GROUP_NORM",
|
||||
"L2_NORM",
|
||||
|
||||
"MUL_MAT",
|
||||
"MUL_MAT_ID",
|
||||
@ -977,6 +978,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"ADD_REL_POS",
|
||||
"RWKV_WKV6",
|
||||
"GATED_LINEAR_ATTN",
|
||||
"RWKV_WKV7",
|
||||
|
||||
"UNARY",
|
||||
|
||||
@ -996,7 +998,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@ -1026,6 +1028,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"rms_norm(x)",
|
||||
"rms_norm_back(x)",
|
||||
"group_norm(x)",
|
||||
"l2_norm(x)",
|
||||
|
||||
"X*Y",
|
||||
"X[i]*Y",
|
||||
@ -1074,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"add_rel_pos(x)",
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
@ -1093,7 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@ -2686,6 +2690,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
|
||||
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
||||
}
|
||||
|
||||
// ggml_l2_norm
|
||||
|
||||
static struct ggml_tensor * ggml_l2_norm_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, eps);
|
||||
|
||||
result->op = GGML_OP_L2_NORM;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_l2_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps) {
|
||||
return ggml_l2_norm_impl(ctx, a, eps, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_l2_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps) {
|
||||
return ggml_l2_norm_impl(ctx, a, eps, true);
|
||||
}
|
||||
|
||||
// ggml_mul_mat
|
||||
|
||||
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
@ -4720,6 +4755,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_rwkv_wkv7
|
||||
|
||||
struct ggml_tensor * ggml_rwkv_wkv7(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * r,
|
||||
struct ggml_tensor * w,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state) {
|
||||
GGML_ASSERT(ggml_is_contiguous(r));
|
||||
GGML_ASSERT(ggml_is_contiguous(w));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(ggml_is_contiguous(b));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
{
|
||||
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
|
||||
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
}
|
||||
|
||||
// concat output and new_state
|
||||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_RWKV_WKV7;
|
||||
result->src[0] = r;
|
||||
result->src[1] = w;
|
||||
result->src[2] = k;
|
||||
result->src[3] = v;
|
||||
result->src[4] = a;
|
||||
result->src[5] = b;
|
||||
result->src[6] = state;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_unary
|
||||
|
||||
static struct ggml_tensor * ggml_unary_impl(
|
||||
|
@ -118,22 +118,26 @@ class Keys:
|
||||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
|
||||
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
|
||||
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
|
||||
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
@ -257,6 +261,8 @@ class MODEL_ARCH(IntEnum):
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
RWKV6QWEN2 = auto()
|
||||
RWKV7 = auto()
|
||||
ARWKV7 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
@ -329,8 +335,20 @@ class MODEL_TENSOR(IntEnum):
|
||||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
TIME_MIX_W0 = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
TIME_MIX_W2 = auto()
|
||||
TIME_MIX_A0 = auto()
|
||||
TIME_MIX_A1 = auto()
|
||||
TIME_MIX_A2 = auto()
|
||||
TIME_MIX_V0 = auto()
|
||||
TIME_MIX_V1 = auto()
|
||||
TIME_MIX_V2 = auto()
|
||||
TIME_MIX_G1 = auto()
|
||||
TIME_MIX_G2 = auto()
|
||||
TIME_MIX_K_K = auto()
|
||||
TIME_MIX_K_A = auto()
|
||||
TIME_MIX_R_K = auto()
|
||||
TIME_MIX_LERP_X = auto()
|
||||
TIME_MIX_LERP_K = auto()
|
||||
TIME_MIX_LERP_V = auto()
|
||||
@ -445,6 +463,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||
MODEL_ARCH.RWKV7: "rwkv7",
|
||||
MODEL_ARCH.ARWKV7: "arwkv7",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
@ -517,8 +537,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
|
||||
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
|
||||
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
|
||||
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
|
||||
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
|
||||
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
|
||||
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
|
||||
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
|
||||
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
|
||||
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
|
||||
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
|
||||
@ -1172,6 +1204,68 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.RWKV7: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM_2,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_W0,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_A0,
|
||||
MODEL_TENSOR.TIME_MIX_A1,
|
||||
MODEL_TENSOR.TIME_MIX_A2,
|
||||
MODEL_TENSOR.TIME_MIX_V0,
|
||||
MODEL_TENSOR.TIME_MIX_V1,
|
||||
MODEL_TENSOR.TIME_MIX_V2,
|
||||
MODEL_TENSOR.TIME_MIX_G1,
|
||||
MODEL_TENSOR.TIME_MIX_G2,
|
||||
MODEL_TENSOR.TIME_MIX_K_K,
|
||||
MODEL_TENSOR.TIME_MIX_K_A,
|
||||
MODEL_TENSOR.TIME_MIX_R_K,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY,
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE,
|
||||
],
|
||||
MODEL_ARCH.ARWKV7: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_W0,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_A0,
|
||||
MODEL_TENSOR.TIME_MIX_A1,
|
||||
MODEL_TENSOR.TIME_MIX_A2,
|
||||
MODEL_TENSOR.TIME_MIX_V0,
|
||||
MODEL_TENSOR.TIME_MIX_V1,
|
||||
MODEL_TENSOR.TIME_MIX_V2,
|
||||
MODEL_TENSOR.TIME_MIX_G1,
|
||||
MODEL_TENSOR.TIME_MIX_G2,
|
||||
MODEL_TENSOR.TIME_MIX_K_K,
|
||||
MODEL_TENSOR.TIME_MIX_K_A,
|
||||
MODEL_TENSOR.TIME_MIX_R_K,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.MAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -767,6 +767,18 @@ class GGUFWriter:
|
||||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_decay_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_iclr_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_value_residual_mix_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_gate_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_relative_attn_buckets_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
||||
|
||||
|
@ -27,7 +27,8 @@ class TensorNameMap:
|
||||
"embedding.word_embeddings", # chatglm
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
"rwkv.embeddings", # rwkv
|
||||
"rwkv.embeddings", # rwkv6
|
||||
"model.embeddings", # rwkv7
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@ -42,6 +43,9 @@ class TensorNameMap:
|
||||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
"rwkv.blocks.0.pre_ln", # rwkv
|
||||
"rwkv.blocks.0.pre_ln", # rwkv6
|
||||
"model.pre_ln", # rwkv7
|
||||
"model.layers.0.pre_norm", # rwkv7
|
||||
"backbone.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
@ -81,7 +85,8 @@ class TensorNameMap:
|
||||
"encoder.final_layernorm", # chatglm
|
||||
"transformer.norm", # openelm
|
||||
"model.norm", # nemotron
|
||||
"rwkv.ln_out", # rwkv
|
||||
"rwkv.ln_out", # rwkv6
|
||||
"model.ln_out", # rwkv7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
),
|
||||
|
||||
@ -122,14 +127,16 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
MODEL_TENSOR.ATTN_NORM_2: (
|
||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv6
|
||||
"model.layers.{bid}.ln2", # rwkv7
|
||||
),
|
||||
|
||||
# Attention query-key-value
|
||||
@ -462,112 +469,174 @@ class TensorNameMap:
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
"model.layers.{bid}.attention.w0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.w1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.w2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A0: (
|
||||
"model.layers.{bid}.attention.a0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A1: (
|
||||
"model.layers.{bid}.attention.a1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A2: (
|
||||
"model.layers.{bid}.attention.a2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V0: (
|
||||
"model.layers.{bid}.attention.v0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V1: (
|
||||
"model.layers.{bid}.attention.v1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V2: (
|
||||
"model.layers.{bid}.attention.v2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G1: (
|
||||
"model.layers.{bid}.attention.g1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G2: (
|
||||
"model.layers.{bid}.attention.g2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_K: (
|
||||
"model.layers.{bid}.attention.k_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_A: (
|
||||
"model.layers.{bid}.attention.k_a", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_R_K: (
|
||||
"model.layers.{bid}.attention.r_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_FIRST: (
|
||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv6
|
||||
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.key", # rwkv7
|
||||
"model.layers.{bid}.attention.k_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv6
|
||||
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.value", # rwkv7
|
||||
"model.layers.{bid}.attention.v_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv6
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.receptance", # rwkv7
|
||||
"model.layers.{bid}.attention.r_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_GATE: (
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv6
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LN: (
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv6
|
||||
"model.layers.{bid}.attention.ln_x" # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT: (
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv6
|
||||
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.output", # rwkv7
|
||||
"model.layers.{bid}.attention.o_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.x_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.key", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.value", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
|
@ -59,6 +59,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
@ -110,22 +112,26 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@ -1238,6 +1244,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ARWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE,
|
||||
{
|
||||
@ -1397,6 +1471,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
@ -1415,6 +1495,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
@ -1422,6 +1505,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
@ -63,6 +63,8 @@ enum llm_arch {
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_RWKV7,
|
||||
LLM_ARCH_ARWKV7,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
@ -127,6 +129,10 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_CAUSAL,
|
||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||
LLM_KV_ATTENTION_DECAY_LORA_RANK,
|
||||
LLM_KV_ATTENTION_ICLR_LORA_RANK,
|
||||
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
|
||||
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
@ -250,8 +256,20 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
LLM_TENSOR_TIME_MIX_A0,
|
||||
LLM_TENSOR_TIME_MIX_A1,
|
||||
LLM_TENSOR_TIME_MIX_A2,
|
||||
LLM_TENSOR_TIME_MIX_V0,
|
||||
LLM_TENSOR_TIME_MIX_V1,
|
||||
LLM_TENSOR_TIME_MIX_V2,
|
||||
LLM_TENSOR_TIME_MIX_G1,
|
||||
LLM_TENSOR_TIME_MIX_G2,
|
||||
LLM_TENSOR_TIME_MIX_K_K,
|
||||
LLM_TENSOR_TIME_MIX_K_A,
|
||||
LLM_TENSOR_TIME_MIX_R_K,
|
||||
LLM_TENSOR_TIME_MIX_LERP_X,
|
||||
LLM_TENSOR_TIME_MIX_LERP_W,
|
||||
LLM_TENSOR_TIME_MIX_LERP_K,
|
||||
|
@ -76,6 +76,10 @@ struct llama_hparams {
|
||||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
uint32_t token_shift_count = 2;
|
||||
uint32_t n_lora_decay = 0;
|
||||
uint32_t n_lora_iclr = 0;
|
||||
uint32_t n_lora_value_res_mix = 0;
|
||||
uint32_t n_lora_gate = 0;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
|
@ -32,6 +32,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_109M: return "109M";
|
||||
case LLM_TYPE_137M: return "137M";
|
||||
case LLM_TYPE_160M: return "160M";
|
||||
case LLM_TYPE_190M: return "190M";
|
||||
case LLM_TYPE_220M: return "220M";
|
||||
case LLM_TYPE_250M: return "250M";
|
||||
case LLM_TYPE_270M: return "270M";
|
||||
@ -48,6 +49,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_1_6B: return "1.6B";
|
||||
case LLM_TYPE_2B: return "2B";
|
||||
case LLM_TYPE_2_8B: return "2.8B";
|
||||
case LLM_TYPE_2_9B: return "2.9B";
|
||||
case LLM_TYPE_3B: return "3B";
|
||||
case LLM_TYPE_4B: return "4B";
|
||||
case LLM_TYPE_6B: return "6B";
|
||||
@ -1250,6 +1252,36 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
|
||||
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
|
||||
ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay);
|
||||
ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix);
|
||||
ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false);
|
||||
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 12: type = LLM_TYPE_190M; break;
|
||||
case 24:
|
||||
switch (hparams.n_embd) {
|
||||
case 1024: type = LLM_TYPE_450M; break;
|
||||
case 2048: type = LLM_TYPE_1_5B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 28:
|
||||
switch (hparams.n_embd) {
|
||||
case 1536: type = LLM_TYPE_1_5B; break;
|
||||
case 3584: type = LLM_TYPE_7B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
@ -3366,6 +3398,146 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// Block 0, LN0
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
||||
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
const int n_lora_decay = hparams.n_lora_decay;
|
||||
const int n_lora_iclr = hparams.n_lora_iclr;
|
||||
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
|
||||
const int n_lora_gate = hparams.n_lora_gate;
|
||||
const int attn_hidden_size = n_embd;
|
||||
const int ffn_size = hparams.n_ff_arr[0];
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
|
||||
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
|
||||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
|
||||
|
||||
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
|
||||
if (i == 0) {
|
||||
// actually not used
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
} else {
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0);
|
||||
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0);
|
||||
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
|
||||
|
||||
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
|
||||
|
||||
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
|
||||
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
|
||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||
|
||||
layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
|
||||
|
||||
layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
|
||||
layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
|
||||
}
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
const int n_lora_decay = hparams.n_lora_decay;
|
||||
const int n_lora_iclr = hparams.n_lora_iclr;
|
||||
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
|
||||
const int n_lora_gate = hparams.n_lora_gate;
|
||||
const int attn_hidden_size = n_embd;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
|
||||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
|
||||
|
||||
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
|
||||
if (i == 0) {
|
||||
// actually not used
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
} else {
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
try {
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
|
||||
} catch(std::runtime_error & e) {
|
||||
// ARWKV models may not have gate tensors
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
|
||||
|
||||
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
|
||||
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@ -10212,6 +10384,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto head_size = hparams.wkv_head_size;
|
||||
const auto n_head = n_embd / head_size;
|
||||
@ -10224,6 +10397,10 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
bool is_qrwkv = layer.time_mix_first == nullptr;
|
||||
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
|
||||
sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
|
||||
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur);
|
||||
|
||||
xxx = ggml_reshape_4d(
|
||||
@ -10366,7 +10543,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
cur = ggml_mul(ctx0, cur, g);
|
||||
cur = build_lora_mm(layer.time_mix_output, cur);
|
||||
|
||||
return cur;
|
||||
return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
|
||||
}
|
||||
};
|
||||
|
||||
@ -10389,6 +10566,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
@ -10422,9 +10600,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
token_shift = ggml_concat(ctx0,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
|
||||
ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
|
||||
@ -10432,6 +10607,18 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
);
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
|
||||
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
|
||||
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
|
||||
cur = ggml_scale(ctx0, cur, 0.5F);
|
||||
}
|
||||
@ -10444,12 +10631,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
@ -10481,10 +10662,9 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
@ -10508,6 +10688,13 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
@ -10532,10 +10719,358 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
|
||||
}
|
||||
|
||||
ggml_tensor * build_rwkv7_channel_mix(
|
||||
const llama_layer * layer,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
llm_arch arch) const {
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
switch (arch) {
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
|
||||
|
||||
ggml_tensor * k = ggml_sqr(
|
||||
ctx0,
|
||||
ggml_relu(
|
||||
ctx0,
|
||||
build_lora_mm(layer->channel_mix_key, xk)
|
||||
)
|
||||
);
|
||||
|
||||
cur = build_lora_mm(layer->channel_mix_value, k);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_rwkv7_time_mix(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto head_size = hparams.wkv_head_size;
|
||||
const auto head_count = n_embd / head_size;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
|
||||
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
|
||||
sx = ggml_repeat(ctx0, sx, dummy);
|
||||
|
||||
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
|
||||
|
||||
ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
||||
ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
||||
ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
|
||||
ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
|
||||
ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
|
||||
ggml_tensor * xg = has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr;
|
||||
|
||||
ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
|
||||
ggml_tensor * w = ggml_add(
|
||||
ctx0,
|
||||
ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
|
||||
layer.time_mix_w0
|
||||
);
|
||||
w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
|
||||
|
||||
ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
|
||||
ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
|
||||
if (first_layer_value == nullptr) {
|
||||
first_layer_value = v;
|
||||
} else {
|
||||
// Add the first layer value as a residual connection.
|
||||
v = ggml_add(ctx0, v,
|
||||
ggml_mul(ctx0,
|
||||
ggml_sub(ctx0, first_layer_value, v),
|
||||
ggml_sigmoid(ctx0, ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, layer.time_mix_v2, ggml_mul_mat(ctx0, layer.time_mix_v1, xv)),
|
||||
layer.time_mix_v0
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
ggml_tensor * g = nullptr;
|
||||
if (layer.time_mix_g1 && layer.time_mix_g2) {
|
||||
g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg)));
|
||||
}
|
||||
|
||||
ggml_tensor * a = ggml_sigmoid(ctx0,
|
||||
ggml_add(
|
||||
ctx0,
|
||||
ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)),
|
||||
layer.time_mix_a0
|
||||
)
|
||||
);
|
||||
|
||||
ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens);
|
||||
kk = ggml_l2_norm(ctx0, kk, 1e-12);
|
||||
|
||||
ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a);
|
||||
k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
|
||||
|
||||
r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
|
||||
w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
|
||||
k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
|
||||
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_copy_mask_state(
|
||||
gf, kv_self->v_l[il], state_copy, state_mask,
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
||||
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
||||
|
||||
ggml_build_forward_expand(
|
||||
gf,
|
||||
ggml_cpy(
|
||||
ctx0,
|
||||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self->v_l[il],
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
if (layer.time_mix_ln && layer.time_mix_ln_b) {
|
||||
// group norm with head_count groups
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens);
|
||||
cur = ggml_norm(ctx0, cur, 64e-5f);
|
||||
|
||||
// Convert back to regular vectors.
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
|
||||
} else {
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
}
|
||||
|
||||
ggml_tensor * rk = ggml_sum_rows(ctx0,
|
||||
ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
|
||||
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));
|
||||
|
||||
if (has_gating) {
|
||||
cur = ggml_mul(ctx0, cur, g);
|
||||
}
|
||||
cur = build_lora_mm(layer.time_mix_output, cur);
|
||||
|
||||
return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
||||
llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
||||
GGML_ASSERT(hparams.token_shift_count == 2);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * v_first = nullptr;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
ggml_tensor * state_mask = build_inp_s_mask();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
);
|
||||
|
||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
||||
ggml_tensor * x_prev = ggml_concat(
|
||||
ctx0,
|
||||
att_shift,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
|
||||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
|
||||
cb(ffn_norm, "ffn_norm", il);
|
||||
|
||||
x_prev = ggml_concat(
|
||||
ctx0,
|
||||
ffn_shift,
|
||||
ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0),
|
||||
1
|
||||
);
|
||||
|
||||
token_shift = ggml_concat(ctx0,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
|
||||
ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
|
||||
1
|
||||
);
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
|
||||
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * v_first = nullptr;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
ggml_tensor * state_mask = build_inp_s_mask();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
);
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
||||
ggml_tensor * x_prev = ggml_concat(
|
||||
ctx0,
|
||||
token_shift,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
|
||||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
|
||||
|
||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
@ -10883,9 +11418,11 @@ llama_memory_i * llama_model::create_memory() const {
|
||||
llama_memory_i * res;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
res = new llama_kv_cache_unified(hparams, {
|
||||
/*.get_rope_factors =*/ nullptr
|
||||
@ -11132,6 +11669,14 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
llm = std::make_unique<llm_build_rwkv7>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
{
|
||||
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
|
||||
@ -11245,6 +11790,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_JAIS:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
|
||||
@ -11399,6 +11946,8 @@ bool llama_model_is_recurrent(const llama_model * model) {
|
||||
case LLM_ARCH_MAMBA: return true;
|
||||
case LLM_ARCH_RWKV6: return true;
|
||||
case LLM_ARCH_RWKV6QWEN2: return true;
|
||||
case LLM_ARCH_RWKV7: return true;
|
||||
case LLM_ARCH_ARWKV7: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ enum llm_type {
|
||||
LLM_TYPE_109M,
|
||||
LLM_TYPE_137M,
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_190M,
|
||||
LLM_TYPE_220M,
|
||||
LLM_TYPE_250M,
|
||||
LLM_TYPE_270M,
|
||||
@ -45,6 +46,7 @@ enum llm_type {
|
||||
LLM_TYPE_1_6B,
|
||||
LLM_TYPE_2B,
|
||||
LLM_TYPE_2_8B,
|
||||
LLM_TYPE_2_9B,
|
||||
LLM_TYPE_3B,
|
||||
LLM_TYPE_4B,
|
||||
LLM_TYPE_6B,
|
||||
@ -260,6 +262,20 @@ struct llama_layer {
|
||||
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
||||
struct ggml_tensor * time_mix_gate = nullptr;
|
||||
|
||||
// rwkv7
|
||||
struct ggml_tensor * time_mix_w0 = nullptr;
|
||||
struct ggml_tensor * time_mix_a0 = nullptr;
|
||||
struct ggml_tensor * time_mix_a1 = nullptr;
|
||||
struct ggml_tensor * time_mix_a2 = nullptr;
|
||||
struct ggml_tensor * time_mix_v0 = nullptr;
|
||||
struct ggml_tensor * time_mix_v1 = nullptr;
|
||||
struct ggml_tensor * time_mix_v2 = nullptr;
|
||||
struct ggml_tensor * time_mix_g1 = nullptr;
|
||||
struct ggml_tensor * time_mix_g2 = nullptr;
|
||||
struct ggml_tensor * time_mix_k_k = nullptr;
|
||||
struct ggml_tensor * time_mix_k_a = nullptr;
|
||||
struct ggml_tensor * time_mix_r_k = nullptr;
|
||||
|
||||
struct ggml_tensor * time_mix_ln = nullptr;
|
||||
struct ggml_tensor * time_mix_ln_b = nullptr;
|
||||
struct ggml_tensor * time_mix_output = nullptr;
|
||||
|
@ -756,10 +756,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
// do not quantize RWKV's small yet 2D weights
|
||||
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w0.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_v0.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_v1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_v2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_a0.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_a1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_a2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_g1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_g2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
||||
|
@ -1916,6 +1916,40 @@ struct test_gla : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RWKV_WKV7
|
||||
struct test_rwkv_wkv7 : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
const int64_t head_size;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
||||
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
// Outputs may become NaN with long seqlen without these normalization
|
||||
a = ggml_l2_norm(ctx, a, 1e-7F);
|
||||
b = ggml_l2_norm(ctx, b, 1e-7F);
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
struct test_mul_mat : public test_case {
|
||||
const ggml_type type_a;
|
||||
@ -2972,6 +3006,32 @@ struct test_group_norm : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_L2_NORM
|
||||
struct test_l2_norm : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_l2_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||
float eps = 1e-12f)
|
||||
: type(type), ne(ne), eps(eps) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ACC
|
||||
struct test_acc : public test_case {
|
||||
const ggml_type type;
|
||||
@ -4006,8 +4066,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
||||
}
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
||||
@ -4019,6 +4082,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
|
Loading…
x
Reference in New Issue
Block a user