mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-18 04:26:09 +00:00
llava : Add Granite Vision Support (#11794)
* Add super wip scripts for multimodal granite gguf Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add example for converting mmgranite to gguf Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * remove hardcoded path Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add vision feature layer to gguf params Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Clean up llava surgery and remove name substitution hacks Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add transformers llava next tensor name mapping Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Make siglip / openclip mutuall exclusive Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Fix projector linear substitution Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Fix linear 2 substitution index Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Increase max flattened gridpoints to 64 Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Fix hardcoded concat for multiple feature layers Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Pull vision feature layers out of gguf keys Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * fix num gridpoints and use all layers Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Avoid dropping last image encoder layer in llava models Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use 10 for max number of patches Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Standardize vision feature layers Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Cleanup logs Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Update comment for vision feature layer init Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Update notes for alternative to legacy llm conversion script Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Fix notes rendering Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add v prefix to vision feature layer log Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use current defaults for feature layer Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use constant for max gridpoints / feat layers, style fixes Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * clarify non-negative feature layers Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Remove CLIP_API from func signature Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * USE MAX_IMAGE_FEATURE_LAYERS const in layer calc Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Clarify feature layers are non negative ints and not uint Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Fix condition for reading feature layers Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * pop last llava layer when feature layers are unset Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Fix unset vision layer 0 Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Update examples/llava/clip.cpp Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> * Reenable assertion for out of bounds get_rows Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use std vector for gridpoints and feature layers Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Caculate max feature layer at load time Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Include base patch for granite vision allocation Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Fix trailing whitespace Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Add max num patches = 10 back for minicpmv Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use unordered set to store feature layers Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Use max feature layer for postnorm Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Apply suggestions from code review --------- Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
08d5986290
commit
7a2c913e66
@ -101,8 +101,27 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
|
||||
```
|
||||
|
||||
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
|
||||
|
||||
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
|
||||
|
||||
**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
|
||||
|
||||
```python
|
||||
import os
|
||||
import transformers
|
||||
|
||||
model_path = ...
|
||||
llm_export_path = ...
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
|
||||
|
||||
tokenizer.save_pretrained(llm_export_path)
|
||||
model.language_model.save_pretrained(llm_export_path)
|
||||
```
|
||||
|
||||
Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
|
||||
|
||||
## llava-cli templating and llava-1.6 prompting
|
||||
|
||||
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
|
||||
|
@ -40,6 +40,7 @@
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <cinttypes>
|
||||
@ -120,6 +121,7 @@ static std::string format(const char * fmt, ...) {
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
@ -444,8 +446,9 @@ struct clip_hparams {
|
||||
|
||||
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
|
||||
|
||||
int32_t image_grid_pinpoints[32];
|
||||
std::vector<int32_t> image_grid_pinpoints;
|
||||
int32_t image_crop_resolution;
|
||||
std::unordered_set<int32_t> vision_feature_layer;
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
@ -585,6 +588,7 @@ struct clip_ctx {
|
||||
struct clip_vision_model vision_model;
|
||||
projector_type proj_type = PROJECTOR_TYPE_MLP;
|
||||
|
||||
int32_t max_feature_layer;
|
||||
float image_mean[3];
|
||||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
@ -651,7 +655,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
@ -752,13 +755,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor *> embedding_stack;
|
||||
const auto & vision_feature_layer = hparams.vision_feature_layer;
|
||||
|
||||
// loop over layers
|
||||
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
|
||||
n_layer += 1;
|
||||
}
|
||||
for (int il = 0; il < n_layer - 1; il++) {
|
||||
for (int il = 0; il < ctx->max_feature_layer; il++) {
|
||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// If this is an embedding feature layer, save the output.
|
||||
// NOTE: 0 index here refers to the input to the encoder.
|
||||
if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
|
||||
embedding_stack.push_back(embeddings);
|
||||
}
|
||||
|
||||
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
|
||||
|
||||
// layernorm1
|
||||
@ -846,7 +855,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
@ -857,6 +865,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||
}
|
||||
|
||||
// final layer is a vision feature layer
|
||||
if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
|
||||
embedding_stack.push_back(embeddings);
|
||||
}
|
||||
|
||||
// If feature layers are explicitly set, stack them (if we have multiple)
|
||||
if (!embedding_stack.empty()) {
|
||||
embeddings = embedding_stack[0];
|
||||
for (size_t i = 1; i < embedding_stack.size(); i++) {
|
||||
embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
|
||||
}
|
||||
}
|
||||
|
||||
// llava projector
|
||||
if (ctx->has_llava_projector) {
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
||||
@ -1443,14 +1464,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
|
||||
int n = gguf_get_arr_n(ctx, idx);
|
||||
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
|
||||
for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
|
||||
hparams.image_grid_pinpoints[i] = pinpoints[i];
|
||||
for (int i = 0; i < n; ++i) {
|
||||
hparams.image_grid_pinpoints.push_back(pinpoints[i]);
|
||||
}
|
||||
if (n < 32)
|
||||
hparams.image_grid_pinpoints[n] = 0;
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
hparams.image_grid_pinpoints[0]=0;
|
||||
}
|
||||
} catch (std::runtime_error & /*e*/) { }
|
||||
|
||||
// Load the vision feature layer indices if they are explicitly provided;
|
||||
// if multiple vision feature layers are present, the values will be concatenated
|
||||
// to form the final visual features.
|
||||
// NOTE: gguf conversions should standardize the values of the vision feature layer to
|
||||
// be non-negative, since we use -1 to mark values as unset here.
|
||||
try {
|
||||
int idx = get_key_idx(ctx, KEY_FEATURE_LAYER);
|
||||
int n = gguf_get_arr_n(ctx, idx);
|
||||
|
||||
const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx);
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
hparams.vision_feature_layer.insert(vision_feature_layer[i]);
|
||||
}
|
||||
} catch (std::runtime_error & /*e*/) { }
|
||||
|
||||
try {
|
||||
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
|
||||
@ -1476,6 +1509,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
new_clip->image_std[i] = std_data[i];
|
||||
}
|
||||
|
||||
// Calculate the deepest feature layer based on hparams and projector type
|
||||
new_clip->max_feature_layer = get_deepest_feature_layer(new_clip);
|
||||
|
||||
if (verbosity >= 2) {
|
||||
LOG_INF("\n%s: vision model hparams\n", __func__);
|
||||
LOG_INF("image_size %d\n", hparams.image_size);
|
||||
@ -1489,8 +1525,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
|
||||
LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
|
||||
LOG_INF("v_image_grid_pinpoints: ");
|
||||
for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
|
||||
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
|
||||
for (const auto & pp : hparams.image_grid_pinpoints) {
|
||||
LOG_INF("%d ", pp);
|
||||
}
|
||||
LOG_INF("\n");
|
||||
LOG_INF("v_vision_feature_layer: ");
|
||||
for (const auto & feature_layer: hparams.vision_feature_layer) {
|
||||
LOG_INF("%d ", feature_layer);
|
||||
}
|
||||
LOG_INF("\n");
|
||||
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
|
||||
@ -2235,10 +2276,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (params.image_grid_pinpoints[0] != 0) {
|
||||
if (!params.image_grid_pinpoints.empty()) {
|
||||
// "spatial_unpad" with "anyres" processing for llava-1.6
|
||||
std::vector<std::pair<int, int>> possible_resolutions;
|
||||
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
|
||||
for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) {
|
||||
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
|
||||
}
|
||||
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
|
||||
@ -2404,7 +2445,14 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
||||
}
|
||||
|
||||
const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_grid_pinpoints;
|
||||
if (ctx->vision_model.hparams.image_grid_pinpoints.size()) {
|
||||
return &ctx->vision_model.hparams.image_grid_pinpoints.front();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_grid_pinpoints.size();
|
||||
}
|
||||
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
@ -2929,6 +2977,28 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||
return ctx->has_qwen2vl_merger;
|
||||
}
|
||||
|
||||
// Determine the number of encoder layers to iterate over
|
||||
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
|
||||
// Get the index of the second to last layer; this is the
|
||||
// default for models that have a llava projector
|
||||
const auto & hparams = ctx->vision_model.hparams;
|
||||
int n_layer = hparams.n_layer - 1;
|
||||
int deepest_feature_layer = -1;
|
||||
|
||||
// Handle other projectors; incrementing here indicates that we
|
||||
// should use the last encoder layer for the vision features.
|
||||
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
|
||||
n_layer += 1;
|
||||
}
|
||||
|
||||
// If we set explicit vision feature layers, only go up to the deepest one
|
||||
for (const auto & feature_layer : hparams.vision_feature_layer) {
|
||||
if (feature_layer > deepest_feature_layer) {
|
||||
deepest_feature_layer = feature_layer;
|
||||
}
|
||||
}
|
||||
return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
|
||||
}
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
|
@ -55,6 +55,7 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
|
||||
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
@ -92,11 +93,13 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
|
||||
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
|
||||
|
||||
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import re
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipVisionModel
|
||||
|
||||
TEXT = "clip.text"
|
||||
VISION = "clip.vision"
|
||||
@ -37,6 +37,18 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
|
||||
|
||||
|
||||
def get_tensor_name(name: str) -> str:
|
||||
# Standardize the transformers llava next keys for
|
||||
# image newline / mm projector with the classes in haotian-liu LLaVA
|
||||
if name == "image_newline":
|
||||
return "model.image_newline"
|
||||
if name.startswith("multi_modal_projector"):
|
||||
name = name.replace("multi_modal_projector", "mm")
|
||||
if "linear_1" in name:
|
||||
name = name.replace("linear_1", "0")
|
||||
if "linear_2" in name:
|
||||
name = name.replace("linear_2", "2")
|
||||
return name
|
||||
|
||||
if "projection" in name:
|
||||
return name
|
||||
if "mm_projector" in name:
|
||||
@ -83,8 +95,14 @@ ap.add_argument("--vision-only", action="store_true", required=False,
|
||||
help="Save a vision-only model. It can't be used to encode texts")
|
||||
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
|
||||
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
|
||||
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||
|
||||
# Selectable visual encoders that are compatible with this script
|
||||
encoder_group = ap.add_mutually_exclusive_group()
|
||||
encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||
help="The clip model is from openclip (for ViT-SO400M type))")
|
||||
encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False,
|
||||
help="the visual encoder is Siglip.")
|
||||
|
||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
|
||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||
@ -109,7 +127,12 @@ if args.use_f32:
|
||||
# output in the same directory as the model if output_dir is None
|
||||
dir_model = args.model_dir
|
||||
|
||||
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
|
||||
if (
|
||||
args.clip_model_is_vision or
|
||||
not os.path.exists(dir_model + "/vocab.json") or
|
||||
args.clip_model_is_openclip or
|
||||
args.clip_model_is_siglip
|
||||
):
|
||||
vocab = None
|
||||
tokens = None
|
||||
else:
|
||||
@ -137,7 +160,10 @@ ftype = 1
|
||||
if args.use_f32:
|
||||
ftype = 0
|
||||
|
||||
if args.clip_model_is_vision or args.clip_model_is_openclip:
|
||||
if args.clip_model_is_siglip:
|
||||
model = SiglipVisionModel.from_pretrained(dir_model)
|
||||
processor = None
|
||||
elif args.clip_model_is_vision or args.clip_model_is_openclip:
|
||||
model = CLIPVisionModel.from_pretrained(dir_model)
|
||||
processor = None
|
||||
else:
|
||||
@ -187,26 +213,71 @@ else:
|
||||
if has_text_encoder:
|
||||
assert t_hparams is not None
|
||||
assert tokens is not None
|
||||
if args.clip_model_is_siglip:
|
||||
text_projection_dim = 0
|
||||
else:
|
||||
text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"])
|
||||
# text_model hparams
|
||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
|
||||
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
|
||||
fout.add_uint32("clip.text.projection_dim", text_projection_dim)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
|
||||
fout.add_token_list(tokens)
|
||||
|
||||
|
||||
|
||||
def get_non_negative_vision_feature_layers(v_hparams):
|
||||
"""
|
||||
Determine the vision feature layer(s) for the llava model, which are indices into the
|
||||
hidden states of the visual encoder. Note that the hidden states array generally takes the
|
||||
form:
|
||||
|
||||
[<emb input>, <output of enc block 0>, ... <output of enc block num_hidden_layers>]
|
||||
|
||||
so feature indices should be offset as n+1 to get the output of encoder block n.
|
||||
We convert all vision feature layers to non-negative so that -1 can be used in
|
||||
the model as an unset value. If no vision feature layer is found, we leave it unset.
|
||||
"""
|
||||
num_hidden_layers = v_hparams["num_hidden_layers"]
|
||||
to_non_negative = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
|
||||
feature_layers_key = None
|
||||
# Key used for llava models in transformers
|
||||
if "vision_feature_layer" in config:
|
||||
feature_layers_key = "vision_feature_layer"
|
||||
# Key used for llava models in the original format
|
||||
elif "mm_vision_select_layer" in config:
|
||||
feature_layers_key = "mm_vision_select_layer"
|
||||
if feature_layers_key is not None:
|
||||
feature_layers = config[feature_layers_key]
|
||||
if isinstance(feature_layers, int):
|
||||
feature_layers = [feature_layers]
|
||||
return [to_non_negative(feature_layer) for feature_layer in feature_layers]
|
||||
|
||||
# Determine if we have explicitly specified vision feature layers in our config
|
||||
feature_layers = get_non_negative_vision_feature_layers(v_hparams)
|
||||
|
||||
if has_vision_encoder:
|
||||
# vision_model hparams
|
||||
# Siglip does not have a visual projector; set projection dim to 0
|
||||
if args.clip_model_is_siglip:
|
||||
visual_projection_dim = 0
|
||||
else:
|
||||
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
|
||||
|
||||
# set vision_model hparams
|
||||
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
|
||||
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
|
||||
fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"]))
|
||||
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
||||
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||
if feature_layers:
|
||||
block_count = max(feature_layers)
|
||||
else:
|
||||
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||
# /**
|
||||
# "image_grid_pinpoints": [
|
||||
@ -258,7 +329,8 @@ if has_vision_encoder:
|
||||
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
|
||||
if "mm_projector_type" in v_hparams:
|
||||
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
|
||||
|
||||
if feature_layers:
|
||||
fout.add_array("clip.vision.feature_layer", feature_layers)
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
|
||||
@ -274,7 +346,13 @@ fout.add_bool("clip.use_gelu", use_gelu)
|
||||
|
||||
|
||||
if has_llava_projector:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
# By default, we drop the last layer for llava projector
|
||||
# models unless we have explicitly set vision feature layers
|
||||
if feature_layers is None:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
else:
|
||||
model.vision_model.encoder.layers = model.vision_model.encoder.layers[:max(feature_layers)]
|
||||
|
||||
projector = torch.load(args.llava_projector)
|
||||
for name, data in projector.items():
|
||||
name = get_tensor_name(name)
|
||||
|
@ -353,9 +353,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
|
||||
const int32_t * image_grid = clip_image_grid(ctx_clip);
|
||||
const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
|
||||
|
||||
std::vector<std::pair<int, int>> grid_pinpoints;
|
||||
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
|
||||
for (size_t i = 0; i < num_gridpoints; i += 2) {
|
||||
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
|
||||
}
|
||||
|
||||
@ -405,7 +406,8 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx *
|
||||
}
|
||||
|
||||
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
|
||||
int num_max_patches = 6;
|
||||
// Granite vision uses up to 10 patches + base patch
|
||||
int num_max_patches = 11;
|
||||
if (clip_is_minicpmv(ctx_clip)) {
|
||||
num_max_patches = 10;
|
||||
}
|
||||
|
@ -33,6 +33,33 @@ def save_model(model, file_path, file_type):
|
||||
else:
|
||||
torch.save(model, file_path)
|
||||
|
||||
# Helpers to match weight names from specific components or
|
||||
# determine if a saved shard contains that component
|
||||
def is_vision_tower(weight_name):
|
||||
return (
|
||||
weight_name.startswith("model.vision_tower") or
|
||||
weight_name.startswith("vit.") or
|
||||
weight_name.startswith("vision_tower")
|
||||
)
|
||||
|
||||
def is_newline(weight_name):
|
||||
return (
|
||||
weight_name.startswith("model.image_newline") or
|
||||
weight_name.startswith("image_newline")
|
||||
)
|
||||
|
||||
def is_mm_projector(weight_name):
|
||||
return (
|
||||
weight_name.startswith("model.mm_projector") or
|
||||
weight_name.startswith("vision_proj.") or
|
||||
weight_name.startswith("multi_modal_projector")
|
||||
)
|
||||
|
||||
def newline_criteria(checkpoint):
|
||||
return any(is_newline(k) for k in checkpoint.keys())
|
||||
|
||||
def proj_criteria(checkpoint):
|
||||
return any(is_mm_projector(k) for k in checkpoint.keys())
|
||||
|
||||
# Adapted function to clean vision tower from checkpoint
|
||||
def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||
@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||
# file_type = 'pytorch'
|
||||
model_path = os.path.dirname(checkpoint_path)
|
||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
|
||||
clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
|
||||
|
||||
if len(clip_tensors) > 0:
|
||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||
@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
|
||||
|
||||
return newline_checkpoint_path, projector_checkpoint_path
|
||||
|
||||
def newline_criteria(checkpoint):
|
||||
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
|
||||
|
||||
def proj_criteria(checkpoint):
|
||||
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
|
||||
|
||||
|
||||
# Command-line interface setup
|
||||
ap = argparse.ArgumentParser()
|
||||
@ -123,14 +144,14 @@ first_checkpoint = None
|
||||
if newline_checkpoint_path is not None:
|
||||
print(f"Taking newline from {newline_checkpoint_path}")
|
||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
|
||||
|
||||
# Load the checkpoint
|
||||
mm_tensors = []
|
||||
last_checkpoint = None
|
||||
if projector_checkpoint_path is not None:
|
||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
|
||||
|
||||
if len(mm_tensors) == 0:
|
||||
if last_checkpoint is not None:
|
||||
@ -155,5 +176,5 @@ if len(projector) > 0:
|
||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
|
||||
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
|
||||
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user