mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 03:26:08 +00:00
Merge branch 'master' into compilade/parallel-convert
This commit is contained in:
commit
d7db1593ee
@ -21,7 +21,7 @@ COPY . .
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
|
@ -17,7 +17,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
|
||||
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \
|
||||
fi && \
|
||||
echo "Building with dynamic libs" && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16} && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${OPT_SYCL_F16} && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
|
@ -1,4 +1,4 @@
|
||||
ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8
|
||||
ARG ASCEND_VERSION=8.1.RC1.alpha001-910b-openeuler22.03-py3.10
|
||||
|
||||
FROM ascendai/cann:$ASCEND_VERSION AS build
|
||||
|
||||
@ -6,7 +6,7 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN yum install -y gcc g++ cmake make
|
||||
RUN yum install -y gcc g++ cmake make libcurl-devel
|
||||
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
|
||||
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
|
||||
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
|
||||
|
@ -35,7 +35,7 @@ COPY . .
|
||||
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
|
@ -17,8 +17,8 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
# gfx906 is deprecated
|
||||
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.4/reference/system-requirements.html
|
||||
|
||||
#ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102'
|
||||
ARG ROCM_DOCKER_ARCH=gfx1100
|
||||
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102'
|
||||
#ARG ROCM_DOCKER_ARCH=gfx1100
|
||||
|
||||
# Set nvcc architectured
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
@ -40,7 +40,7 @@ WORKDIR /app
|
||||
COPY . .
|
||||
|
||||
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DCMAKE_BUILD_TYPE=Release \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON \
|
||||
&& cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib \
|
||||
|
@ -16,7 +16,7 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 && \
|
||||
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
|
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@ -1771,7 +1771,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
cann:
|
||||
- '8.0.rc3.beta1-910b-openeuler22.03-py3.10'
|
||||
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
|
||||
device:
|
||||
- 'ascend910b3'
|
||||
build:
|
||||
@ -1784,7 +1784,7 @@ jobs:
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake
|
||||
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
|
12
.github/workflows/docker.yml
vendored
12
.github/workflows/docker.yml
vendored
@ -36,13 +36,13 @@ jobs:
|
||||
matrix:
|
||||
config:
|
||||
# Multi-stage build
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
|
||||
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false }
|
||||
# Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete
|
||||
#- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: true }
|
||||
#- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: true }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v4
|
||||
|
13
README.md
13
README.md
@ -9,13 +9,6 @@
|
||||
|
||||
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
|
||||
|
||||
> [!IMPORTANT]
|
||||
> New `llama.cpp` package location: [ggml-org/llama.cpp](https://github.com/ggml-org/llama.cpp/pkgs/container/llama.cpp)
|
||||
>
|
||||
> Update your container URLs to: `ghcr.io/ggml-org/llama.cpp`
|
||||
>
|
||||
> More info: https://github.com/ggml-org/llama.cpp/discussions/11801
|
||||
|
||||
## Recent API changes
|
||||
|
||||
- [Changelog for `libllama` API](https://github.com/ggml-org/llama.cpp/issues/9289)
|
||||
@ -104,6 +97,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [Flan T5](https://huggingface.co/models?search=flan-t5)
|
||||
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
|
||||
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat)
|
||||
- [x] [GLM-4-0414](https://huggingface.co/collections/THUDM/glm-4-0414-67f3cbcb34dd9d252707cb2e)
|
||||
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
|
||||
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
|
||||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||
@ -247,6 +241,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
| [Vulkan](docs/build.md#vulkan) | GPU |
|
||||
| [CANN](docs/build.md#cann) | Ascend NPU |
|
||||
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
|
||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/examples/rpc) | All |
|
||||
|
||||
## Building the project
|
||||
|
||||
@ -265,7 +260,9 @@ The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](htt
|
||||
- [Trending](https://huggingface.co/models?library=gguf&sort=trending)
|
||||
- [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf)
|
||||
|
||||
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from Hugging Face by using this CLI argument: `-hf <user>/<model>[:quant]`
|
||||
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf <user>/<model>[:quant]`.
|
||||
|
||||
By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. For example, you may opt to downloading model checkpoints from ModelScope or other model sharing communities by setting the environment variable, e.g. `MODEL_ENDPOINT=https://www.modelscope.cn/`.
|
||||
|
||||
After downloading a model, use the CLI tools to run it locally - see below.
|
||||
|
||||
|
@ -41,6 +41,11 @@ COMMON_CMAKE_ARGS=(
|
||||
-DGGML_OPENMP=${GGML_OPENMP}
|
||||
)
|
||||
|
||||
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||
echo "Detected Xcode version: $XCODE_VERSION"
|
||||
|
||||
check_required_tool() {
|
||||
local tool=$1
|
||||
local install_message=$2
|
||||
@ -325,21 +330,28 @@ combine_static_libraries() {
|
||||
|
||||
# Platform-specific post-processing for device builds
|
||||
if [[ "$is_simulator" == "false" ]]; then
|
||||
if command -v vtool &>/dev/null; then
|
||||
if command -v xcrun vtool &>/dev/null; then
|
||||
case "$platform" in
|
||||
"ios")
|
||||
echo "Marking binary as a framework binary for iOS..."
|
||||
vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
|
||||
xcrun vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
|
||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
||||
;;
|
||||
"visionos")
|
||||
echo "Marking binary as a framework binary for visionOS..."
|
||||
vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
|
||||
if [[ "$MAJOR_VERSION" -gt 16 ]] || [[ "$MAJOR_VERSION" -eq 16 && "$MINOR_VERSION" -gt 2 ]]; then
|
||||
echo "Xcode version greater than 16.2, using visionOS."
|
||||
VISION_OS_BUILD_VERSION="visionos"
|
||||
else
|
||||
echo "Xcode version less than or equal to 16.2, using xros."
|
||||
VISION_OS_BUILD_VERSION="xros"
|
||||
fi
|
||||
xcrun vtool -set-build-version ${VISION_OS_BUILD_VERSION} ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
|
||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
||||
;;
|
||||
"tvos")
|
||||
echo "Marking binary as a framework binary for tvOS..."
|
||||
vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
|
||||
xcrun vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
|
||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
||||
;;
|
||||
esac
|
||||
|
@ -228,12 +228,13 @@ static bool common_download_file_single(const std::string & url, const std::stri
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
|
||||
// Check if hf-token or bearer-token was specified
|
||||
if (!bearer_token.empty()) {
|
||||
std::string auth_header = "Authorization: Bearer " + bearer_token;
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
}
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
|
||||
#if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
@ -544,7 +545,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
|
||||
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
curl_slist_ptr http_headers;
|
||||
std::string res_str;
|
||||
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
|
||||
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
|
||||
std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
|
||||
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
|
||||
@ -659,13 +663,8 @@ static void common_params_handle_model(
|
||||
}
|
||||
}
|
||||
|
||||
std::string hf_endpoint = "https://huggingface.co/";
|
||||
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
|
||||
if (hf_endpoint_env) {
|
||||
hf_endpoint = hf_endpoint_env;
|
||||
if (hf_endpoint.back() != '/') hf_endpoint += '/';
|
||||
}
|
||||
model.url = hf_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
|
||||
// make sure model path is present (for caching purposes)
|
||||
if (model.path.empty()) {
|
||||
// this is to avoid different repo having same file name, or same file name in different subdirs
|
||||
|
@ -1027,6 +1027,19 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
return iparams;
|
||||
}
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
|
||||
const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env;
|
||||
std::string model_endpoint = "https://huggingface.co/";
|
||||
if (endpoint_env) {
|
||||
model_endpoint = endpoint_env;
|
||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
||||
}
|
||||
return model_endpoint;
|
||||
}
|
||||
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||
llama_clear_adapter_lora(ctx);
|
||||
for (auto & la : lora) {
|
||||
|
@ -543,6 +543,8 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
|
||||
// clear LoRA adapters from context, then apply new list of adapters
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||
|
||||
std::string get_model_endpoint();
|
||||
|
||||
//
|
||||
// Batch utils
|
||||
//
|
||||
|
@ -65,6 +65,7 @@ class Model:
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
dir_model_card: Path
|
||||
remote_hf_model_id: str | None
|
||||
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
@ -73,7 +74,8 @@ class Model:
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, thread_count: int = 2):
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
thread_count: int = 2):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
@ -83,11 +85,24 @@ class Model:
|
||||
self.is_big_endian = is_big_endian
|
||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
self.use_temp_file = use_temp_file
|
||||
self.lazy = not eager
|
||||
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
|
||||
self.is_safetensors = len(self.part_names) > 0
|
||||
if not self.is_safetensors:
|
||||
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
|
||||
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
self.tensor_names = set(name for name in remote_tensors.keys())
|
||||
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
|
||||
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
|
||||
|
||||
self.get_tensors = get_remote_tensors
|
||||
else:
|
||||
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
|
||||
self.is_safetensors = len(self.part_names) > 0
|
||||
if not self.is_safetensors:
|
||||
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
|
||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
@ -394,6 +409,10 @@ class Model:
|
||||
|
||||
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
|
||||
|
||||
# If we are using HF model id, set the metadata name to the model id
|
||||
if self.remote_hf_model_id:
|
||||
self.metadata.name = self.remote_hf_model_id
|
||||
|
||||
# Fallback to model directory name if metadata name is still missing
|
||||
if self.metadata.name is None:
|
||||
self.metadata.name = self.dir_model.name
|
||||
@ -718,6 +737,9 @@ class Model:
|
||||
if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406":
|
||||
# ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
res = "llama4"
|
||||
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
|
||||
res = "glm4"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@ -1733,7 +1755,7 @@ class LlamaModel(Model):
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
# assert low_freq_wavelen != high_freq_wavelen # Errors for Llama4
|
||||
|
||||
rope_factors = []
|
||||
for freq in freqs:
|
||||
@ -1789,10 +1811,6 @@ class Llama4Model(LlamaModel):
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
name = name.replace("feed_forward.", "mlp.") # a bit hacky for now
|
||||
name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now
|
||||
|
||||
# split the gate_up into gate and up
|
||||
if "gate_up_proj" in name:
|
||||
name_up = name.replace("gate_up_proj", "up_proj.weight")
|
||||
@ -2460,6 +2478,16 @@ class Qwen2MoeModel(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("Qwen3ForCausalLM")
|
||||
class Qwen3Model(Qwen2Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3
|
||||
|
||||
|
||||
@Model.register("Qwen3MoeForCausalLM")
|
||||
class Qwen3MoeModel(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3MOE
|
||||
|
||||
|
||||
@Model.register("GPT2LMHeadModel")
|
||||
class GPT2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GPT2
|
||||
@ -4874,6 +4902,22 @@ class JaisModel(Model):
|
||||
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
|
||||
|
||||
|
||||
@Model.register("Glm4ForCausalLM")
|
||||
class Glm4Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GLM4
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "yarn":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
class ChatGLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.CHATGLM
|
||||
@ -5394,6 +5438,14 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
|
||||
dtype = cls._dtype_str_map[remote_tensor.dtype]
|
||||
shape = remote_tensor.shape
|
||||
meta = cls.meta_with_dtype_and_shape(dtype, shape)
|
||||
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
del types # unused
|
||||
@ -5471,9 +5523,13 @@ def parse_args() -> argparse.Namespace:
|
||||
"--print-supported-models", action="store_true",
|
||||
help="Print the supported models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote", action="store_true",
|
||||
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--threads", type=int, default=2,
|
||||
help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this.",
|
||||
help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this. Defaults to 2.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -5515,6 +5571,14 @@ def main() -> None:
|
||||
|
||||
dir_model = args.model
|
||||
|
||||
if args.remote:
|
||||
from huggingface_hub import snapshot_download
|
||||
local_dir = snapshot_download(
|
||||
repo_id=str(dir_model),
|
||||
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
||||
dir_model = Path(local_dir)
|
||||
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
||||
|
||||
if not dir_model.is_dir():
|
||||
logger.error(f'Error: {args.model} is not a directory')
|
||||
sys.exit(1)
|
||||
@ -5536,6 +5600,9 @@ def main() -> None:
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
elif args.remote:
|
||||
# if remote, use the model ID as the output file name
|
||||
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
|
||||
else:
|
||||
fname_out = dir_model
|
||||
|
||||
@ -5546,7 +5613,6 @@ def main() -> None:
|
||||
with torch.inference_mode():
|
||||
output_type = ftype_map[args.outtype]
|
||||
model_architecture = hparams["architectures"][0]
|
||||
|
||||
try:
|
||||
model_class = Model.from_model_architecture(model_architecture)
|
||||
except NotImplementedError:
|
||||
@ -5559,7 +5625,9 @@ def main() -> None:
|
||||
metadata_override=args.metadata, model_name=args.model_name,
|
||||
split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split, thread_count=args.threads)
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=str(args.model) if args.remote else None,
|
||||
thread_count=args.threads)
|
||||
|
||||
if args.vocab_only:
|
||||
logger.info("Exporting model vocab...")
|
||||
|
@ -114,6 +114,7 @@ models = [
|
||||
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
|
||||
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
|
||||
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
|
||||
]
|
||||
|
||||
|
||||
|
@ -425,13 +425,13 @@ Examples:
|
||||
- Use device 0:
|
||||
|
||||
```sh
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0
|
||||
```
|
||||
|
||||
- Use multiple devices:
|
||||
|
||||
```sh
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer
|
||||
```
|
||||
|
||||
*Notes:*
|
||||
@ -697,13 +697,13 @@ Examples:
|
||||
- Use device 0:
|
||||
|
||||
```
|
||||
build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm none -mg 0
|
||||
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm none -mg 0
|
||||
```
|
||||
|
||||
- Use multiple devices:
|
||||
|
||||
```
|
||||
build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm layer
|
||||
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm layer
|
||||
```
|
||||
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
# llava (legacy)
|
||||
|
||||
add_library(llava OBJECT
|
||||
llava.cpp
|
||||
llava.h
|
||||
@ -22,12 +24,41 @@ if (BUILD_SHARED_LIBS)
|
||||
install(TARGETS llava_shared LIBRARY)
|
||||
endif()
|
||||
|
||||
# mtmd
|
||||
|
||||
add_library(mtmd OBJECT
|
||||
mtmd.cpp
|
||||
mtmd.h
|
||||
clip.cpp
|
||||
clip.h
|
||||
clip-impl.h
|
||||
)
|
||||
|
||||
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
target_include_directories(mtmd PUBLIC .)
|
||||
target_include_directories(mtmd PRIVATE ../..)
|
||||
target_include_directories(mtmd PRIVATE ../../common) # for stb_image.h
|
||||
|
||||
target_compile_features(mtmd PRIVATE cxx_std_17)
|
||||
|
||||
add_library(mtmd_static STATIC $<TARGET_OBJECTS:mtmd>)
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
|
||||
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
install(TARGETS mtmd_shared LIBRARY)
|
||||
endif()
|
||||
|
||||
if (NOT MSVC)
|
||||
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
|
||||
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
|
||||
endif()
|
||||
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(llava BUILD_INFO)
|
||||
add_dependencies(mtmd BUILD_INFO)
|
||||
endif()
|
||||
|
||||
set(TARGET llama-llava-cli)
|
||||
@ -55,7 +86,7 @@ set(TARGET llama-gemma3-cli)
|
||||
add_executable(${TARGET} gemma3-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-llava-clip-quantize-cli)
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
#include "clip.h"
|
||||
|
||||
#include "clip.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdarg>
|
||||
@ -7,6 +10,7 @@
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// Internal header for clip.cpp
|
||||
|
||||
@ -120,6 +124,23 @@ static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
return PROJECTOR_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
// RGB uint8 image
|
||||
struct clip_image_u8 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<uint8_t> buf;
|
||||
};
|
||||
|
||||
// RGB float32 image (NHWC)
|
||||
// Memory layout: RGBRGBRGB...
|
||||
struct clip_image_f32 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<float> buf;
|
||||
};
|
||||
|
||||
//
|
||||
// logging
|
||||
//
|
||||
@ -178,6 +199,36 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||
|
||||
//
|
||||
// cpp wrappers
|
||||
//
|
||||
|
||||
// wrapper for clip_image_size
|
||||
struct clip_image_size_deleter {
|
||||
void operator()(clip_image_size * val) { clip_image_size_free(val); }
|
||||
};
|
||||
typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
|
||||
|
||||
// wrapper for clip_image_u8
|
||||
struct clip_image_u8_deleter {
|
||||
void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
|
||||
};
|
||||
typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
|
||||
|
||||
// wrapper for clip_image_f32
|
||||
struct clip_image_f32_deleter {
|
||||
void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
|
||||
};
|
||||
typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
|
||||
|
||||
struct clip_image_u8_batch {
|
||||
std::vector<clip_image_u8_ptr> entries;
|
||||
};
|
||||
|
||||
struct clip_image_f32_batch {
|
||||
std::vector<clip_image_f32_ptr> entries;
|
||||
};
|
||||
|
||||
//
|
||||
// common utils
|
||||
//
|
||||
@ -214,6 +265,20 @@ static void string_replace_all(std::string & s, const std::string & search, cons
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
// split string by a `std::string delim` instead of `char delim`
|
||||
static std::vector<std::string> string_split_str(std::string s, const std::string & delimiter) {
|
||||
std::vector<std::string> tokens;
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
while ((pos = s.find(delimiter)) != std::string::npos) {
|
||||
token = s.substr(0, pos);
|
||||
tokens.push_back(token);
|
||||
s.erase(0, pos + delimiter.length());
|
||||
}
|
||||
tokens.push_back(s);
|
||||
return tokens;
|
||||
}
|
||||
|
||||
//
|
||||
// gguf utils
|
||||
//
|
||||
@ -271,3 +336,9 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
||||
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// API used internally with mtmd
|
||||
//
|
||||
|
||||
projector_type clip_get_projector_type(const struct clip_ctx * ctx);
|
||||
|
@ -32,23 +32,6 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
|
||||
|
||||
//#define CLIP_DEBUG_FUNCTIONS
|
||||
|
||||
// RGB uint8 image
|
||||
struct clip_image_u8 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<uint8_t> buf;
|
||||
};
|
||||
|
||||
// RGB float32 image (NHWC)
|
||||
// Memory layout: RGBRGBRGB...
|
||||
struct clip_image_f32 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<float> buf;
|
||||
};
|
||||
|
||||
#ifdef CLIP_DEBUG_FUNCTIONS
|
||||
static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
|
||||
std::ofstream file(filename, std::ios::binary);
|
||||
@ -331,60 +314,48 @@ struct clip_ctx {
|
||||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
bool use_silu = false;
|
||||
int32_t ftype = 1;
|
||||
|
||||
struct gguf_context * ctx_gguf = nullptr;
|
||||
struct ggml_context * ctx_data = nullptr;
|
||||
gguf_context_ptr ctx_gguf;
|
||||
ggml_context_ptr ctx_data;
|
||||
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
|
||||
ggml_backend_t backend = nullptr;
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
ggml_backend_buffer_t buf = nullptr;
|
||||
ggml_backend_ptr backend;
|
||||
ggml_backend_ptr backend_cpu;
|
||||
ggml_backend_buffer_ptr buf;
|
||||
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
struct clip_image_size * load_image_size = nullptr;
|
||||
clip_image_size load_image_size;
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
backend = ctx_params.use_gpu
|
||||
backend_cpu = ggml_backend_ptr(ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr));
|
||||
backend = ggml_backend_ptr(ctx_params.use_gpu
|
||||
? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
|
||||
: nullptr;
|
||||
: nullptr);
|
||||
|
||||
if (backend) {
|
||||
LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
backend_ptrs.push_back(backend);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
|
||||
LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend.get()));
|
||||
backend_ptrs.push_back(backend.get());
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend.get()));
|
||||
} else {
|
||||
backend = backend_cpu;
|
||||
backend = std::move(backend_cpu);
|
||||
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||
}
|
||||
|
||||
backend_ptrs.push_back(backend_cpu);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
|
||||
backend_ptrs.push_back(backend_cpu.get());
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu.get()));
|
||||
|
||||
sched.reset(
|
||||
ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false)
|
||||
);
|
||||
}
|
||||
|
||||
~clip_ctx() {
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_backend_buffer_free(buf);
|
||||
ggml_backend_free(backend);
|
||||
if (backend_cpu != backend) {
|
||||
ggml_backend_free(backend_cpu);
|
||||
}
|
||||
clip_image_size_free(load_image_size);
|
||||
}
|
||||
};
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
||||
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@ -400,7 +371,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
||||
const int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
|
||||
GGML_ASSERT(imgs->size == 1); // batch_size == 1
|
||||
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
@ -408,7 +379,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
ggml_context_ptr ctx0_ptr(ggml_init(params));
|
||||
auto ctx0 = ctx0_ptr.get();
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// input raw
|
||||
@ -530,12 +503,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_ERR("This gguf file seems to have no vision encoder\n");
|
||||
return nullptr;
|
||||
@ -548,23 +519,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
if (load_image_size == nullptr) {
|
||||
load_image_size = clip_image_size_init();
|
||||
}
|
||||
LOG_DBG("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height);
|
||||
image_size_width = load_image_size->width;
|
||||
image_size_height = load_image_size->height;
|
||||
LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
|
||||
image_size_width = load_image_size.width;
|
||||
image_size_height = load_image_size.height;
|
||||
if (is_inf) {
|
||||
image_size_width = imgs->data->nx;
|
||||
image_size_height = imgs->data->ny;
|
||||
image_size_width = imgs.entries[0]->nx;
|
||||
image_size_height = imgs.entries[0]->ny;
|
||||
}
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
// use the image's native resolution when image is avaible
|
||||
if (is_inf) {
|
||||
// if (imgs->data->nx && imgs->data->ny) {
|
||||
image_size_width = imgs->data->nx;
|
||||
image_size_height = imgs->data->ny;
|
||||
image_size_width = imgs.entries[0]->nx;
|
||||
image_size_height = imgs.entries[0]->ny;
|
||||
}
|
||||
}
|
||||
const int patch_size = hparams.patch_size;
|
||||
@ -579,7 +547,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
||||
const float eps = hparams.eps;
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
const int batch_size = imgs.entries.size();
|
||||
|
||||
if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
@ -591,7 +559,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
ggml_context_ptr ctx0_ptr(ggml_init(params));
|
||||
auto ctx0 = ctx0_ptr.get();
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
|
||||
@ -1079,7 +1049,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatel error");
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
@ -1099,12 +1069,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return clip_image_build_graph_siglip(ctx, imgs);
|
||||
} else {
|
||||
@ -1142,9 +1110,6 @@ struct clip_model_loader {
|
||||
|
||||
// print gguf info
|
||||
{
|
||||
int ftype = -1;
|
||||
get_u32(KEY_FTYPE, ftype, false);
|
||||
const std::string ftype_str = ggml_type_name(static_cast<ggml_type>(ftype));
|
||||
std::string name;
|
||||
get_string(KEY_NAME, name, false);
|
||||
std::string description;
|
||||
@ -1155,7 +1120,6 @@ struct clip_model_loader {
|
||||
LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx_gguf.get()));
|
||||
LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors);
|
||||
LOG_INF("%s: n_kv: %d\n", __func__, (int)gguf_get_n_kv(ctx_gguf.get()));
|
||||
LOG_INF("%s: ftype: %s\n", __func__, ftype_str.c_str());
|
||||
LOG_INF("\n");
|
||||
}
|
||||
|
||||
@ -1279,7 +1243,7 @@ struct clip_model_loader {
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ctx_clip.ctx_data = ggml_init(params);
|
||||
ctx_clip.ctx_data.reset(ggml_init(params));
|
||||
if (!ctx_clip.ctx_data) {
|
||||
throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__));
|
||||
}
|
||||
@ -1293,7 +1257,7 @@ struct clip_model_loader {
|
||||
if (cur) {
|
||||
tensors_to_load.push_back(cur);
|
||||
// add tensors to context
|
||||
struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data, cur);
|
||||
struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
|
||||
ggml_set_name(data_tensor, cur->name);
|
||||
cur = data_tensor;
|
||||
}
|
||||
@ -1464,11 +1428,11 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
// alloc memory and offload data
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
|
||||
ctx_clip.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data, buft);
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend.get());
|
||||
ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
for (auto & t : tensors_to_load) {
|
||||
struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data, t->name);
|
||||
struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
|
||||
const size_t offset = tensor_offset[t->name];
|
||||
fin.seekg(offset, std::ios::beg);
|
||||
if (!fin) {
|
||||
@ -1493,10 +1457,20 @@ struct clip_model_loader {
|
||||
|
||||
void alloc_compute_meta() {
|
||||
ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||
|
||||
// create a fake batch
|
||||
clip_image_f32_batch batch;
|
||||
batch.size = 1;
|
||||
batch.data = nullptr;
|
||||
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, &batch, nullptr, false);
|
||||
clip_image_f32_ptr img(clip_image_f32_init());
|
||||
clip_image_size image_size;
|
||||
image_size.width = clip_get_image_size(&ctx_clip);
|
||||
image_size.height = clip_get_image_size(&ctx_clip);
|
||||
int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
|
||||
img->nx = n_patches;
|
||||
img->ny = n_patches;
|
||||
img->buf.resize(n_patches * image_size.width * image_size.height * 3);
|
||||
batch.entries.push_back(std::move(img));
|
||||
|
||||
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
|
||||
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
|
||||
for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
|
||||
ggml_backend_t backend = ctx_clip.backend_ptrs[i];
|
||||
@ -1597,11 +1571,11 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
}
|
||||
|
||||
void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
|
||||
ctx_clip->load_image_size = load_image_size;
|
||||
ctx_clip->load_image_size = *load_image_size; // copy
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
|
||||
return ctx_clip->load_image_size;
|
||||
return &ctx_clip->load_image_size;
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_image_size_init() {
|
||||
@ -1619,25 +1593,53 @@ struct clip_image_f32 * clip_image_f32_init() {
|
||||
return new clip_image_f32();
|
||||
}
|
||||
|
||||
struct clip_image_f32_batch * clip_image_f32_batch_init() {
|
||||
return new clip_image_f32_batch();
|
||||
}
|
||||
|
||||
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
|
||||
if (nx) *nx = img->nx;
|
||||
if (ny) *ny = img->ny;
|
||||
return img->buf.data();
|
||||
}
|
||||
|
||||
void clip_image_size_free(struct clip_image_size * load_image_size) {
|
||||
if (load_image_size == nullptr) {
|
||||
return;
|
||||
}
|
||||
delete load_image_size;
|
||||
}
|
||||
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
|
||||
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
|
||||
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) {
|
||||
if (batch->size > 0) {
|
||||
delete[] batch->data;
|
||||
batch->size = 0;
|
||||
}
|
||||
void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
|
||||
void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
|
||||
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
|
||||
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
|
||||
|
||||
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
|
||||
return batch->entries.size();
|
||||
}
|
||||
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) {
|
||||
if (batch->size > 0) {
|
||||
delete[] batch->data;
|
||||
batch->size = 0;
|
||||
|
||||
size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) {
|
||||
if (idx < 0 || idx >= (int)batch->entries.size()) {
|
||||
LOG_ERR("%s: invalid index %d\n", __func__, idx);
|
||||
return 0;
|
||||
}
|
||||
return batch->entries[idx]->nx;
|
||||
}
|
||||
|
||||
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
|
||||
if (idx < 0 || idx >= (int)batch->entries.size()) {
|
||||
LOG_ERR("%s: invalid index %d\n", __func__, idx);
|
||||
return 0;
|
||||
}
|
||||
return batch->entries[idx]->ny;
|
||||
}
|
||||
|
||||
clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
|
||||
if (idx < 0 || idx >= (int)batch->entries.size()) {
|
||||
LOG_ERR("%s: invalid index %d\n", __func__, idx);
|
||||
return nullptr;
|
||||
}
|
||||
return batch->entries[idx].get();
|
||||
}
|
||||
|
||||
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
|
||||
@ -1711,14 +1713,15 @@ static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int ta
|
||||
}
|
||||
|
||||
// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
|
||||
static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) {
|
||||
dst->nx = src->nx;
|
||||
dst->ny = src->ny;
|
||||
dst->buf.resize(src->buf.size());
|
||||
static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
|
||||
dst.nx = src.nx;
|
||||
dst.ny = src.ny;
|
||||
dst.buf.resize(src.buf.size());
|
||||
|
||||
for (size_t i = 0; i < src->buf.size(); ++i) {
|
||||
// TODO @ngxson : seems like this could be done more efficiently on cgraph
|
||||
for (size_t i = 0; i < src.buf.size(); ++i) {
|
||||
int c = i % 3; // rgb
|
||||
dst->buf[i] = (static_cast<float>(src->buf[i]) / 255.0f - mean[c]) / std[c];
|
||||
dst.buf[i] = (static_cast<float>(src.buf[i]) / 255.0f - mean[c]) / std[c];
|
||||
}
|
||||
}
|
||||
|
||||
@ -1726,7 +1729,7 @@ inline int clip(int x, int lower, int upper) {
|
||||
return std::max(lower, std::min(x, upper));
|
||||
}
|
||||
|
||||
static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) {
|
||||
static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
const int nx = img.nx;
|
||||
const int ny = img.ny;
|
||||
|
||||
@ -1864,13 +1867,13 @@ static std::pair<int, int> select_best_resolution(const std::pair<int, int> & or
|
||||
return best_fit;
|
||||
}
|
||||
|
||||
static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
|
||||
std::vector<clip_image_u8*> patches;
|
||||
static std::vector<clip_image_u8_ptr> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
|
||||
std::vector<clip_image_u8_ptr> patches;
|
||||
int width = image.nx;
|
||||
int height = image.ny;
|
||||
for (int i = 0; i < height; i += patch_size) {
|
||||
for (int j = 0; j < width; j += patch_size) {
|
||||
clip_image_u8 *patch = clip_image_u8_init();
|
||||
clip_image_u8_ptr patch(clip_image_u8_init());
|
||||
patch->nx = std::min(patch_size, width - j);
|
||||
patch->ny = std::min(patch_size, height - i);
|
||||
patch->buf.resize(3 * patch->nx * patch->ny);
|
||||
@ -1881,7 +1884,7 @@ static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & im
|
||||
}
|
||||
}
|
||||
}
|
||||
patches.push_back(patch);
|
||||
patches.push_back(std::move(patch));
|
||||
}
|
||||
}
|
||||
return patches;
|
||||
@ -1962,7 +1965,7 @@ static std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int mul
|
||||
// -> https://arxiv.org/pdf/2403.11703
|
||||
// -> https://github.com/thunlp/LLaVA-UHD
|
||||
// -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
|
||||
static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
|
||||
static std::vector<std::vector<clip_image_u8_ptr>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
|
||||
const std::pair<int, int> original_size={img->nx,img->ny};
|
||||
const int original_width = img->nx;
|
||||
const int original_height = img->ny;
|
||||
@ -1970,30 +1973,30 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
|
||||
const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
|
||||
const int multiple = fmin(ceil(ratio), max_slice_nums);
|
||||
|
||||
std::vector<std::vector<clip_image_u8 *>> images;
|
||||
std::vector<std::vector<clip_image_u8_ptr>> images;
|
||||
LOG_DBG("%s: multiple %d\n", __func__, multiple);
|
||||
images.push_back(std::vector<clip_image_u8 *>());
|
||||
images.push_back(std::vector<clip_image_u8_ptr>());
|
||||
|
||||
if (multiple <= 1) {
|
||||
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
|
||||
clip_image_u8 * source_image = clip_image_u8_init();
|
||||
clip_image_u8_ptr source_image(clip_image_u8_init());
|
||||
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
|
||||
// source_image = image.resize(best_size, Image.Resampling.BICUBIC)
|
||||
images[images.size()-1].push_back(source_image);
|
||||
images.back().push_back(std::move(source_image));
|
||||
}
|
||||
else if (multiple > 1) {
|
||||
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
|
||||
clip_image_u8 * source_image = clip_image_u8_init();
|
||||
clip_image_u8_ptr source_image(clip_image_u8_init());
|
||||
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
|
||||
// source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
|
||||
LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second);
|
||||
images[images.size()-1].push_back(source_image);
|
||||
images.back().push_back(std::move(source_image));
|
||||
|
||||
std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
|
||||
LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
|
||||
|
||||
auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
|
||||
clip_image_u8 * refine_image = clip_image_u8_init();
|
||||
clip_image_u8_ptr refine_image(clip_image_u8_init());
|
||||
bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
|
||||
|
||||
LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second);
|
||||
@ -2004,9 +2007,9 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
|
||||
int grid_x = int(width / best_grid.first);
|
||||
int grid_y = int(height / best_grid.second);
|
||||
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
|
||||
images.push_back(std::vector<clip_image_u8 *>());
|
||||
images.push_back(std::vector<clip_image_u8_ptr>());
|
||||
for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
|
||||
clip_image_u8 * patch = clip_image_u8_init();
|
||||
clip_image_u8_ptr patch(clip_image_u8_init());
|
||||
patch->nx = grid_x;
|
||||
patch->ny = grid_y;
|
||||
patch->buf.resize(3 * patch->nx * patch->ny);
|
||||
@ -2019,10 +2022,9 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
|
||||
patch->buf[j+2] = refine_image->buf[i+2];
|
||||
}
|
||||
}
|
||||
images[images.size()-1].push_back(patch);
|
||||
images.back().push_back(std::move(patch));
|
||||
}
|
||||
}
|
||||
clip_image_u8_free(refine_image);
|
||||
}
|
||||
return images;
|
||||
}
|
||||
@ -2030,8 +2032,8 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
|
||||
int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
|
||||
const int max_slice_nums=9;
|
||||
const int scale_resolution=448;
|
||||
const int original_width = ctx_clip->load_image_size->width;
|
||||
const int original_height = ctx_clip->load_image_size->height;
|
||||
const int original_width = ctx_clip->load_image_size.width;
|
||||
const int original_height = ctx_clip->load_image_size.height;
|
||||
const float log_ratio = log(1.0*original_width/original_height);
|
||||
const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
|
||||
const int multiple = fmin(ceil(ratio), max_slice_nums);
|
||||
@ -2041,64 +2043,44 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
|
||||
|
||||
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
|
||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
|
||||
|
||||
if(clip_is_minicpmv(ctx)){
|
||||
if (clip_is_minicpmv(ctx)) {
|
||||
int max_slice_nums = 9;
|
||||
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
|
||||
res_imgs->size = 0;
|
||||
for (size_t i = 0; i < imgs.size(); ++i){
|
||||
res_imgs->size += imgs[i].size();
|
||||
}
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
int idx = 0;
|
||||
std::vector<std::vector<clip_image_u8_ptr>> imgs = uhd_slice_image(img, max_slice_nums);
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
for (size_t j = 0; j < imgs[i].size(); ++j) {
|
||||
LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
|
||||
clip_image_f32 * res = clip_image_f32_init();
|
||||
normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std);
|
||||
res_imgs->data[idx++] = *res;
|
||||
clip_image_f32_free(res);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
for (size_t j = 0; j < imgs[i].size(); ++j) {
|
||||
if (imgs[i][j] != nullptr) {
|
||||
clip_image_u8_free(imgs[i][j]);
|
||||
}
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
clip_image_u8 * resized = clip_image_u8_init();
|
||||
auto patch_size = clip_patch_size(ctx) * 2;
|
||||
clip_image_u8 resized;
|
||||
auto patch_size = clip_get_patch_size(ctx) * 2;
|
||||
int nx = ceil((float)img->nx / patch_size) * patch_size;
|
||||
int ny = ceil((float)img->ny / patch_size) * patch_size;
|
||||
bicubic_resize(*img, *resized, nx, ny);
|
||||
bicubic_resize(*img, resized, nx, ny);
|
||||
|
||||
res_imgs->data = new clip_image_f32[1];
|
||||
// clip_image_f32 * res = clip_image_f32_init();
|
||||
normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
// clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std);
|
||||
// res_imgs->data[0] = *res;
|
||||
res_imgs->size = 1;
|
||||
|
||||
// clip_image_f32_free(res);
|
||||
clip_image_u8_free(resized);
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
res_imgs->size = 1;
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
clip_image_u8 resized_image;
|
||||
int32_t sz=ctx->vision_model.hparams.image_size;
|
||||
bicubic_resize(*img, resized_image,sz,sz);
|
||||
clip_image_f32 * res = clip_image_f32_init();
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
//clip_image_save_to_bmp(resized_image, "resized.bmp");
|
||||
normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std);
|
||||
res_imgs->data[0] = *res;
|
||||
clip_image_f32_free(res);
|
||||
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -2113,16 +2095,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
pad_to_square = false;
|
||||
}
|
||||
// free the previous res_imgs if any set
|
||||
if (res_imgs->size > 0) {
|
||||
clip_image_f32_batch_free(res_imgs);
|
||||
}
|
||||
res_imgs->data = nullptr;
|
||||
res_imgs->size = 0;
|
||||
res_imgs->entries.clear();
|
||||
|
||||
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
|
||||
clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily
|
||||
clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
|
||||
if (pad_to_square && img->nx != img->ny) {
|
||||
int longer_side = std::max(img->nx, img->ny);
|
||||
temp->nx = longer_side;
|
||||
@ -2165,28 +2143,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
// clip_image_u8_free(temp2);
|
||||
// }
|
||||
|
||||
std::vector<clip_image_u8 *> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
|
||||
std::vector<clip_image_u8_ptr> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
|
||||
|
||||
clip_image_u8 *image_original_resize = clip_image_u8_init();
|
||||
clip_image_u8_ptr image_original_resize(clip_image_u8_init());
|
||||
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
patches.insert(patches.begin(), image_original_resize);
|
||||
// clip_image_f32_batch_init(patches.size());
|
||||
res_imgs->size = patches.size();
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
int num=0;
|
||||
for (auto& patch : patches) {
|
||||
normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std);
|
||||
num++;
|
||||
patches.insert(patches.begin(), std::move(image_original_resize));
|
||||
for (auto & patch : patches) {
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < patches.size(); i++) {
|
||||
// LOG_DBG("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny);
|
||||
clip_image_u8_free(patches[i]);
|
||||
}
|
||||
|
||||
clip_image_u8_free(temp);
|
||||
|
||||
return true;
|
||||
} else {
|
||||
temp->nx = img->nx;
|
||||
@ -2202,7 +2170,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
|
||||
const int nx2 = ctx->vision_model.hparams.image_size;
|
||||
const int ny2 = ctx->vision_model.hparams.image_size;
|
||||
clip_image_f32 * res = clip_image_f32_init();
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
res->nx = nx2;
|
||||
res->ny = ny2;
|
||||
res->buf.resize(3 * nx2 * ny2);
|
||||
@ -2254,7 +2222,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
}
|
||||
}
|
||||
}
|
||||
clip_image_u8_free(temp);
|
||||
|
||||
// {
|
||||
// clip_image_u8 * temp2 = clip_image_u8_init();
|
||||
@ -2264,10 +2231,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
// }
|
||||
// res_imgs.push_back(res);
|
||||
|
||||
res_imgs->size = 1;
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
res_imgs->data[0] = *res;
|
||||
clip_image_f32_free(res);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -2295,15 +2259,15 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w
|
||||
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
int32_t clip_image_size(const struct clip_ctx * ctx) {
|
||||
int32_t clip_get_image_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_size;
|
||||
}
|
||||
|
||||
int32_t clip_patch_size(const struct clip_ctx * ctx) {
|
||||
int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.patch_size;
|
||||
}
|
||||
|
||||
int32_t clip_hidden_size(const struct clip_ctx * ctx) {
|
||||
int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.hidden_size;
|
||||
}
|
||||
|
||||
@ -2351,6 +2315,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||
n_patches = x_patch * y_patch;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
n_patches = 256;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
@ -2448,19 +2414,23 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
|
||||
return false;
|
||||
}
|
||||
|
||||
clip_image_f32_batch imgs{};
|
||||
imgs.size = 1;
|
||||
imgs.data = img;
|
||||
clip_image_f32_batch imgs;
|
||||
clip_image_f32_ptr img_copy(clip_image_f32_init());
|
||||
*img_copy = *img;
|
||||
imgs.entries.push_back(std::move(img_copy));
|
||||
|
||||
return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
|
||||
}
|
||||
|
||||
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) {
|
||||
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
|
||||
const clip_image_f32_batch & imgs = *imgs_c_ptr;
|
||||
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
int batch_size = imgs->size;
|
||||
int batch_size = imgs.entries.size();
|
||||
if (ctx->has_llava_projector) {
|
||||
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
|
||||
}
|
||||
@ -2487,25 +2457,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
|
||||
image_size_width = imgs->data[0].nx;
|
||||
image_size_height = imgs->data[0].ny;
|
||||
image_size_width = imgs.entries[0]->nx;
|
||||
image_size_height = imgs.entries[0]->ny;
|
||||
}
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||
if(ctx->load_image_size==nullptr){
|
||||
ctx->load_image_size= clip_image_size_init();
|
||||
}
|
||||
const int pos_w = ctx->load_image_size->width/patch_size;
|
||||
const int pos_h = ctx->load_image_size->height/patch_size;
|
||||
const int pos_w = ctx->load_image_size.width / patch_size;
|
||||
const int pos_h = ctx->load_image_size.height / patch_size;
|
||||
|
||||
{
|
||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||
|
||||
for (size_t i = 0; i < imgs->size; i++) {
|
||||
const int nx = imgs->data[i].nx;
|
||||
const int ny = imgs->data[i].ny;
|
||||
for (size_t i = 0; i < imgs.entries.size(); i++) {
|
||||
const int nx = imgs.entries[i]->nx;
|
||||
const int ny = imgs.entries[i]->ny;
|
||||
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
|
||||
GGML_ASSERT(nx == image_size && ny == image_size);
|
||||
}
|
||||
@ -2516,7 +2483,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
for (int k = 0; k < 3; k++) {
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2643,7 +2610,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend_cpu.get(), n_threads);
|
||||
|
||||
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
@ -2676,8 +2643,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
||||
/* verbosity */ GGML_LOG_LEVEL_ERROR,
|
||||
});
|
||||
|
||||
const auto & ctx_src = ctx_clip->ctx_gguf;
|
||||
const auto & ctx_data = ctx_clip->ctx_data;
|
||||
const auto & ctx_src = ctx_clip->ctx_gguf.get();
|
||||
const auto & ctx_data = ctx_clip->ctx_data.get();
|
||||
|
||||
auto * ctx_out = gguf_init_empty();
|
||||
gguf_set_kv(ctx_out, ctx_src);
|
||||
@ -2898,3 +2865,11 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
|
||||
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// API used internally with mtmd
|
||||
//
|
||||
|
||||
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
|
||||
return ctx->proj_type;
|
||||
}
|
||||
|
@ -30,15 +30,8 @@ struct clip_image_size {
|
||||
int height;
|
||||
};
|
||||
|
||||
struct clip_image_u8_batch {
|
||||
struct clip_image_u8 * data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct clip_image_f32_batch {
|
||||
struct clip_image_f32 * data;
|
||||
size_t size;
|
||||
};
|
||||
struct clip_image_u8_batch;
|
||||
struct clip_image_f32_batch;
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
@ -55,9 +48,9 @@ CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
|
||||
|
||||
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
|
||||
|
||||
// TODO: should be enum, not string
|
||||
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
@ -73,9 +66,13 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
|
||||
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
|
||||
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
|
||||
|
||||
CLIP_API struct clip_image_size * clip_image_size_init();
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_init();
|
||||
CLIP_API struct clip_image_size * clip_image_size_init();
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_init();
|
||||
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava
|
||||
|
||||
// nx, ny are the output image dimensions
|
||||
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
|
||||
|
||||
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
|
||||
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
|
||||
@ -83,6 +80,12 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
||||
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
|
||||
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
||||
|
||||
// use for accessing underlay data of clip_image_f32_batch
|
||||
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
|
||||
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
|
||||
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
|
||||
CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
|
||||
|
||||
/**
|
||||
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
|
||||
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
|
||||
|
@ -2,11 +2,11 @@
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "stb_image.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
#include "console.h"
|
||||
#include "chat.h"
|
||||
#include "mtmd.h"
|
||||
|
||||
#include <vector>
|
||||
#include <limits.h>
|
||||
@ -57,13 +57,18 @@ static void sigint_handler(int signo) {
|
||||
#endif
|
||||
|
||||
struct gemma3_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
common_init_result llama_init;
|
||||
mtmd_context_ptr ctx_vision;
|
||||
common_init_result llama_init;
|
||||
|
||||
llama_model * model;
|
||||
llama_context * lctx;
|
||||
const llama_vocab * vocab;
|
||||
llama_batch batch;
|
||||
int n_batch;
|
||||
|
||||
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
||||
// so here we don't need to keep track of chat history
|
||||
common_chat_templates_ptr tmpls;
|
||||
|
||||
int n_threads = 1;
|
||||
llama_pos n_past = 0;
|
||||
@ -74,21 +79,24 @@ struct gemma3_context {
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_threads = params.cpuparams.n_threads;
|
||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
init_clip_model(params);
|
||||
n_batch = params.n_batch;
|
||||
tmpls = common_chat_templates_init(model, params.chat_template);
|
||||
init_vision_context(params);
|
||||
}
|
||||
|
||||
void init_clip_model(common_params & params) {
|
||||
void init_vision_context(common_params & params) {
|
||||
const char * clip_path = params.mmproj.path.c_str();
|
||||
ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
|
||||
if (!ctx_clip) {
|
||||
LOG_ERR("Failed to load CLIP model from %s\n", clip_path);
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
|
||||
/* use_gpu */ true,
|
||||
/* timings */ true,
|
||||
/* n_threads */ params.cpuparams.n_threads,
|
||||
/* verbosity */ GGML_LOG_LEVEL_INFO,
|
||||
}));
|
||||
if (!ctx_vision.get()) {
|
||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
~gemma3_context() {
|
||||
clip_free(ctx_clip);
|
||||
}
|
||||
};
|
||||
|
||||
struct decode_embd_batch {
|
||||
@ -124,77 +132,6 @@ struct decode_embd_batch {
|
||||
}
|
||||
};
|
||||
|
||||
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
|
||||
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
|
||||
common_batch_clear(ctx.batch);
|
||||
for (llama_token & t : tokens) {
|
||||
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
|
||||
}
|
||||
if (logits_last) {
|
||||
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
|
||||
}
|
||||
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
|
||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
||||
LOG_ERR("Failed to decode text\n");
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int eval_image(gemma3_context & ctx, std::string & fname) {
|
||||
std::vector<float> image_embd_v;
|
||||
int n_embd = llama_model_n_embd(ctx.model);
|
||||
int n_tokens = 256;
|
||||
image_embd_v.resize(n_tokens * n_embd);
|
||||
|
||||
bool ok;
|
||||
struct clip_image_u8 * img_u8 = clip_image_u8_init();
|
||||
ok = clip_image_load_from_file(fname.c_str(), img_u8);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to load image %s\n", fname.c_str());
|
||||
clip_image_u8_free(img_u8);
|
||||
return 2; // non-fatal error
|
||||
}
|
||||
|
||||
clip_image_f32_batch batch_f32;
|
||||
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess image\n");
|
||||
clip_image_f32_batch_free(&batch_f32);
|
||||
clip_image_u8_free(img_u8);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
LOG("Encoding image %s\n", fname.c_str());
|
||||
ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to encode image\n");
|
||||
clip_image_f32_batch_free(&batch_f32);
|
||||
clip_image_u8_free(img_u8);
|
||||
return 1;
|
||||
}
|
||||
LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
|
||||
clip_image_f32_batch_free(&batch_f32);
|
||||
clip_image_u8_free(img_u8);
|
||||
|
||||
// decode image embeddings
|
||||
int64_t t1 = ggml_time_ms();
|
||||
eval_text(ctx, "<start_of_image>");
|
||||
llama_set_causal_attn(ctx.lctx, false);
|
||||
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
|
||||
if (llama_decode(ctx.lctx, batch_img.batch)) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
return 1;
|
||||
}
|
||||
ctx.n_past += n_tokens;
|
||||
llama_set_causal_attn(ctx.lctx, true);
|
||||
eval_text(ctx, "<end_of_image>");
|
||||
LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
if (i > n_predict || !g_is_generating) {
|
||||
@ -224,6 +161,45 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
|
||||
std::vector<mtmd_bitmap> bitmaps;
|
||||
|
||||
common_chat_templates_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = {msg};
|
||||
tmpl_inputs.add_generation_prompt = true;
|
||||
tmpl_inputs.use_jinja = false; // jinja is buggy here
|
||||
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
||||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
||||
|
||||
for (auto & fname : images_fname) {
|
||||
mtmd_bitmap bitmap;
|
||||
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
|
||||
LOG_ERR("Unable to load image %s\n", fname.c_str());
|
||||
return 2; // image not found
|
||||
}
|
||||
bitmaps.push_back(std::move(bitmap));
|
||||
}
|
||||
|
||||
mtmd_input_text text;
|
||||
text.text = formatted_chat.prompt;
|
||||
text.add_special = add_bos;
|
||||
text.parse_special = true;
|
||||
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
|
||||
if (chunks == nullptr) {
|
||||
LOG_ERR("Unable to tokenize prompt\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
|
||||
LOG_ERR("Unable to eval prompt\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
@ -265,21 +241,15 @@ int main(int argc, char ** argv) {
|
||||
#endif
|
||||
}
|
||||
|
||||
if (eval_text(ctx, "<bos>")) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (is_single_turn) {
|
||||
g_is_generating = true;
|
||||
if (eval_text(ctx, "<start_of_turn>user\n")) {
|
||||
return 1;
|
||||
if (params.prompt.find("<__image__>") == std::string::npos) {
|
||||
params.prompt += " <__image__>";
|
||||
}
|
||||
for (auto & fname : params.image) {
|
||||
if (eval_image(ctx, fname)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = params.prompt;
|
||||
if (eval_message(ctx, msg, params.image, true)) {
|
||||
return 1;
|
||||
}
|
||||
if (generate_response(ctx, smpl, n_predict)) {
|
||||
@ -293,9 +263,9 @@ int main(int argc, char ** argv) {
|
||||
LOG("\n /quit or /exit exit the program");
|
||||
LOG("\n");
|
||||
|
||||
if (eval_text(ctx, "<start_of_turn>user\n")) {
|
||||
return 1;
|
||||
}
|
||||
bool is_first_msg = true;
|
||||
std::vector<std::string> images_fname;
|
||||
std::string content;
|
||||
|
||||
while (true) {
|
||||
g_is_generating = false;
|
||||
@ -320,24 +290,31 @@ int main(int argc, char ** argv) {
|
||||
g_is_generating = true;
|
||||
if (line.find("/image") == 0) {
|
||||
std::string image = line.substr(7);
|
||||
int res = eval_image(ctx, image);
|
||||
if (res == 2) {
|
||||
continue; // image not found
|
||||
}
|
||||
if (res) {
|
||||
return 1;
|
||||
}
|
||||
images_fname.push_back(string_strip(image));
|
||||
content += "<__image__>";
|
||||
continue;
|
||||
} else {
|
||||
content += line;
|
||||
}
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = content;
|
||||
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
|
||||
if (ret == 2) {
|
||||
// non-fatal error
|
||||
images_fname.clear();
|
||||
content.clear();
|
||||
continue;
|
||||
}
|
||||
if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
|
||||
if (ret) {
|
||||
return 1;
|
||||
}
|
||||
if (generate_response(ctx, smpl, n_predict)) {
|
||||
return 1;
|
||||
}
|
||||
if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
|
||||
return 1;
|
||||
}
|
||||
images_fname.clear();
|
||||
content.clear();
|
||||
is_first_msg = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#if defined(LLAVA_LOG_OFF)
|
||||
# define LOG_INF(...)
|
||||
@ -45,6 +46,17 @@ struct clip_image_grid_shape {
|
||||
int second;
|
||||
};
|
||||
|
||||
// convenience cpp wrapper
|
||||
struct clip_image_f32_batch_deleter {
|
||||
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
|
||||
};
|
||||
typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
|
||||
|
||||
struct clip_image_size_deleter {
|
||||
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
|
||||
};
|
||||
typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
|
||||
|
||||
/**
|
||||
* Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
*
|
||||
@ -105,8 +117,8 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
||||
struct ggml_context * ctx;
|
||||
} model;
|
||||
|
||||
const int32_t image_size = clip_image_size(ctx_clip);
|
||||
const int32_t patch_size = clip_patch_size(ctx_clip);
|
||||
const int32_t image_size = clip_get_image_size(ctx_clip);
|
||||
const int32_t patch_size = clip_get_patch_size(ctx_clip);
|
||||
|
||||
int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
|
||||
|
||||
@ -246,12 +258,9 @@ static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size)
|
||||
|
||||
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
|
||||
// std::vector<clip_image_f32*> img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336
|
||||
clip_image_f32_batch img_res_v;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) {
|
||||
clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init());
|
||||
if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) {
|
||||
LOG_ERR("%s: unable to preprocess image\n", __func__);
|
||||
delete[] img_res_v.data;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -259,66 +268,72 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
|
||||
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
|
||||
|
||||
const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get());
|
||||
|
||||
if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
|
||||
std::vector<float *> image_embd_v;
|
||||
image_embd_v.resize(img_res_v.size);
|
||||
struct clip_image_size * load_image_size = clip_image_size_init();
|
||||
image_embd_v.resize(n_imgs);
|
||||
clip_image_size load_image_size;
|
||||
|
||||
for (size_t i = 0; i < img_res_v.size; i++) {
|
||||
for (size_t i = 0; i < n_imgs; i++) {
|
||||
const int64_t t_img_enc_step_start_us = ggml_time_us();
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
|
||||
int patch_size=14;
|
||||
load_image_size->width = img_res_v.data[i].nx;
|
||||
load_image_size->height = img_res_v.data[i].ny;
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
|
||||
int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny));
|
||||
int patch_size = 14;
|
||||
load_image_size.width = nx;
|
||||
load_image_size.height = ny;
|
||||
clip_add_load_image_size(ctx_clip, &load_image_size);
|
||||
|
||||
bool encoded = false;
|
||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
|
||||
if (clip_is_qwen2vl(ctx_clip)) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]);
|
||||
}
|
||||
else {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]);
|
||||
}
|
||||
|
||||
if (!encoded) {
|
||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
|
||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
|
||||
return false;
|
||||
}
|
||||
const int64_t t_img_enc_steop_batch_us = ggml_time_us();
|
||||
LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0);
|
||||
LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0);
|
||||
}
|
||||
const int64_t t_img_enc_batch_us = ggml_time_us();
|
||||
LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
|
||||
int n_img_pos_out = 0;
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
|
||||
int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
|
||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
|
||||
std::memcpy(
|
||||
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
|
||||
image_embd_v[i],
|
||||
clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
|
||||
n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]);
|
||||
clip_embd_nbytes_by_img(ctx_clip, nx, ny));
|
||||
n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res);
|
||||
}
|
||||
*n_img_pos = n_img_pos_out;
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
free(image_embd_v[i]);
|
||||
}
|
||||
image_embd_v.clear();
|
||||
load_image_size->width = img->nx;
|
||||
load_image_size->height = img->ny;
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
|
||||
delete[] img_res_v.data;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
load_image_size.width = img->nx;
|
||||
load_image_size.height = img->ny;
|
||||
clip_add_load_image_size(ctx_clip, &load_image_size);
|
||||
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height);
|
||||
}
|
||||
else if (clip_is_glm(ctx_clip)){
|
||||
struct clip_image_size * load_image_size = clip_image_size_init();
|
||||
load_image_size->width = img_res_v.data[0].nx;
|
||||
load_image_size->height = img_res_v.data[0].ny;
|
||||
load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0);
|
||||
load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0);
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd);
|
||||
int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2);
|
||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
|
||||
int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2);
|
||||
*n_img_pos = (pos * pos + 2);
|
||||
if (!encoded){
|
||||
LOG_ERR("Unable to encode image \n");
|
||||
@ -328,8 +343,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
||||
// flat / default llava-1.5 type embedding
|
||||
*n_img_pos = clip_n_patches(ctx_clip);
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096
|
||||
delete[] img_res_v.data;
|
||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
|
||||
if (!encoded) {
|
||||
LOG_ERR("Unable to encode image\n");
|
||||
|
||||
@ -340,17 +355,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
// spatial_unpad llava-1.6 type embedding
|
||||
// TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working
|
||||
std::vector<float *> image_embd_v;
|
||||
image_embd_v.resize(img_res_v.size);
|
||||
for (size_t i = 0; i < img_res_v.size; i++) {
|
||||
image_embd_v.resize(n_imgs);
|
||||
for (size_t i = 0; i < n_imgs; i++) {
|
||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
|
||||
const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
|
||||
const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
|
||||
if (!encoded) {
|
||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
|
||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const int64_t t_img_enc_batch_us = ggml_time_us();
|
||||
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
|
||||
const int32_t * image_grid = clip_image_grid(ctx_clip);
|
||||
const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
|
||||
@ -360,12 +376,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
|
||||
}
|
||||
|
||||
// free all img_res_v - not needed anymore
|
||||
delete[] img_res_v.data;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
|
||||
const int32_t image_size = clip_image_size(ctx_clip);
|
||||
const int32_t image_size = clip_get_image_size(ctx_clip);
|
||||
|
||||
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
|
||||
|
||||
|
341
examples/llava/mtmd.cpp
Normal file
341
examples/llava/mtmd.cpp
Normal file
@ -0,0 +1,341 @@
|
||||
#include "clip.h"
|
||||
#include "clip-impl.h"
|
||||
#include "mtmd.h"
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cerrno>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
struct mtmd_context {
|
||||
struct clip_ctx * ctx_clip;
|
||||
const struct llama_model * text_model;
|
||||
std::vector<float> image_embd_v; // image embedding vector
|
||||
bool print_timings;
|
||||
int n_threads;
|
||||
std::string image_marker;
|
||||
|
||||
// TODO @ngxson : add timings
|
||||
|
||||
mtmd_context(const char * mmproj_fname,
|
||||
const llama_model * text_model,
|
||||
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
|
||||
clip_context_params ctx_clip_params;
|
||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||
ctx_clip = clip_init(mmproj_fname, ctx_clip_params);
|
||||
if (!ctx_clip) {
|
||||
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
||||
}
|
||||
this->text_model = text_model;
|
||||
}
|
||||
|
||||
~mtmd_context() {
|
||||
clip_free(ctx_clip);
|
||||
}
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens_data {
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens {
|
||||
uint32_t nx; // number of tokens in x direction
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
};
|
||||
|
||||
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
const struct llama_model * text_model,
|
||||
const struct mtmd_context_params ctx_params) {
|
||||
try {
|
||||
return new mtmd_context(mmproj_fname, text_model, ctx_params);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: error: %s\n", __func__, e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void mtmd_free(mtmd_context * ctx) {
|
||||
if (ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
}
|
||||
|
||||
// copied from common_tokenize
|
||||
static std::vector<llama_token> mtmd_tokenize_text_internal(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::string & text,
|
||||
bool add_special,
|
||||
bool parse_special) {
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + 2 * add_special;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
|
||||
const mtmd_input_text & text,
|
||||
const std::vector<mtmd_bitmap> & bitmaps) {
|
||||
mtmd_input_chunks * output = new mtmd_input_chunks;
|
||||
auto vocab = llama_model_get_vocab(ctx->text_model);
|
||||
|
||||
std::string prompt_modified(text.text);
|
||||
std::string marker_modified(ctx->image_marker);
|
||||
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||
// a bit hacky here, but works for now
|
||||
// for some models, we need to add prefix and suffix to the image embeddings
|
||||
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
}
|
||||
|
||||
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
|
||||
output->clear();
|
||||
output->reserve(parts.size());
|
||||
|
||||
size_t i_img = 0;
|
||||
|
||||
for (const auto & part : parts) {
|
||||
//printf("tokenizing part: %s\n", part.c_str());
|
||||
bool add_bos = &parts.front() == ∂
|
||||
auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special);
|
||||
if (tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
std::move(tokens),
|
||||
{},
|
||||
};
|
||||
output->emplace_back(std::move(chunk));
|
||||
|
||||
if (&parts.back() != &part) {
|
||||
// add image token to middle of 2 parts
|
||||
|
||||
if (i_img >= bitmaps.size()) {
|
||||
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// shim layer
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
img_u8->nx = bitmaps[i_img].nx;
|
||||
img_u8->ny = bitmaps[i_img].ny;
|
||||
img_u8->buf.resize(bitmaps[i_img].data.size());
|
||||
std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
|
||||
|
||||
// preprocess image
|
||||
clip_image_f32_batch batch_f32;
|
||||
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess image\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
|
||||
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
|
||||
image_tokens->ny = 1; // TODO
|
||||
image_tokens->batch_f32 = std::move(batch_f32);
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{},
|
||||
image_tokens,
|
||||
};
|
||||
output->emplace_back(std::move(chunk));
|
||||
i_img++;
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
|
||||
for (auto & chunk : *chunks) {
|
||||
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
|
||||
delete chunk.tokens_image;
|
||||
}
|
||||
}
|
||||
delete chunks;
|
||||
}
|
||||
|
||||
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
|
||||
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
|
||||
bool ok = clip_image_batch_encode(
|
||||
ctx->ctx_clip,
|
||||
ctx->n_threads,
|
||||
&image_tokens->batch_f32,
|
||||
ctx->image_embd_v.data());
|
||||
return ok ? 0 : 1;
|
||||
}
|
||||
|
||||
float * mtmd_get_output_embd(mtmd_context * ctx) {
|
||||
return ctx->image_embd_v.data();
|
||||
}
|
||||
|
||||
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
|
||||
size_t n_tokens = 0;
|
||||
for (auto & chunk : *chunks) {
|
||||
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
n_tokens += chunk.tokens_text.size();
|
||||
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
n_tokens += chunk.tokens_image->n_tokens();
|
||||
} else {
|
||||
GGML_ASSERT(false && "chunk type not supported");
|
||||
}
|
||||
}
|
||||
return n_tokens;
|
||||
}
|
||||
|
||||
// helper struct to make working with embd batch easier
|
||||
// note: this will be removed after llama_batch_ext refactoring
|
||||
struct decode_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||
llama_context * lctx,
|
||||
mtmd_input_chunks * chunks,
|
||||
llama_pos pos0,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch) {
|
||||
int32_t ret;
|
||||
llama_pos n_past = pos0;
|
||||
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
for (auto & chunk : *chunks) {
|
||||
bool is_last = &chunk == &chunks->back();
|
||||
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
// TODO @ngxson : may need to split into smaller batches
|
||||
text_batch.n_tokens = chunk.tokens_text.size();
|
||||
for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
|
||||
text_batch.token [i] = chunk.tokens_text[i];
|
||||
text_batch.pos [i] = n_past++;
|
||||
text_batch.n_seq_id[i] = 1;
|
||||
text_batch.seq_id [i][0] = seq_id;
|
||||
text_batch.logits [i] = false;
|
||||
}
|
||||
if (is_last) {
|
||||
// always get logits for last input chunk
|
||||
text_batch.logits[text_batch.n_tokens - 1] = true;
|
||||
}
|
||||
ret = llama_decode(lctx, text_batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode text\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
|
||||
GGML_ASSERT(chunk.tokens_image != nullptr);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("encoding image...\n");
|
||||
}
|
||||
ret = mtmd_encode(ctx, chunk.tokens_image);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to encode image\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
}
|
||||
|
||||
int32_t n_tokens = chunk.tokens_image->n_tokens();
|
||||
float * embd = mtmd_get_output_embd(ctx);
|
||||
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ret = llama_decode(lctx, batch_img.batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
|
||||
}
|
||||
|
||||
n_past += n_tokens;
|
||||
|
||||
} else {
|
||||
GGML_ASSERT(false && "chunk type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(text_batch);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) {
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to load image from buffer\n");
|
||||
return 1;
|
||||
}
|
||||
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
|
||||
output.data.resize(output.nx * output.ny * 3);
|
||||
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) {
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
bool ok = clip_image_load_from_file(fname, img_u8.get());
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to load image %s\n", fname);
|
||||
return 1;
|
||||
}
|
||||
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
|
||||
output.data.resize(output.nx * output.ny * 3);
|
||||
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
|
||||
return 0;
|
||||
}
|
146
examples/llava/mtmd.h
Normal file
146
examples/llava/mtmd.h
Normal file
@ -0,0 +1,146 @@
|
||||
#ifndef MTMD_H
|
||||
#define MTMD_H
|
||||
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "clip.h"
|
||||
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
#include <memory>
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef LLAMA_BUILD
|
||||
# define MTMD_API __declspec(dllexport)
|
||||
# else
|
||||
# define MTMD_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define MTMD_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define MTMD_API
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
enum mtmd_input_chunk_type {
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
};
|
||||
|
||||
struct mtmd_context;
|
||||
struct mtmd_image_tokens;
|
||||
|
||||
// represents raw image data, layout is RGBRGBRGB...
|
||||
// length of data must be nx * ny * 3
|
||||
struct mtmd_bitmap {
|
||||
uint32_t nx;
|
||||
uint32_t ny;
|
||||
std::vector<unsigned char> data;
|
||||
};
|
||||
|
||||
struct mtmd_input_chunk {
|
||||
mtmd_input_chunk_type type;
|
||||
std::vector<llama_token> tokens_text;
|
||||
mtmd_image_tokens * tokens_image = nullptr;
|
||||
};
|
||||
|
||||
using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
|
||||
|
||||
struct mtmd_context_params {
|
||||
bool use_gpu = true;
|
||||
bool print_timings = true;
|
||||
int n_threads = 4;
|
||||
enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO;
|
||||
const char * image_marker = "<__image__>";
|
||||
};
|
||||
|
||||
struct mtmd_input_text {
|
||||
std::string text;
|
||||
bool add_special;
|
||||
bool parse_special;
|
||||
};
|
||||
|
||||
// initialize the mtmd context
|
||||
// return nullptr on failure
|
||||
MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
const llama_model * text_model,
|
||||
const mtmd_context_params ctx_params);
|
||||
|
||||
MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||
|
||||
// tokenize an input text prompt and an image
|
||||
// the prompt must have the input image marker (default: "<__image__>") in it
|
||||
// the marker will be replaced with the image tokens
|
||||
// for example:
|
||||
// "here is an image: <__image__>\ndescribe it in detail."
|
||||
// this will gives 3 chunks:
|
||||
// 1. "here is an image: <start_of_image>"
|
||||
// 2. (image tokens)
|
||||
// 3. "<end_of_image>\ndescribe it in detail."
|
||||
// number of bitmaps must be equal to the number of image markers in the prompt
|
||||
// this function is thread-safe (shared ctx)
|
||||
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
|
||||
const mtmd_input_text & text,
|
||||
const std::vector<mtmd_bitmap> & bitmaps);
|
||||
|
||||
// free image chunk data
|
||||
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
|
||||
|
||||
// returns 0 on success
|
||||
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
||||
const mtmd_image_tokens * image_tokens);
|
||||
|
||||
// get output embeddings from the last encode pass
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
|
||||
//
|
||||
// helper functions (can be implemented based on other functions)
|
||||
//
|
||||
|
||||
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
|
||||
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
|
||||
|
||||
// helper function that automatically:
|
||||
// 1. run llama_decode() on text chunks
|
||||
// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
|
||||
// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error
|
||||
// otherwise, returns 0 on success
|
||||
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||
llama_context * lctx,
|
||||
mtmd_input_chunks * chunks,
|
||||
llama_pos pos0,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch);
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a file
|
||||
// returns 0 on success
|
||||
// this function is thread-safe
|
||||
MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output);
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a buffer
|
||||
// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
|
||||
// returns 0 on success
|
||||
// this function is thread-safe
|
||||
MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output);
|
||||
|
||||
// convenient unique_ptr wrappers
|
||||
struct mtmd_context_deleter {
|
||||
void operator()(mtmd_context * val) { mtmd_free(val); }
|
||||
};
|
||||
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
|
||||
|
||||
struct mtmd_input_chunks_deleter {
|
||||
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
|
||||
};
|
||||
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
|
||||
|
||||
#else
|
||||
|
||||
static_assert(false && "C header is not yet supported by this library");
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -697,8 +697,10 @@ class LlamaData {
|
||||
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
|
||||
std::string url;
|
||||
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
|
||||
if (pos == std::string::npos) {
|
||||
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
|
||||
auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/");
|
||||
hfr = model_name;
|
||||
|
||||
nlohmann::json manifest;
|
||||
@ -713,7 +715,7 @@ class LlamaData {
|
||||
hff = model.substr(pos + 1);
|
||||
}
|
||||
|
||||
url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
|
||||
url = model_endpoint + hfr + "/resolve/main/" + hff;
|
||||
|
||||
return download(url, bn, true, headers);
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "common/base64.hpp"
|
||||
#include "base64.hpp"
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
|
@ -507,17 +507,12 @@ extern "C" {
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
GGML_OP_MAP_UNARY,
|
||||
GGML_OP_MAP_BINARY,
|
||||
|
||||
GGML_OP_MAP_CUSTOM1_F32,
|
||||
GGML_OP_MAP_CUSTOM2_F32,
|
||||
GGML_OP_MAP_CUSTOM3_F32,
|
||||
|
||||
GGML_OP_MAP_CUSTOM1,
|
||||
GGML_OP_MAP_CUSTOM2,
|
||||
GGML_OP_MAP_CUSTOM3,
|
||||
|
||||
GGML_OP_CUSTOM,
|
||||
|
||||
GGML_OP_CROSS_ENTROPY_LOSS,
|
||||
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||
GGML_OP_OPT_STEP_ADAMW,
|
||||
@ -1722,24 +1717,29 @@ extern "C" {
|
||||
float p0,
|
||||
float p1);
|
||||
|
||||
// nearest interpolate
|
||||
enum ggml_scale_mode {
|
||||
GGML_SCALE_MODE_NEAREST = 0,
|
||||
GGML_SCALE_MODE_BILINEAR = 1,
|
||||
};
|
||||
|
||||
// interpolate
|
||||
// multiplies ne0 and ne1 by scale factor
|
||||
// used in stable-diffusion
|
||||
GGML_API struct ggml_tensor * ggml_upscale(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int scale_factor);
|
||||
int scale_factor,
|
||||
enum ggml_scale_mode mode);
|
||||
|
||||
// nearest interpolate
|
||||
// nearest interpolate to specified dimensions
|
||||
// used in tortoise.cpp
|
||||
// interpolate
|
||||
// interpolate scale to specified dimensions
|
||||
GGML_API struct ggml_tensor * ggml_upscale_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3);
|
||||
int ne3,
|
||||
enum ggml_scale_mode mode);
|
||||
|
||||
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
||||
GGML_API struct ggml_tensor * ggml_pad(
|
||||
@ -1916,83 +1916,6 @@ extern "C" {
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
|
||||
|
||||
typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
|
||||
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
||||
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
ggml_unary_op_f32_t fun),
|
||||
"use ggml_map_custom1 instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
ggml_unary_op_f32_t fun),
|
||||
"use ggml_map_custom1_inplace instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
ggml_binary_op_f32_t fun),
|
||||
"use ggml_map_custom2 instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
ggml_binary_op_f32_t fun),
|
||||
"use ggml_map_custom2_inplace instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
ggml_custom1_op_f32_t fun),
|
||||
"use ggml_map_custom1 instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
ggml_custom1_op_f32_t fun),
|
||||
"use ggml_map_custom1_inplace instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
ggml_custom2_op_f32_t fun),
|
||||
"use ggml_map_custom2 instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
ggml_custom2_op_f32_t fun),
|
||||
"use ggml_map_custom2_inplace instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
ggml_custom3_op_f32_t fun),
|
||||
"use ggml_map_custom3 instead");
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
ggml_custom3_op_f32_t fun),
|
||||
"use ggml_map_custom3_inplace instead");
|
||||
|
||||
// custom operators v2
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
|
||||
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
|
||||
@ -2048,6 +1971,30 @@ extern "C" {
|
||||
int n_tasks,
|
||||
void * userdata);
|
||||
|
||||
typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_custom_4d(
|
||||
struct ggml_context * ctx,
|
||||
enum ggml_type type,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3,
|
||||
struct ggml_tensor ** args,
|
||||
int n_args,
|
||||
ggml_custom_op_t fun,
|
||||
int n_tasks,
|
||||
void * userdata);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_custom_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor ** args,
|
||||
int n_args,
|
||||
ggml_custom_op_t fun,
|
||||
int n_tasks,
|
||||
void * userdata);
|
||||
|
||||
// loss function
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
|
||||
|
@ -41,6 +41,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
||||
return ACL_INT4;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return ACL_INT8;
|
||||
case GGML_TYPE_I64:
|
||||
return ACL_INT64;
|
||||
default:
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
|
@ -57,6 +57,13 @@
|
||||
#include <aclnnop/aclnn_sub.h>
|
||||
#include <aclnnop/aclnn_mul.h>
|
||||
#include <aclnnop/aclnn_div.h>
|
||||
#include <aclnnop/aclnn_convolution.h>
|
||||
#include <aclnnop/aclnn_elu.h>
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_mean.h>
|
||||
#include <aclnnop/aclnn_reflection_pad1d.h>
|
||||
#include <aclnnop/aclnn_eq_tensor.h>
|
||||
#include <aclnnop/aclnn_gt_scalar.h>
|
||||
#include <float.h>
|
||||
|
||||
#include <cmath>
|
||||
@ -86,6 +93,20 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_unary_op(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src = dst->src[0];
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Repeats elements of a tensor along each dimension according to the
|
||||
* specified repeat array.
|
||||
@ -2582,6 +2603,131 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
}
|
||||
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
// stride
|
||||
int64_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||
|
||||
aclTensor* acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
|
||||
aclTensor* acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
|
||||
|
||||
int64_t strideVal[1];
|
||||
strideVal[0] = s0;
|
||||
aclIntArray *stride = aclCreateIntArray(strideVal, 1);
|
||||
int64_t paddingVal[] = {0};
|
||||
aclIntArray *padding = aclCreateIntArray(paddingVal, 1);
|
||||
int64_t dilationVal[] = {1};
|
||||
aclIntArray *dilation = aclCreateIntArray(dilationVal, 1);
|
||||
bool transposed = true;
|
||||
int64_t groups = 1;
|
||||
int8_t cubeMathType = 0;
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(Convolution, acl_input, acl_weight, nullptr, stride,
|
||||
padding, dilation, transposed, padding, groups, acl_dst, cubeMathType);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_weight));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
ACL_CHECK(aclDestroyIntArray(stride));
|
||||
ACL_CHECK(aclDestroyIntArray(padding));
|
||||
ACL_CHECK(aclDestroyIntArray(dilation));
|
||||
}
|
||||
|
||||
void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
aclTensor* acl_input = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
float alphaValue = 1.0f;
|
||||
aclScalar* alpha = nullptr;
|
||||
alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(Elu, acl_input, alpha, alpha, alpha,
|
||||
acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_input));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
ACL_CHECK(aclDestroyScalar(alpha));
|
||||
}
|
||||
|
||||
void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
int64_t reduceDimValue[] = {3};
|
||||
aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1);
|
||||
bool keepDim = true;
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
ACL_CHECK(aclDestroyIntArray(reduceDim));
|
||||
}
|
||||
|
||||
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
int32_t *opts = (int32_t *) dst->op_params;
|
||||
int64_t paddingsArray[2] = {opts[0], opts[1]};
|
||||
aclIntArray* paddings = aclCreateIntArray(paddingsArray, 2);
|
||||
|
||||
for (int64_t i = 0; i < src0->ne[3]; i++) {
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(
|
||||
(char*)src0->data + i * src0->ne[3],
|
||||
ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
|
||||
src0->ne, src0->nb, 3);
|
||||
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(
|
||||
(char*)dst->data + i * src0->ne[3],
|
||||
ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
|
||||
dst->ne, dst->nb, 3);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ReflectionPad1d, acl_src, paddings, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
}
|
||||
ACL_CHECK(aclDestroyIntArray(paddings));
|
||||
}
|
||||
|
||||
void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
aclTensor* acl_self = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_other = ggml_cann_create_tensor(src1);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(InplaceEqTensor, acl_self, acl_other);
|
||||
|
||||
ggml_cann_sum(ctx, dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_self));
|
||||
ACL_CHECK(aclDestroyTensor(acl_other));
|
||||
}
|
||||
|
||||
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
float alphaValue = 0.0f;
|
||||
aclScalar* alpha = nullptr;
|
||||
alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(GtScalar, acl_src, alpha, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
ACL_CHECK(aclDestroyScalar(alpha));
|
||||
}
|
||||
|
@ -1,15 +1,4 @@
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
/**
|
||||
* @file acl_tensor
|
||||
* @brief This file contains related functions of ggml_tensor and acl_tensor.
|
||||
* Contains conversion from ggml_tensor to acl_tensor, broadcast and other
|
||||
* functions.
|
||||
* @author hipudding <huafengchun@gmail.com>
|
||||
* @author wangshuai09 <391746016@qq.com>
|
||||
* @date July 15, 2024
|
||||
*
|
||||
* Copyright (c) 2023-2024 The ggml authors
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
@ -31,6 +20,9 @@
|
||||
* IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
#include <aclnnop/aclnn_abs.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
@ -50,6 +42,8 @@
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
|
||||
@ -483,8 +477,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
* operation is executed using the CANN backend for optimized performance.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the indices of the maximum values will be stored.
|
||||
* dst->op is `GGML_OP_ARGMAX`.
|
||||
* @param dst The destination tensor where the indices of the maximum values will
|
||||
* be stored. dst->op is `GGML_OP_ARGMAX`.
|
||||
*/
|
||||
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
@ -599,6 +593,160 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one
|
||||
* output tensor.
|
||||
*
|
||||
* This function checks whether broadcasting is needed between `src0` and `src1`.
|
||||
* If broadcasting is required, it calculates the proper shapes and creates
|
||||
* ACL tensors with broadcast parameters. Otherwise, it directly creates ACL tensors
|
||||
* based on the original tensor shapes.
|
||||
*
|
||||
* @param src0 The first input tensor (reference shape).
|
||||
* @param src1 The second input tensor (possibly broadcasted).
|
||||
* @param dst The destination/output tensor.
|
||||
* @param acl_src0 Output pointer to the created ACL tensor corresponding to src0.
|
||||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst,
|
||||
aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||
* tensor using the CANN backend.
|
||||
*
|
||||
* @details This function performs a 1D transposed convolution (also known as
|
||||
* deconvolution) operation on the input tensor. The computed result is stored
|
||||
* in the destination tensor `dst`. The operation is optimized using the CANN
|
||||
* backend for improved performance.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the transposed convolution result
|
||||
* will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`.
|
||||
*/
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the ELU (Exponential Linear Unit) activation to a ggml tensor
|
||||
* using the CANN backend.
|
||||
*
|
||||
* @details This function performs an element-wise ELU activation on the input
|
||||
* tensor.
|
||||
* The result is written to the destination tensor `dst` in-place.
|
||||
* The ELU function is defined as:
|
||||
*
|
||||
* \text{ELU}(x) =
|
||||
* \begin{cases}
|
||||
* x, & \text{if } x > 0 \\
|
||||
* \alpha \left( \exp(x) - 1 \right), & \text{if } x \leq 0
|
||||
* \end{cases}
|
||||
*
|
||||
* where α (alpha) is a hyperparameter, typically set to 1.0.
|
||||
* This operation is optimized using the CANN backend for high-performance
|
||||
* inference or training.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the ELU-activated result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_ELU`.
|
||||
*/
|
||||
void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the mean of a ggml tensor element-wise using the CANN backend.
|
||||
*
|
||||
* @details This function calculates the element-wise mean of the input tensor.
|
||||
* The result is written to the destination tensor `dst`.
|
||||
* The mean is computed by averaging the values across the entire tensor.
|
||||
*
|
||||
* This operation is optimized using the CANN backend for high-performance inference or training.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the mean result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_MEAN`.
|
||||
*/
|
||||
void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Applies 1D reflect padding to a ggml tensor using the CANN backend.
|
||||
*
|
||||
* @details This function performs 1D reflect padding on the input tensor.
|
||||
* The amount of padding on each side is specified by parameters stored in `dst->op_params`.
|
||||
* The operation reflects the values at the borders of the tensor to generate the padded output.
|
||||
*
|
||||
* This operation is optimized using the CANN backend for high-performance inference or training.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the padded result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`.
|
||||
*/
|
||||
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Counts the number of equal elements in two ggml tensors using the CANN backend.
|
||||
*
|
||||
* @details This function performs an element-wise comparison between two input tensors,
|
||||
* and counts the number of positions where the elements are equal. The result is
|
||||
* stored in the destination tensor `dst` as a scalar.
|
||||
*
|
||||
* The operation is optimized using the CANN backend, making it suitable for
|
||||
* high-performance inference or training scenarios.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_COUNT_EQUAL`.
|
||||
*/
|
||||
void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the Step activation function to a ggml tensor using the CANN backend.
|
||||
*
|
||||
* @details This function applies a step function element-wise to the input tensor, where
|
||||
* each element is transformed to 1.0 if it is greater than 0, and 0.0 otherwise.
|
||||
* The result is stored in the destination tensor `dst`.
|
||||
*
|
||||
* This operation is accelerated using the CANN backend to improve runtime performance.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_STEP`.
|
||||
*/
|
||||
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a element-wise operation to two input tensors using the CANN
|
||||
* backend.
|
||||
*
|
||||
* This templated function takes a binary operator and applies it to two source
|
||||
* tensors
|
||||
* associated with the destination tensor. The function handles broadcasting as
|
||||
* needed.
|
||||
*
|
||||
* @tparam binary_op A callable object (e.g., lambda or function pointer) representing
|
||||
* the binary operation to be performed. It must take three arguments:
|
||||
* (ggml_backend_cann_context&, aclTensor*, aclTensor*, aclTensor*).
|
||||
*
|
||||
* @param ctx The CANN backend context used to manage execution and resources.
|
||||
* @param dst The destination tensor.
|
||||
*/
|
||||
template <auto binary_op>
|
||||
void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0];
|
||||
ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
aclTensor* acl_src0;
|
||||
aclTensor* acl_src1;
|
||||
aclTensor* acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src0));
|
||||
ACL_CHECK(aclDestroyTensor(acl_src1));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Launches an asynchronous task using the memory allocator.
|
||||
*
|
||||
@ -631,56 +779,6 @@ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream())); \
|
||||
} while (0)
|
||||
|
||||
|
||||
/**
|
||||
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one output tensor.
|
||||
*
|
||||
* This function checks whether broadcasting is needed between `src0` and `src1`.
|
||||
* If broadcasting is required, it calculates the proper shapes and creates
|
||||
* ACL tensors with broadcast parameters. Otherwise, it directly creates ACL tensors
|
||||
* based on the original tensor shapes.
|
||||
*
|
||||
* @param src0 The first input tensor (reference shape).
|
||||
* @param src1 The second input tensor (possibly broadcasted).
|
||||
* @param dst The destination/output tensor.
|
||||
* @param acl_src0 Output pointer to the created ACL tensor corresponding to src0.
|
||||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
|
||||
aclTensor ** acl_src1, aclTensor ** acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a element-wise operation to two input tensors using the CANN backend.
|
||||
*
|
||||
* This templated function takes a binary operator and applies it to two source tensors
|
||||
* associated with the destination tensor. The function handles broadcasting as needed.
|
||||
*
|
||||
* @tparam binary_op A callable object (e.g., lambda or function pointer) representing
|
||||
* the binary operation to be performed. It must take three arguments:
|
||||
* (ggml_backend_cann_context&, aclTensor*, aclTensor*, aclTensor*).
|
||||
*
|
||||
* @param ctx The CANN backend context used to manage execution and resources.
|
||||
* @param dst The destination tensor.
|
||||
*/
|
||||
template <auto binary_op>
|
||||
void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0];
|
||||
ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
aclTensor* acl_src0;
|
||||
aclTensor* acl_src1;
|
||||
aclTensor* acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src0));
|
||||
ACL_CHECK(aclDestroyTensor(acl_src1));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to an input tensor using the CANN backend.
|
||||
*
|
||||
@ -690,7 +788,6 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
* @tparam unary_op A callable with the signature:
|
||||
* void(ggml_backend_cann_context&, aclTensor*, aclTensor*)
|
||||
* where the first aclTensor is the source and the second is the destination.
|
||||
*
|
||||
* @param ctx The CANN backend context for managing resources and execution.
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
*/
|
||||
@ -702,10 +799,30 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
|
||||
ACL_CHECK(aclDestroyTensor(acl_src));
|
||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to a ggml tensor using the CANN backend.
|
||||
*
|
||||
* @details This function performs a unary operation on the input tensor using
|
||||
* a user-provided lambda or callable object `unary_op`, which accepts the CANN
|
||||
* context and two ACL tensors (source and destination). Internally, this function
|
||||
* creates ACL representations of the ggml tensors and invokes the unary operation.
|
||||
* The result is stored in the destination tensor `dst`. This utility abstracts the
|
||||
* common boilerplate of tensor conversion and cleanup when implementing unary ops.
|
||||
*
|
||||
* @param unary_op A callable that performs the unary operation using CANN APIs.
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the result will be stored.
|
||||
* The source tensor is retrieved from `dst->src[0]`.
|
||||
*/
|
||||
void ggml_cann_unary_op(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
|
||||
*
|
||||
@ -725,11 +842,12 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
||||
*/
|
||||
#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](auto ctx, auto acl_src, auto acl_dst) { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_unary_op<lambda>(ctx, dst); \
|
||||
ggml_cann_unary_op(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
|
||||
#endif // CANN_ACLNN_OPS
|
||||
|
@ -1330,12 +1330,13 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
GGML_CANN_CALL_UNARY_OP(Silu);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: {
|
||||
auto lambda = [](auto ctx, auto acl_src, auto acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst);
|
||||
};
|
||||
ggml_cann_unary_op<lambda>(ctx, dst);
|
||||
}
|
||||
break;
|
||||
auto lambda = [](ggml_backend_cann_context& ctx,
|
||||
aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst);
|
||||
};
|
||||
ggml_cann_unary_op(lambda, ctx, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_TANH:
|
||||
GGML_CANN_CALL_UNARY_OP(Tanh);
|
||||
break;
|
||||
@ -1354,6 +1355,15 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
case GGML_UNARY_OP_EXP:
|
||||
GGML_CANN_CALL_UNARY_OP(Exp);
|
||||
break;
|
||||
case GGML_UNARY_OP_ELU:
|
||||
ggml_cann_elu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
GGML_CANN_CALL_UNARY_OP(Sign);
|
||||
break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
ggml_cann_step(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -1448,7 +1458,22 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
ggml_cann_unary_op<aclnn_sin>(ctx, dst);
|
||||
break;
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_cann_conv_transpose_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_LOG:
|
||||
GGML_CANN_CALL_UNARY_OP(Log);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_cann_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
ggml_cann_pad_reflect_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
ggml_cann_count_equal(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -1710,6 +1735,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@ -1796,6 +1824,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
|
||||
return false;
|
||||
}
|
||||
if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_POOL_2D: {
|
||||
@ -1842,6 +1873,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -323,8 +323,6 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||
#else
|
||||
#ifdef __POWER9_VECTOR__
|
||||
#include <altivec.h>
|
||||
#undef bool
|
||||
#define bool _Bool
|
||||
#else
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <intrin.h>
|
||||
|
@ -2027,41 +2027,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_UNARY:
|
||||
{
|
||||
ggml_unary_op_f32_t fun;
|
||||
memcpy(&fun, tensor->op_params, sizeof(fun));
|
||||
ggml_compute_forward_map_unary(params, tensor, fun);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MAP_BINARY:
|
||||
{
|
||||
ggml_binary_op_f32_t fun;
|
||||
memcpy(&fun, tensor->op_params, sizeof(fun));
|
||||
ggml_compute_forward_map_binary(params, tensor, fun);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
{
|
||||
ggml_custom1_op_f32_t fun;
|
||||
memcpy(&fun, tensor->op_params, sizeof(fun));
|
||||
ggml_compute_forward_map_custom1_f32(params, tensor, fun);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MAP_CUSTOM2_F32:
|
||||
{
|
||||
ggml_custom2_op_f32_t fun;
|
||||
memcpy(&fun, tensor->op_params, sizeof(fun));
|
||||
ggml_compute_forward_map_custom2_f32(params, tensor, fun);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MAP_CUSTOM3_F32:
|
||||
{
|
||||
ggml_custom3_op_f32_t fun;
|
||||
memcpy(&fun, tensor->op_params, sizeof(fun));
|
||||
ggml_compute_forward_map_custom3_f32(params, tensor, fun);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
@ -2077,6 +2042,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
ggml_compute_forward_map_custom3(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CUSTOM:
|
||||
{
|
||||
ggml_compute_forward_custom(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
{
|
||||
ggml_compute_forward_cross_entropy_loss(params, tensor);
|
||||
@ -2328,11 +2298,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
case GGML_OP_MAP_CUSTOM2_F32:
|
||||
case GGML_OP_MAP_CUSTOM3_F32:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
@ -2366,6 +2331,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = MIN(p.n_tasks, n_threads);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CUSTOM:
|
||||
{
|
||||
struct ggml_custom_op_params p;
|
||||
memcpy(&p, node->op_params, sizeof(p));
|
||||
if (p.n_tasks == GGML_N_TASKS_MAX) {
|
||||
n_tasks = n_threads;
|
||||
} else {
|
||||
n_tasks = MIN(p.n_tasks, n_threads);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
|
@ -6351,24 +6351,72 @@ static void ggml_compute_forward_upscale_f32(
|
||||
const float sf2 = (float)ne2/src0->ne[2];
|
||||
const float sf3 = (float)ne3/src0->ne[3];
|
||||
|
||||
// TODO: optimize
|
||||
const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3 / sf3;
|
||||
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
||||
const int64_t i02 = i2 / sf2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const int64_t i01 = i1 / sf1;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
const int64_t i00 = i0 / sf0;
|
||||
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3 / sf3;
|
||||
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
||||
const int64_t i02 = i2 / sf2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const int64_t i01 = i1 / sf1;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
const int64_t i00 = i0 / sf0;
|
||||
|
||||
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
*y = *x;
|
||||
*y = *x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
// setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
|
||||
const float pixel_offset = 0.5f;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3 / sf3;
|
||||
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
||||
const int64_t i02 = i2 / sf2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
|
||||
int64_t y0 = (int64_t)floorf(y);
|
||||
int64_t y1 = y0 + 1;
|
||||
|
||||
y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
|
||||
y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
|
||||
|
||||
float dy = y - (float)y0;
|
||||
dy = std::max(0.0f, std::min(dy, 1.0f));
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
|
||||
int64_t x0 = (int64_t)floorf(x);
|
||||
int64_t x1 = x0 + 1;
|
||||
|
||||
x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
|
||||
x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
|
||||
|
||||
float dx = x - (float)x0;
|
||||
dx = std::max(0.0f, std::min(dx, 1.0f));
|
||||
|
||||
// fetch the four surrounding pixel values and interpolate
|
||||
const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
|
||||
const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
|
||||
const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
|
||||
const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
|
||||
|
||||
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
*y_dst = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("unsupported upscale mode");
|
||||
}
|
||||
}
|
||||
|
||||
@ -8268,152 +8316,6 @@ void ggml_compute_forward_rwkv_wkv7(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_unary
|
||||
|
||||
static void ggml_compute_forward_map_unary_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const ggml_unary_op_f32_t fun) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
fun(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_map_unary(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const ggml_unary_op_f32_t fun) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_map_unary_f32(params, dst, fun);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_binary
|
||||
|
||||
static void ggml_compute_forward_map_binary_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const ggml_binary_op_f32_t fun) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(src1));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
fun(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
||||
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_map_binary(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const ggml_binary_op_f32_t fun) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_map_binary_f32(params, dst, fun);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_custom1
|
||||
|
||||
void ggml_compute_forward_map_custom1_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const ggml_custom1_op_f32_t fun) {
|
||||
|
||||
const ggml_tensor * a = dst->src[0];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
fun(dst, a);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_custom2
|
||||
|
||||
void ggml_compute_forward_map_custom2_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const ggml_custom2_op_f32_t fun) {
|
||||
|
||||
const ggml_tensor * a = dst->src[0];
|
||||
const ggml_tensor * b = dst->src[1];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
fun(dst, a, b);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_custom3
|
||||
|
||||
void ggml_compute_forward_map_custom3_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const ggml_custom3_op_f32_t fun) {
|
||||
|
||||
const ggml_tensor * a = dst->src[0];
|
||||
const ggml_tensor * b = dst->src[1];
|
||||
const ggml_tensor * c = dst->src[1];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
fun(dst, a, b, c);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_custom1
|
||||
|
||||
void ggml_compute_forward_map_custom1(
|
||||
@ -8459,6 +8361,18 @@ void ggml_compute_forward_map_custom3(
|
||||
p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_custom
|
||||
|
||||
void ggml_compute_forward_custom(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
struct ggml_custom_op_params p;
|
||||
memcpy(&p, dst->op_params, sizeof(p));
|
||||
|
||||
p.fun(dst, params->ith, params->nth, p.userdata);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_cross_entropy_loss
|
||||
|
||||
static void ggml_compute_forward_cross_entropy_loss_f32(
|
||||
|
@ -96,29 +96,10 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
|
||||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_unary(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
const ggml_unary_op_f32_t fun);
|
||||
void ggml_compute_forward_map_binary(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
const ggml_binary_op_f32_t fun);
|
||||
void ggml_compute_forward_map_custom1_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
const ggml_custom1_op_f32_t fun);
|
||||
void ggml_compute_forward_map_custom2_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
const ggml_custom2_op_f32_t fun);
|
||||
void ggml_compute_forward_map_custom3_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
const ggml_custom3_op_f32_t fun);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
@ -392,7 +392,11 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
||||
#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
|
||||
vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
|
||||
vec_extract_fp32_from_shortl(vec_xl(0, p))
|
||||
#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
|
||||
static inline unsigned char ggml_endian_byte(int i) {
|
||||
uint16_t tmp_val = 1;
|
||||
return ((unsigned char *)&tmp_val)[i];
|
||||
}
|
||||
#define GGML_ENDIAN_BYTE(i) ggml_endian_byte(i)
|
||||
#define GGML_F16_VEC_STORE(p, r, i) \
|
||||
if (i & 0x1) \
|
||||
vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
|
||||
@ -851,13 +855,17 @@ static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
|
||||
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
||||
}
|
||||
|
||||
return vec_xl(0, tmp);
|
||||
// note: keep type-cast here to prevent compiler bugs
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/12846
|
||||
return vec_xl(0, (const float *)(tmp));
|
||||
}
|
||||
|
||||
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
|
||||
float arr[4];
|
||||
|
||||
vec_xst(y, 0, arr);
|
||||
// note: keep type-cast here to prevent compiler bugs
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/12846
|
||||
vec_xst(y, 0, (float *)(arr));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
||||
|
@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
half * dsti = (half *) cdsti;
|
||||
@ -386,6 +393,16 @@ static void ggml_cpy_f32_f32_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_bf16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@ -581,6 +598,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
@ -634,6 +653,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return nullptr;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
|
@ -3079,6 +3079,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
@ -3213,6 +3216,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
@ -16,6 +16,14 @@
|
||||
#include <arm_sve.h>
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
|
||||
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
|
||||
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
||||
//
|
||||
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
||||
//
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if defined(__F16C__)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
@ -140,8 +148,14 @@ struct ggml_map_custom2_op_params {
|
||||
|
||||
struct ggml_map_custom3_op_params {
|
||||
ggml_custom3_op_t fun;
|
||||
int n_tasks;
|
||||
void * userdata;
|
||||
int n_tasks;
|
||||
void * userdata;
|
||||
};
|
||||
|
||||
struct ggml_custom_op_params {
|
||||
ggml_custom_op_t fun;
|
||||
int n_tasks;
|
||||
void * userdata;
|
||||
};
|
||||
|
||||
// bitset
|
||||
@ -311,13 +325,6 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
|
||||
// for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
|
||||
//
|
||||
#if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
|
||||
|
||||
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
||||
//
|
||||
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
||||
//
|
||||
#include <arm_neon.h>
|
||||
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
||||
|
||||
@ -355,8 +362,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
|
||||
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
||||
|
||||
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
||||
register float f;
|
||||
register double d;
|
||||
float f;
|
||||
double d;
|
||||
__asm__(
|
||||
"mtfprd %0,%2\n"
|
||||
"xscvhpdp %0,%0\n"
|
||||
@ -368,8 +375,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
|
||||
}
|
||||
|
||||
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
||||
register double d;
|
||||
register ggml_fp16_t r;
|
||||
double d;
|
||||
ggml_fp16_t r;
|
||||
__asm__( /* xscvdphp can work on double or single precision */
|
||||
"xscvdphp %0,%2\n"
|
||||
"mffprd %1,%0\n" :
|
||||
|
@ -1334,8 +1334,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
return false;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,13 @@
|
||||
#define GGML_SYCL_ELEMENTWISE_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ggml.h"
|
||||
#include <limits.h>
|
||||
|
||||
template <typename T>
|
||||
T neg_infinity() {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_repeat(const float a, const float b) {
|
||||
return b;
|
||||
@ -24,6 +31,19 @@ static __dpct_inline__ float op_div(const float a, const float b) {
|
||||
return a / b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct typed_data {
|
||||
const T * src;
|
||||
T * dst;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
typed_data<T> cast_data(ggml_tensor * dst) {
|
||||
return {
|
||||
/* .src = */ static_cast<const T *>(dst->src[0]->data),
|
||||
/* .dst = */ static_cast<T *>(dst->data)
|
||||
};
|
||||
}
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@ -65,6 +85,10 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
// ---------
|
||||
|
||||
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
@ -1617,17 +1617,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
|
||||
dst[i] = scale * x[i];
|
||||
}
|
||||
|
||||
static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||
}
|
||||
|
||||
template <typename Ti, typename To>
|
||||
static void pool2d_nchw_kernel(
|
||||
@ -1768,18 +1757,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
||||
});
|
||||
}
|
||||
|
||||
static void clamp_f32_sycl(const float *x, float *dst, const float min,
|
||||
const float max, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
clamp_f32(x, dst, min, max, k, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||
const int nrows, queue_ptr stream) {
|
||||
@ -2258,26 +2235,6 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst
|
||||
SYCL_CHECK(0);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
memcpy(&min, dst->op_params, sizeof(float));
|
||||
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream());
|
||||
/*
|
||||
DPCT1010:88: SYCL uses exceptions to report errors and does not use the
|
||||
error codes. The call was replaced with 0. You need to rewrite this code.
|
||||
*/
|
||||
SYCL_CHECK(0);
|
||||
}
|
||||
|
||||
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
||||
static bool peer_access_enabled = false;
|
||||
|
||||
@ -3218,10 +3175,6 @@ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
ggml_sycl_op_scale(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_clamp(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_diag_mask_inf(ctx, dst);
|
||||
}
|
||||
@ -3700,7 +3653,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
||||
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
if (!g_ggml_sycl_disable_graph) {
|
||||
if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
|
||||
const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
|
||||
if (!graph_support) {
|
||||
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
@ -3711,8 +3665,10 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
model_sycl_graph.end_recording();
|
||||
|
||||
if (!sycl_ctx->exec_graph) {
|
||||
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
||||
const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
|
||||
if (!sycl_ctx->exec_graph || !graph_update_support) {
|
||||
auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
|
||||
model_sycl_graph.finalize();
|
||||
sycl_ctx->exec_graph = std::make_unique<
|
||||
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
||||
} else {
|
||||
@ -3900,7 +3856,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
|
||||
#if defined (GGML_SYCL_F16)
|
||||
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
|
||||
#else
|
||||
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
||||
#endif
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -4022,13 +3982,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
return (op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LOG:
|
||||
return (op->src[0]->type == GGML_TYPE_F32);
|
||||
#if defined (GGML_SYCL_F16)
|
||||
return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
|
||||
#else
|
||||
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
||||
#endif
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
@ -4055,12 +4020,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_IM2COL:
|
||||
// TODO: add support for the new F32 operations
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
@ -4194,6 +4194,12 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
||||
if (split_k == 3) {
|
||||
split_k = 2;
|
||||
}
|
||||
if (ctx->device->coopmat2) {
|
||||
// coopmat2 shader expects splits to be aligned to 256
|
||||
while (split_k > 1 && ((k / split_k) % 256) != 0) {
|
||||
split_k /= 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5743,7 +5749,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_UPSCALE:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
|
||||
return ctx->device->pipeline_upscale_f32;
|
||||
}
|
||||
return nullptr;
|
||||
@ -9398,9 +9404,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
@ -9768,7 +9775,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_CONCAT) {
|
||||
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_UPSCALE) {
|
||||
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
|
||||
} else if (tensor->op == GGML_OP_SCALE) {
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
|
||||
|
@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
|
||||
block_q4_K_packed128 block;
|
||||
};
|
||||
|
||||
#if defined(IS_MUL_MM2)
|
||||
|
||||
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
|
||||
// into shared memory and then process the whole tile using those scales.
|
||||
// There is a fetch function that loads into private variables and then a store
|
||||
// function that stores into shared memory.
|
||||
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
|
||||
// the part that fetches from the structure (which has a different block layout).
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
const uint shAscales_stride = (BM + 2);
|
||||
// 1 scale per 32 elements -> 8 scales per block, per row
|
||||
shared vec2 shAscales[8 * shAscales_stride];
|
||||
uvec4 row_v;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K)
|
||||
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
|
||||
|
||||
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
|
||||
{
|
||||
uint tids_per_row = BLOCK_SIZE / BM;
|
||||
uint is_per_tid = 8 / tids_per_row;
|
||||
uint is_start = is_per_tid * (tid % tids_per_row);
|
||||
uint tid_row = tid / tids_per_row;
|
||||
|
||||
uint row = ir_BM + tid_row;
|
||||
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
|
||||
if (in_bounds || row < p.M) {
|
||||
row_v = data_a_q4_k_packed128[block_index].q4k[0];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(DATA_A_Q5_K)
|
||||
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
|
||||
|
||||
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
|
||||
{
|
||||
uint tids_per_row = BLOCK_SIZE / BM;
|
||||
uint is_per_tid = 8 / tids_per_row;
|
||||
uint is_start = is_per_tid * (tid % tids_per_row);
|
||||
uint tid_row = tid / tids_per_row;
|
||||
|
||||
uint row = ir_BM + tid_row;
|
||||
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
|
||||
if (in_bounds || row < p.M) {
|
||||
row_v = data_a_q5_k_packed128[block_index].q5k[0];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
void store_scalesQ4_K(uint tid)
|
||||
{
|
||||
barrier();
|
||||
|
||||
uint tids_per_row = BLOCK_SIZE / BM;
|
||||
uint is_per_tid = 8 / tids_per_row;
|
||||
uint is_start = is_per_tid * (tid % tids_per_row);
|
||||
uint tid_row = tid / tids_per_row;
|
||||
|
||||
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
|
||||
uint is = idx + is_start;
|
||||
uvec4 v = row_v;
|
||||
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scale0 = v.y;
|
||||
uint32_t scale4 = v.z;
|
||||
uint32_t scale8 = v.w;
|
||||
|
||||
uint32_t sc_lo = scale0;
|
||||
uint32_t mb_lo = scale4;
|
||||
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
||||
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
||||
|
||||
sc = is < 4 ? sc_lo : sc_hi;
|
||||
mbyte = is < 4 ? mb_lo : mb_hi;
|
||||
sc = sc >> (8 * (is & 3));
|
||||
mbyte = mbyte >> (8 * (is & 3));
|
||||
sc &= 0x3F;
|
||||
mbyte &= 0x3F;
|
||||
|
||||
const float d = loadd.x * float(sc);
|
||||
const float m = loadd.y * float(mbyte);
|
||||
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
|
||||
@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||
const uint b = (idx & 0x20) >> 5; // 0,1
|
||||
const uint is = (idx & 0xE0) >> 5; // 0..7
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q4k[0];
|
||||
|
||||
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||
|
||||
uint32_t sc;
|
||||
@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||
|
||||
const float d = loadd.x * float(sc);
|
||||
const float m = loadd.y * float(mbyte);
|
||||
#endif
|
||||
|
||||
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
||||
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
|
||||
@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
||||
const uint b = (idx & 0x20) >> 5; // 0,1
|
||||
const uint is = (idx & 0xE0) >> 5; // 0..7
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q5k[0];
|
||||
|
||||
const f16vec2 loadd = unpackFloat2x16(v.x);
|
||||
@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
||||
|
||||
const float16_t d = loadd.x * float16_t(sc);
|
||||
const float16_t m = loadd.y * float16_t(mbyte);
|
||||
#endif
|
||||
|
||||
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
|
||||
qh = ((qh >> is) & 0x101) << 4;
|
||||
@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
||||
qs = (qs >> (b * 4)) & 0x0F0F;
|
||||
qs = unpack8(qs | qh)[idx & 1];
|
||||
|
||||
float16_t ret = d * (float16_t(qs)) - m;
|
||||
float ret = d * float(qs) - m;
|
||||
|
||||
return ret;
|
||||
return float16_t(ret);
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
|
||||
@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||
#define dequantFuncA dequantFuncQ3_K
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define dequantFuncA dequantFuncQ4_K
|
||||
#define fetch_scales fetch_scalesQ4_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#define fetch_scales fetch_scalesQ5_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
|
@ -330,9 +330,11 @@ void main() {
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
O = eMdiag * O;
|
||||
// multiply with fp16 accumulation, then add to O.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
||||
PV = coopMatMulAdd(P_A, V, PV);
|
||||
|
||||
O = coopMatMulAdd(P_A, V, O);
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
|
@ -19,6 +19,9 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#define IS_MUL_MM2 1
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 256;
|
||||
layout (constant_id = 1) const uint BM = 64;
|
||||
layout (constant_id = 2) const uint BN = 64;
|
||||
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
||||
@ -70,6 +73,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#define DECODEFUNCA
|
||||
#endif
|
||||
|
||||
#if !defined(fetch_scales)
|
||||
#define fetch_scales(a, b, c, d, e, f)
|
||||
#endif
|
||||
#if !defined(store_scales)
|
||||
#define store_scales(a)
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
@ -116,6 +126,8 @@ void main() {
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint tid = gl_LocalInvocationIndex;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.z;
|
||||
#else
|
||||
@ -218,14 +230,21 @@ void main() {
|
||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||
|
||||
#if !defined(MUL_MAT_ID)
|
||||
|
||||
const uint START_ALIGN_K = 256;
|
||||
// For Qi_K (block size 256), unroll whole 256 element tiles.
|
||||
// For legacy quants (block size 32), unroll 8x.
|
||||
const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
|
||||
const uint unroll_count = UNROLL_K / BK;
|
||||
|
||||
// Detect a fast path where all loads are entirely in bounds and no clamping is required
|
||||
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
||||
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
|
||||
#if QUANT_K == 1
|
||||
(stride_a % 8) == 0 &&
|
||||
#endif
|
||||
(stride_b % 8) == 0 && (start_k % 8) == 0) {
|
||||
(stride_b % 8) == 0) {
|
||||
// Hint to the compiler that values are aligned (want 16B alignment)
|
||||
start_k &= ~7;
|
||||
start_k &= ~(START_ALIGN_K-1);
|
||||
stride_b &= ~7;
|
||||
#if QUANT_K == 1
|
||||
stride_a &= ~7;
|
||||
@ -234,11 +253,39 @@ void main() {
|
||||
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
|
||||
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
||||
|
||||
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
||||
uint k_iters = (end_k - start_k) / UNROLL_K;
|
||||
uint block_k = start_k;
|
||||
|
||||
// fetch scale values for a tile of quants. These will be copied into shared memory.
|
||||
// The fetches and stores are pipelined to hide the latency.
|
||||
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
|
||||
|
||||
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
for (uint i = 0; i < k_iters; ++i) {
|
||||
|
||||
store_scales(tid);
|
||||
if (block_k + UNROLL_K < end_k) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
|
||||
}
|
||||
|
||||
// Manually partial unroll
|
||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
}
|
||||
// Do any remaining iterations that were not unrolled
|
||||
if (block_k < end_k) {
|
||||
store_scales(tid);
|
||||
}
|
||||
while (block_k < end_k) {
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||
|
||||
@ -246,6 +293,7 @@ void main() {
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
@ -253,8 +301,30 @@ void main() {
|
||||
return;
|
||||
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
for (uint i = 0; i < k_iters; ++i) {
|
||||
|
||||
store_scales(tid);
|
||||
if (block_k + UNROLL_K < end_k) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
|
||||
}
|
||||
|
||||
// Manually partial unroll
|
||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
}
|
||||
// Do any remaining iterations that were not unrolled
|
||||
if (block_k < end_k) {
|
||||
store_scales(tid);
|
||||
}
|
||||
while (block_k < end_k) {
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||
|
||||
@ -262,6 +332,7 @@ void main() {
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
@ -269,8 +340,31 @@ void main() {
|
||||
return;
|
||||
} else {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
for (uint i = 0; i < k_iters; ++i) {
|
||||
|
||||
store_scales(tid);
|
||||
if (block_k + UNROLL_K < end_k) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
|
||||
}
|
||||
|
||||
// Manually partial unroll
|
||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
}
|
||||
// Do any remaining iterations that were not unrolled
|
||||
if (block_k < end_k) {
|
||||
store_scales(tid);
|
||||
}
|
||||
while (block_k < end_k) {
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
@ -278,6 +372,7 @@ void main() {
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
@ -298,47 +393,29 @@ void main() {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
||||
|
||||
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
store_scales(tid);
|
||||
if (block_k + BK < end_k) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
// Clamping is expensive, so detect different code paths for each combination
|
||||
// of A and B needing clamping.
|
||||
bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
bool unclampedB = true;
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
if (unclampedA && unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||
#endif
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (unclampedA && !unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (!unclampedA && unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||
#endif
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (!unclampedA && !unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
|
278
ggml/src/ggml.c
278
ggml/src/ggml.c
@ -982,23 +982,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
|
||||
"UNARY",
|
||||
|
||||
"MAP_UNARY",
|
||||
"MAP_BINARY",
|
||||
|
||||
"MAP_CUSTOM1_F32",
|
||||
"MAP_CUSTOM2_F32",
|
||||
"MAP_CUSTOM3_F32",
|
||||
|
||||
"MAP_CUSTOM1",
|
||||
"MAP_CUSTOM2",
|
||||
"MAP_CUSTOM3",
|
||||
|
||||
"CUSTOM",
|
||||
|
||||
"CROSS_ENTROPY_LOSS",
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@ -1081,23 +1076,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
|
||||
"unary(x)",
|
||||
|
||||
"f(x)",
|
||||
"f(x,y)",
|
||||
|
||||
"custom_f32(x)",
|
||||
"custom_f32(x,y)",
|
||||
"custom_f32(x,y,z)",
|
||||
"map_custom(x)",
|
||||
"map_custom(x,y)",
|
||||
"map_custom(x,y,z)",
|
||||
|
||||
"custom(x)",
|
||||
"custom(x,y)",
|
||||
"custom(x,y,z)",
|
||||
|
||||
"cross_entropy_loss(x,y)",
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@ -4184,7 +4174,8 @@ static struct ggml_tensor * ggml_upscale_impl(
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3) {
|
||||
int ne3,
|
||||
enum ggml_scale_mode mode) {
|
||||
GGML_ASSERT(a->ne[0] <= ne0);
|
||||
GGML_ASSERT(a->ne[1] <= ne1);
|
||||
GGML_ASSERT(a->ne[2] <= ne2);
|
||||
@ -4192,6 +4183,8 @@ static struct ggml_tensor * ggml_upscale_impl(
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, mode);
|
||||
|
||||
result->op = GGML_OP_UPSCALE;
|
||||
result->src[0] = a;
|
||||
|
||||
@ -4201,8 +4194,9 @@ static struct ggml_tensor * ggml_upscale_impl(
|
||||
struct ggml_tensor * ggml_upscale(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int scale_factor) {
|
||||
return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
|
||||
int scale_factor,
|
||||
enum ggml_scale_mode mode) {
|
||||
return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_upscale_ext(
|
||||
@ -4211,8 +4205,9 @@ struct ggml_tensor * ggml_upscale_ext(
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3) {
|
||||
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
|
||||
int ne3,
|
||||
enum ggml_scale_mode mode) {
|
||||
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
|
||||
}
|
||||
|
||||
// ggml_pad
|
||||
@ -4842,179 +4837,6 @@ struct ggml_tensor * ggml_unary_inplace(
|
||||
return ggml_unary_impl(ctx, a, op, true);
|
||||
}
|
||||
|
||||
// ggml_map_unary
|
||||
|
||||
static struct ggml_tensor * ggml_map_unary_impl_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
const ggml_unary_op_f32_t fun,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
||||
|
||||
result->op = GGML_OP_MAP_UNARY;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_unary_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
const ggml_unary_op_f32_t fun) {
|
||||
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
const ggml_unary_op_f32_t fun) {
|
||||
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
||||
}
|
||||
|
||||
// ggml_map_binary
|
||||
|
||||
static struct ggml_tensor * ggml_map_binary_impl_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
const ggml_binary_op_f32_t fun,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
||||
|
||||
result->op = GGML_OP_MAP_BINARY;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_binary_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
const ggml_binary_op_f32_t fun) {
|
||||
return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_binary_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
const ggml_binary_op_f32_t fun) {
|
||||
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
||||
}
|
||||
|
||||
// ggml_map_custom1_f32
|
||||
|
||||
static struct ggml_tensor * ggml_map_custom1_impl_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
const ggml_custom1_op_f32_t fun,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
||||
|
||||
result->op = GGML_OP_MAP_CUSTOM1_F32;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_custom1_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
const ggml_custom1_op_f32_t fun) {
|
||||
return ggml_map_custom1_impl_f32(ctx, a, fun, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_custom1_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
const ggml_custom1_op_f32_t fun) {
|
||||
return ggml_map_custom1_impl_f32(ctx, a, fun, true);
|
||||
}
|
||||
|
||||
// ggml_map_custom2_f32
|
||||
|
||||
static struct ggml_tensor * ggml_map_custom2_impl_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
const ggml_custom2_op_f32_t fun,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
||||
|
||||
result->op = GGML_OP_MAP_CUSTOM2_F32;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_custom2_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
const ggml_custom2_op_f32_t fun) {
|
||||
return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_custom2_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
const ggml_custom2_op_f32_t fun) {
|
||||
return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
|
||||
}
|
||||
|
||||
// ggml_map_custom3_f32
|
||||
|
||||
static struct ggml_tensor * ggml_map_custom3_impl_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
const ggml_custom3_op_f32_t fun,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
|
||||
|
||||
result->op = GGML_OP_MAP_CUSTOM3_F32;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_custom3_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
const ggml_custom3_op_f32_t fun) {
|
||||
return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_map_custom3_inplace_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
const ggml_custom3_op_f32_t fun) {
|
||||
return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
|
||||
}
|
||||
|
||||
// ggml_map_custom1
|
||||
|
||||
static struct ggml_tensor * ggml_map_custom1_impl(
|
||||
@ -5033,7 +4855,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
|
||||
/*.n_tasks =*/ n_tasks,
|
||||
/*.userdata =*/ userdata
|
||||
};
|
||||
ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_MAP_CUSTOM1;
|
||||
result->src[0] = a;
|
||||
@ -5078,7 +4900,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
|
||||
/*.n_tasks =*/ n_tasks,
|
||||
/*.userdata =*/ userdata
|
||||
};
|
||||
ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_MAP_CUSTOM2;
|
||||
result->src[0] = a;
|
||||
@ -5127,7 +4949,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
|
||||
/*.n_tasks =*/ n_tasks,
|
||||
/*.userdata =*/ userdata
|
||||
};
|
||||
ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_MAP_CUSTOM3;
|
||||
result->src[0] = a;
|
||||
@ -5159,6 +4981,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
|
||||
return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_custom_4d(
|
||||
struct ggml_context * ctx,
|
||||
enum ggml_type type,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3,
|
||||
struct ggml_tensor ** args,
|
||||
int n_args,
|
||||
ggml_custom_op_t fun,
|
||||
int n_tasks,
|
||||
void * userdata) {
|
||||
|
||||
GGML_ASSERT(n_args < GGML_MAX_SRC);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
|
||||
|
||||
struct ggml_custom_op_params params = {
|
||||
/*.fun =*/ fun,
|
||||
/*.n_tasks =*/ n_tasks,
|
||||
/*.userdata =*/ userdata
|
||||
};
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_CUSTOM;
|
||||
for (int i = 0; i < n_args; i++) {
|
||||
result->src[i] = args[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_custom_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor ** args,
|
||||
int n_args,
|
||||
ggml_custom_op_t fun,
|
||||
int n_tasks,
|
||||
void * userdata) {
|
||||
|
||||
GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
struct ggml_custom_op_params params = {
|
||||
/*.fun =*/ fun,
|
||||
/*.n_tasks =*/ n_tasks,
|
||||
/*.userdata =*/ userdata
|
||||
};
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_CUSTOM;
|
||||
result->src[0] = a;
|
||||
for (int i = 0; i < n_args; i++) {
|
||||
result->src[i + 1] = args[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
// ggml_cross_entropy_loss
|
||||
|
||||
struct ggml_tensor * ggml_cross_entropy_loss(
|
||||
|
@ -248,6 +248,8 @@ class MODEL_ARCH(IntEnum):
|
||||
QWEN2 = auto()
|
||||
QWEN2MOE = auto()
|
||||
QWEN2VL = auto()
|
||||
QWEN3 = auto()
|
||||
QWEN3MOE = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PHIMOE = auto()
|
||||
@ -278,6 +280,7 @@ class MODEL_ARCH(IntEnum):
|
||||
DEEPSEEK = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
CHATGLM = auto()
|
||||
GLM4 = auto()
|
||||
BITNET = auto()
|
||||
T5 = auto()
|
||||
T5ENCODER = auto()
|
||||
@ -453,6 +456,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.QWEN2: "qwen2",
|
||||
MODEL_ARCH.QWEN2MOE: "qwen2moe",
|
||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.QWEN3: "qwen3",
|
||||
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PHIMOE: "phimoe",
|
||||
@ -483,6 +488,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.DEEPSEEK: "deepseek",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
MODEL_ARCH.GLM4: "glm4",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
@ -953,6 +959,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.QWEN3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN3MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.PLAMO: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@ -1523,6 +1563,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.GLM4 : [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.BITNET: [
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
|
@ -13,7 +13,7 @@ class TensorNameMap:
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
@ -30,6 +30,7 @@ class TensorNameMap:
|
||||
"rwkv.embeddings", # rwkv6
|
||||
"model.embeddings", # rwkv7
|
||||
"model.word_embeddings", # bailingmoe
|
||||
"language_model.model.embed_tokens", # llama4
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@ -67,6 +68,7 @@ class TensorNameMap:
|
||||
"output_layer", # chatglm
|
||||
"head", # rwkv
|
||||
"head.out", # wavtokenizer
|
||||
"language_model.lm_head", # llama4
|
||||
),
|
||||
|
||||
# Output norm
|
||||
@ -89,6 +91,7 @@ class TensorNameMap:
|
||||
"rwkv.ln_out", # rwkv6
|
||||
"model.ln_out", # rwkv7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
"language_model.model.norm", # llama4
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
@ -130,6 +133,7 @@ class TensorNameMap:
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"language_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@ -169,6 +173,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@ -183,6 +188,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@ -196,6 +202,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wv", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
|
||||
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention output
|
||||
@ -222,6 +229,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@ -233,7 +241,8 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_POST_NORM: (
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
|
||||
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
|
||||
),
|
||||
|
||||
# Rotary embeddings
|
||||
@ -259,6 +268,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@ -269,6 +279,7 @@ class TensorNameMap:
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
|
||||
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
@ -278,6 +289,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
"language_model.model.layers.{bid}.feed_forward.router", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@ -306,7 +318,7 @@ class TensorNameMap:
|
||||
"h.{bid}.mlp.c_fc", # gpt2
|
||||
"transformer.h.{bid}.mlp.fc1", # phi2
|
||||
"model.layers.{bid}.mlp.fc1", # phi2
|
||||
"model.layers.{bid}.mlp.gate_up_proj", # phi3
|
||||
"model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414
|
||||
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w3", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
|
||||
@ -315,6 +327,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@ -323,11 +336,13 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
@ -348,6 +363,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
@ -356,11 +372,13 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
@ -389,6 +407,7 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@ -398,11 +417,13 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
|
@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
||||
@ -67,3 +71,194 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
||||
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
|
||||
|
||||
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteTensor:
|
||||
dtype: str
|
||||
shape: tuple[int, ...]
|
||||
offset_start: int
|
||||
size: int
|
||||
url: str
|
||||
|
||||
def data(self) -> bytearray:
|
||||
# TODO: handle request errors (maybe with limited retries?)
|
||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
|
||||
return data
|
||||
|
||||
|
||||
class SafetensorRemote:
|
||||
"""
|
||||
Uility class to handle remote safetensor files.
|
||||
This class is designed to work with Hugging Face model repositories.
|
||||
|
||||
Example (one model has single safetensor file, the other has multiple):
|
||||
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
|
||||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
||||
print(tensors)
|
||||
|
||||
Example reading tensor data:
|
||||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
||||
for name, meta in tensors.items():
|
||||
dtype, shape, offset_start, size, remote_safetensor_url = meta
|
||||
# read the tensor data
|
||||
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
|
||||
print(data)
|
||||
"""
|
||||
|
||||
BASE_DOMAIN = "https://huggingface.co"
|
||||
ALIGNMENT = 8 # bytes
|
||||
|
||||
@classmethod
|
||||
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
||||
"""
|
||||
Get list of tensors from a Hugging Face model repository.
|
||||
|
||||
Returns a dictionary of tensor names and their metadata.
|
||||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
|
||||
"""
|
||||
# case 1: model has only one single model.safetensor file
|
||||
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
|
||||
if is_single_file:
|
||||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
|
||||
return cls.get_list_tensors(url)
|
||||
|
||||
# case 2: model has multiple files
|
||||
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
|
||||
is_multiple_files = cls.check_file_exist(index_url)
|
||||
if is_multiple_files:
|
||||
# read the index file
|
||||
index_data = cls.get_data_by_range(index_url, 0)
|
||||
index_str = index_data.decode('utf-8')
|
||||
index_json = json.loads(index_str)
|
||||
assert index_json.get("weight_map") is not None, "weight_map not found in index file"
|
||||
weight_map = index_json["weight_map"]
|
||||
# get the list of files
|
||||
all_files = list(set(weight_map.values()))
|
||||
all_files.sort() # make sure we load shard files in order
|
||||
# get the list of tensors
|
||||
tensors: dict[str, RemoteTensor] = {}
|
||||
for file in all_files:
|
||||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
|
||||
for key, val in cls.get_list_tensors(url).items():
|
||||
tensors[key] = val
|
||||
return tensors
|
||||
|
||||
raise ValueError(f"Model {model_id} does not have any safetensor files")
|
||||
|
||||
@classmethod
|
||||
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
|
||||
"""
|
||||
Get list of tensors from a remote safetensor file.
|
||||
|
||||
Returns a dictionary of tensor names and their metadata.
|
||||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
|
||||
"""
|
||||
metadata, data_start_offset = cls.get_metadata(url)
|
||||
res: dict[str, RemoteTensor] = {}
|
||||
|
||||
for name, meta in metadata.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
if not isinstance(meta, dict):
|
||||
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
|
||||
try:
|
||||
dtype = meta["dtype"]
|
||||
shape = meta["shape"]
|
||||
offset_start_relative, offset_end_relative = meta["data_offsets"]
|
||||
size = offset_end_relative - offset_start_relative
|
||||
offset_start = data_start_offset + offset_start_relative
|
||||
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls, url: str) -> tuple[dict, int]:
|
||||
"""
|
||||
Get JSON metadata from a remote safetensor file.
|
||||
|
||||
Returns tuple of (metadata, data_start_offset)
|
||||
"""
|
||||
# Request first 5MB of the file (hopefully enough for metadata)
|
||||
read_size = 5 * 1024 * 1024
|
||||
raw_data = cls.get_data_by_range(url, 0, read_size)
|
||||
|
||||
# Parse header
|
||||
# First 8 bytes contain the metadata length as u64 little-endian
|
||||
if len(raw_data) < 8:
|
||||
raise ValueError("Not enough data to read metadata size")
|
||||
metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
|
||||
|
||||
# Calculate the data start offset
|
||||
data_start_offset = 8 + metadata_length
|
||||
alignment = SafetensorRemote.ALIGNMENT
|
||||
if data_start_offset % alignment != 0:
|
||||
data_start_offset += alignment - (data_start_offset % alignment)
|
||||
|
||||
# Check if we have enough data to read the metadata
|
||||
if len(raw_data) < 8 + metadata_length:
|
||||
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
|
||||
|
||||
# Extract metadata bytes and parse as JSON
|
||||
metadata_bytes = raw_data[8:8 + metadata_length]
|
||||
metadata_str = metadata_bytes.decode('utf-8')
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
return metadata, data_start_offset
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
||||
|
||||
@classmethod
|
||||
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
||||
"""
|
||||
Get raw byte data from a remote file by range.
|
||||
If size is not specified, it will read the entire file.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
||||
headers = cls._get_request_headers()
|
||||
if size > -1:
|
||||
headers["Range"] = f"bytes={start}-{start + size}"
|
||||
response = requests.get(url, allow_redirects=True, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Get raw byte data
|
||||
return response.content[:size]
|
||||
|
||||
@classmethod
|
||||
def check_file_exist(cls, url: str) -> bool:
|
||||
"""
|
||||
Check if a file exists at the given URL.
|
||||
Returns True if the file exists, False otherwise.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
||||
try:
|
||||
headers = cls._get_request_headers()
|
||||
headers["Range"] = "bytes=0-0"
|
||||
response = requests.head(url, allow_redirects=True, headers=headers)
|
||||
# Success (2xx) or redirect (3xx)
|
||||
return 200 <= response.status_code < 400
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _get_request_headers(cls) -> dict[str, str]:
|
||||
"""Prepare common headers for requests."""
|
||||
headers = {"User-Agent": "convert_hf_to_gguf"}
|
||||
if os.environ.get("HF_TOKEN"):
|
||||
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
||||
return headers
|
||||
|
@ -158,13 +158,13 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||
# scripts/gen-authors.sh -> scripts/gen-authors.sh
|
||||
|
||||
cat ggml-src.patch | sed -E \
|
||||
-e 's/(^[[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \
|
||||
@ -180,11 +180,11 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \
|
||||
-e 's/^([[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \
|
||||
-e 's/^([[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \
|
||||
-e 's/^([[:space:]]| [ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \
|
||||
-e 's/^([[:space:]]| [ab]\/)LICENSE/\1LICENSE/g' \
|
||||
-e 's/^([[:space:]]| [ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)LICENSE/\1LICENSE/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \
|
||||
> ggml-src.patch.tmp
|
||||
mv ggml-src.patch.tmp ggml-src.patch
|
||||
|
||||
|
@ -1 +1 @@
|
||||
70e85f61f1fdcd1064a1e032ff564d5b5e67560c
|
||||
2abf606f098844faebee578996cae9c6d63a40e2
|
||||
|
@ -32,7 +32,7 @@ add_library(llama
|
||||
unicode.h
|
||||
)
|
||||
|
||||
target_include_directories(llama PUBLIC . ../include ../common)
|
||||
target_include_directories(llama PUBLIC . ../include)
|
||||
target_compile_features (llama PUBLIC cxx_std_17) # don't bump
|
||||
|
||||
target_link_libraries(llama PUBLIC ggml)
|
||||
|
@ -26,6 +26,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_QWEN2, "qwen2" },
|
||||
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_QWEN3, "qwen3" },
|
||||
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
@ -52,6 +54,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DEEPSEEK, "deepseek" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||
{ LLM_ARCH_GLM4, "glm4" },
|
||||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
@ -595,6 +598,45 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PHI2,
|
||||
{
|
||||
@ -1111,6 +1153,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GLM4,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BITNET,
|
||||
{
|
||||
|
@ -30,6 +30,8 @@ enum llm_arch {
|
||||
LLM_ARCH_QWEN2,
|
||||
LLM_ARCH_QWEN2MOE,
|
||||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_QWEN3,
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
@ -56,6 +58,7 @@ enum llm_arch {
|
||||
LLM_ARCH_DEEPSEEK,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_GLM4,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
@ -254,6 +257,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_LAYER_OUT_NORM,
|
||||
LLM_TENSOR_POST_ATTN_NORM,
|
||||
LLM_TENSOR_POST_MLP_NORM,
|
||||
LLM_TENSOR_SSM_IN,
|
||||
LLM_TENSOR_SSM_CONV1D,
|
||||
LLM_TENSOR_SSM_X,
|
||||
|
@ -787,6 +787,22 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHI2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@ -1189,6 +1205,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GLM4:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 40: type = LLM_TYPE_9B; break;
|
||||
case 61: type = LLM_TYPE_32B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@ -2360,6 +2385,77 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
|
||||
if (n_expert == 0) {
|
||||
throw std::runtime_error("n_expert must be > 0 for QWEN3MOE");
|
||||
}
|
||||
if (n_expert_used == 0) {
|
||||
throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE");
|
||||
}
|
||||
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHI2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@ -3389,6 +3485,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GLM4:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (layer.wqkv == nullptr) {
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);
|
||||
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@ -4168,6 +4303,10 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_QWEN3MOE) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) {
|
||||
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||
@ -4349,8 +4488,8 @@ struct llm_build_llama : public llm_graph_context {
|
||||
|
||||
if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
|
||||
// Llama4TextL2Norm
|
||||
Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
|
||||
Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
|
||||
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
|
||||
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
@ -6582,6 +6721,255 @@ struct llm_build_qwen2moe : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_qwen3 : public llm_graph_context {
|
||||
llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_qwen3moe : public llm_graph_context {
|
||||
llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * moe_out =
|
||||
build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
cur = moe_out;
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_phi2 : public llm_graph_context {
|
||||
llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
@ -10514,6 +10902,157 @@ struct llm_build_chatglm : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_glm4 : public llm_graph_context {
|
||||
llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// Pre-attention norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = nullptr;
|
||||
ggml_tensor * Kcur = nullptr;
|
||||
ggml_tensor * Vcur = nullptr;
|
||||
|
||||
if (model.layers[il].wqkv == nullptr) {
|
||||
Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
} else {
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
if (model.layers[il].bqkv) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// Post-attention norm (new!)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "post_attn_norm", il);
|
||||
|
||||
// Add the input (residual connection after post-attention norm)
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// FF
|
||||
{
|
||||
// Pre-MLP norm
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MLP
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// Post-MLP norm
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "post_mlp_norm", il);
|
||||
}
|
||||
|
||||
// Add residual connection after post-MLP norm
|
||||
inpL = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(inpL, "l_out", il);
|
||||
}
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(inpL,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// Output projection
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_nemotron : public llm_graph_context {
|
||||
llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
@ -12282,6 +12821,14 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen2moe>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3:
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen3>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen3moe>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_PHI2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_phi2>(*this, params, gf);
|
||||
@ -12387,6 +12934,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_chatglm>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_GLM4:
|
||||
{
|
||||
llm = std::make_unique<llm_build_glm4>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
llm = std::make_unique<llm_build_bitnet>(*this, params, gf);
|
||||
@ -12584,6 +13135,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_PLM:
|
||||
case LLM_ARCH_CHATGLM:
|
||||
case LLM_ARCH_GLM4:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
@ -12601,6 +13153,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_QWEN:
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
case LLM_ARCH_QWEN3:
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
case LLM_ARCH_OLMO2:
|
||||
case LLM_ARCH_OLMOE:
|
||||
case LLM_ARCH_PHI2:
|
||||
|
@ -1572,6 +1572,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "glm4" ||
|
||||
tokenizer_pre == "chatglm-bpe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
|
||||
special_bos_id = LLAMA_TOKEN_NULL;
|
||||
|
@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) {
|
||||
}
|
||||
}
|
||||
|
||||
static std::string var_to_str(ggml_scale_mode mode) {
|
||||
switch (mode) {
|
||||
case GGML_SCALE_MODE_NEAREST: return "nearest";
|
||||
case GGML_SCALE_MODE_BILINEAR: return "bilinear";
|
||||
default: return std::to_string(mode);
|
||||
}
|
||||
}
|
||||
|
||||
#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
|
||||
|
||||
#define VARS_TO_STR1(a) VAR_TO_STR(a)
|
||||
@ -2948,15 +2956,16 @@ struct test_upscale : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int32_t scale_factor;
|
||||
const bool transpose;
|
||||
const ggml_scale_mode mode;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, scale_factor, transpose);
|
||||
return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
|
||||
}
|
||||
|
||||
test_upscale(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {512, 512, 3, 1},
|
||||
int32_t scale_factor = 2, bool transpose = false)
|
||||
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
|
||||
int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
|
||||
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
@ -2967,7 +2976,7 @@ struct test_upscale : public test_case {
|
||||
ggml_set_name(a, "a_transposed");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
|
||||
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
@ -2979,21 +2988,23 @@ struct test_upscale_ext : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const std::array<int64_t, 4> ne_tgt;
|
||||
const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, ne_tgt);
|
||||
return VARS_TO_STR4(type, ne, ne_tgt, mode);
|
||||
}
|
||||
|
||||
test_upscale_ext(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {2, 5, 7, 11},
|
||||
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
|
||||
: type(type), ne(ne), ne_tgt(ne_tgt) {}
|
||||
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
|
||||
ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
|
||||
: type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
|
||||
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
@ -4399,12 +4410,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||
}
|
||||
|
||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
||||
test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_mean());
|
||||
test_cases.emplace_back(new test_upscale());
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
|
||||
test_cases.emplace_back(new test_upscale_ext());
|
||||
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
|
||||
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||
test_cases.emplace_back(new test_acc());
|
||||
|
@ -19,6 +19,8 @@ static std::string normalize_newlines(const std::string & s) {
|
||||
#endif
|
||||
}
|
||||
|
||||
#define U8C(x) (const char*)(u8##x)
|
||||
|
||||
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
|
||||
common_chat_msg msg;
|
||||
msg.role = role;
|
||||
@ -35,6 +37,8 @@ int main(void) {
|
||||
{"assistant", " I am an assistant "},
|
||||
{"user", "Another question"},
|
||||
};
|
||||
|
||||
// std::string wrong = /* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}";
|
||||
struct TestCase {
|
||||
std::string name;
|
||||
std::string template_str;
|
||||
@ -177,7 +181,7 @@ int main(void) {
|
||||
},
|
||||
{
|
||||
/* .name= */ "ChatGLM4",
|
||||
/* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||
/* .template_str= */ U8C("[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"),
|
||||
/* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .bos_token= */ "",
|
||||
@ -193,8 +197,8 @@ int main(void) {
|
||||
},
|
||||
{
|
||||
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
|
||||
/* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
|
||||
/* .expected_output= */ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
|
||||
/* .template_str= */ U8C("{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"),
|
||||
/* .expected_output= */ U8C("You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>"),
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "",
|
||||
@ -202,7 +206,7 @@ int main(void) {
|
||||
{
|
||||
/* .name= */ "DeepSeek-V2",
|
||||
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||
/* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
|
||||
/* .expected_output= */ U8C("You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:"),
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "<|end▁of▁sentence|>",
|
||||
@ -256,7 +260,7 @@ int main(void) {
|
||||
},
|
||||
{
|
||||
/* .name= */ "Infinigence/Megrez-3B-Instruct",
|
||||
/* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
|
||||
/* .template_str= */ U8C("{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"),
|
||||
/* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .bos_token= */ "",
|
||||
|
Loading…
x
Reference in New Issue
Block a user