feat: 实现 Whisper ONNX 完整推理管线
新增组件: - MelSpectrogram: Mel 频谱图提取 (Hann 窗 + FFT + Mel 滤波器组) - WhisperTokenizer: BPE 分词器 (支持 token 编解码和特殊 token) 核心改进: - STTEngine 动态检测 ONNX 模型输入/输出名称 - 支持两种模型格式: 直接输出 [1, vocab_size] 和自回归 [1, seq, vocab] - argmax + softmax 解码 + 置信度计算 - infer() 接口改为 language 参数替代 isStreaming UI 调整: - STTTestPage 和 FileTranscribePage 适配新的 infer() 接口 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
09074a71fe
commit
bba124aee4
@ -39,6 +39,8 @@ set(SOURCES
|
|||||||
|
|
||||||
# Core
|
# Core
|
||||||
src/core/stt_engine.cpp
|
src/core/stt_engine.cpp
|
||||||
|
src/core/mel_spectrogram.cpp
|
||||||
|
src/core/whisper_tokenizer.cpp
|
||||||
src/core/audio_processor.cpp
|
src/core/audio_processor.cpp
|
||||||
src/core/decoder.cpp
|
src/core/decoder.cpp
|
||||||
src/core/tokenizer.cpp
|
src/core/tokenizer.cpp
|
||||||
@ -68,6 +70,8 @@ set(HEADERS
|
|||||||
src/app/config_manager.h
|
src/app/config_manager.h
|
||||||
|
|
||||||
src/core/stt_engine.h
|
src/core/stt_engine.h
|
||||||
|
src/core/mel_spectrogram.h
|
||||||
|
src/core/whisper_tokenizer.h
|
||||||
src/core/audio_processor.h
|
src/core/audio_processor.h
|
||||||
src/core/decoder.h
|
src/core/decoder.h
|
||||||
src/core/tokenizer.h
|
src/core/tokenizer.h
|
||||||
|
|||||||
221
src/core/mel_spectrogram.cpp
Normal file
221
src/core/mel_spectrogram.cpp
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
#include "mel_spectrogram.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
|
#ifndef M_PI
|
||||||
|
#define M_PI 3.14159265358979323846
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Whisper 模型参数
|
||||||
|
static const int kWhisperDurationSec = 30; // 30 秒音频
|
||||||
|
static const float kMinLevel = -11.5f; // 对数谱图最小值
|
||||||
|
|
||||||
|
namespace impress {
|
||||||
|
|
||||||
|
// 简易复数运算
|
||||||
|
struct Complex {
|
||||||
|
float re, im;
|
||||||
|
Complex(float r = 0, float i = 0) : re(r), im(i) {}
|
||||||
|
Complex operator+(const Complex& o) const { return {re + o.re, im + o.im}; }
|
||||||
|
Complex operator-(const Complex& o) const { return {re - o.re, im - o.im}; }
|
||||||
|
Complex operator*(const Complex& o) const {
|
||||||
|
return {re * o.re - im * o.im, re * o.im + im * o.re};
|
||||||
|
}
|
||||||
|
Complex operator*(float s) const { return {re * s, im * s}; }
|
||||||
|
float magnitudeSq() const { return re * re + im * im; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Radix-2 Cooley-Tukey FFT
|
||||||
|
*/
|
||||||
|
static void fft(std::vector<Complex>& x) {
|
||||||
|
int n = static_cast<int>(x.size());
|
||||||
|
if (n <= 1) return;
|
||||||
|
|
||||||
|
// 位反转置换
|
||||||
|
for (int i = 1, j = 0; i < n; i++) {
|
||||||
|
int bit = n >> 1;
|
||||||
|
for (; j & bit; bit >>= 1) j ^= bit;
|
||||||
|
j ^= bit;
|
||||||
|
if (i < j) std::swap(x[i], x[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 蝶形运算
|
||||||
|
for (int len = 2; len <= n; len *= 2) {
|
||||||
|
float angle = -2.0f * static_cast<float>(M_PI) / len;
|
||||||
|
Complex wlen(std::cos(angle), std::sin(angle));
|
||||||
|
for (int i = 0; i < n; i += len) {
|
||||||
|
Complex w(1.0f, 0.0f);
|
||||||
|
for (int j = 0; j < len / 2; j++) {
|
||||||
|
Complex u = x[i + j];
|
||||||
|
Complex v = x[i + j + len / 2] * w;
|
||||||
|
x[i + j] = u + v;
|
||||||
|
x[i + j + len / 2] = u - v;
|
||||||
|
w = w * wlen;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MelSpectrogram::MelSpectrogram(int nMel, int nFFT, int hopLength, int sampleRate)
|
||||||
|
: nMel_(nMel)
|
||||||
|
, nFFT_(nFFT)
|
||||||
|
, hopLength_(hopLength)
|
||||||
|
, sampleRate_(sampleRate)
|
||||||
|
{
|
||||||
|
// FFT 窗口大小向上取 2 的幂
|
||||||
|
nFFTWindow_ = 1;
|
||||||
|
while (nFFTWindow_ < nFFT) nFFTWindow_ *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MelSpectrogram::nFrames(int numSamples) const {
|
||||||
|
return (numSamples - nFFT_ + hopLength_) / hopLength_;
|
||||||
|
}
|
||||||
|
|
||||||
|
float MelSpectrogram::hzToMel(float hz) {
|
||||||
|
return 1125.0f * std::log(1.0f + hz / 700.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
float MelSpectrogram::melToHz(float mel) {
|
||||||
|
return 700.0f * (std::exp(mel / 1125.0f) - 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> MelSpectrogram::hannWindow(int size) const {
|
||||||
|
std::vector<float> window(size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
window[i] = 0.5f * (1.0f - std::cos(2.0f * static_cast<float>(M_PI) * i / (size - 1)));
|
||||||
|
}
|
||||||
|
return window;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> MelSpectrogram::melFilterbank() const {
|
||||||
|
// Mel 滤波器组 [nMel x (nFFT/2 + 1)]
|
||||||
|
int nFreq = nFFTWindow_ / 2 + 1;
|
||||||
|
std::vector<float> filters(nMel_ * nFreq, 0.0f);
|
||||||
|
|
||||||
|
float fMin = 0.0f;
|
||||||
|
float fMax = static_cast<float>(sampleRate_) / 2.0f;
|
||||||
|
float melMin = hzToMel(fMin);
|
||||||
|
float melMax = hzToMel(fMax);
|
||||||
|
|
||||||
|
// Mel 中心频率点
|
||||||
|
std::vector<float> melPoints(nMel_ + 2);
|
||||||
|
for (int i = 0; i < nMel_ + 2; i++) {
|
||||||
|
float mel = melMin + (melMax - melMin) * i / (nMel_ + 1);
|
||||||
|
melPoints[i] = melToHz(mel);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为 FFT bin 索引
|
||||||
|
std::vector<int> binPoints(nMel_ + 2);
|
||||||
|
for (int i = 0; i < nMel_ + 2; i++) {
|
||||||
|
binPoints[i] = static_cast<int>(std::round((nFFTWindow_ + 1) * melPoints[i] / sampleRate_));
|
||||||
|
binPoints[i] = std::min(binPoints[i], nFreq - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造三角滤波器
|
||||||
|
for (int m = 0; m < nMel_; m++) {
|
||||||
|
for (int k = 0; k < nFreq; k++) {
|
||||||
|
float val = 0.0f;
|
||||||
|
if (k >= binPoints[m] && k <= binPoints[m + 1]) {
|
||||||
|
val = (k - binPoints[m]) / static_cast<float>(binPoints[m + 1] - binPoints[m] + 1e-10f);
|
||||||
|
} else if (k >= binPoints[m + 1] && k <= binPoints[m + 2]) {
|
||||||
|
val = (binPoints[m + 2] - k) / static_cast<float>(binPoints[m + 2] - binPoints[m + 1] + 1e-10f);
|
||||||
|
}
|
||||||
|
filters[m * nFreq + k] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 归一化
|
||||||
|
for (int m = 0; m < nMel_; m++) {
|
||||||
|
float norm = 0.0f;
|
||||||
|
for (int k = 0; k < nFreq; k++) {
|
||||||
|
norm += filters[m * nFreq + k];
|
||||||
|
}
|
||||||
|
if (norm > 1e-10f) {
|
||||||
|
for (int k = 0; k < nFreq; k++) {
|
||||||
|
filters[m * nFreq + k] /= norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filters;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> MelSpectrogram::stft(const std::vector<float>& samples, int frameStart) const {
|
||||||
|
int nFreq = nFFTWindow_ / 2 + 1;
|
||||||
|
std::vector<float> magnitude(nFreq, 0.0f);
|
||||||
|
|
||||||
|
// 提取窗口并应用 Hann 窗
|
||||||
|
auto window = hannWindow(nFFT_);
|
||||||
|
std::vector<Complex> fftInput(nFFTWindow_, {0.0f, 0.0f});
|
||||||
|
|
||||||
|
for (int i = 0; i < nFFT_; i++) {
|
||||||
|
int idx = frameStart + i;
|
||||||
|
if (idx < static_cast<int>(samples.size())) {
|
||||||
|
fftInput[i] = {samples[idx] * window[i], 0.0f};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行 FFT
|
||||||
|
fft(fftInput);
|
||||||
|
|
||||||
|
// 计算幅度谱
|
||||||
|
for (int k = 0; k < nFreq; k++) {
|
||||||
|
magnitude[k] = fftInput[k].magnitudeSq();
|
||||||
|
}
|
||||||
|
|
||||||
|
return magnitude;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> MelSpectrogram::compute(const std::vector<float>& samples) const {
|
||||||
|
int nFreq = nFFTWindow_ / 2 + 1;
|
||||||
|
auto filters = melFilterbank();
|
||||||
|
|
||||||
|
// 填充到 30 秒
|
||||||
|
int expectedSamples = kWhisperDurationSec * sampleRate_;
|
||||||
|
std::vector<float> padded = samples;
|
||||||
|
if (static_cast<int>(padded.size()) < expectedSamples) {
|
||||||
|
padded.resize(expectedSamples, 0.0f);
|
||||||
|
} else if (static_cast<int>(padded.size()) > expectedSamples) {
|
||||||
|
padded.resize(expectedSamples);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算帧数
|
||||||
|
int numFrames = nFrames(static_cast<int>(padded.size()));
|
||||||
|
if (numFrames <= 0) numFrames = 1;
|
||||||
|
|
||||||
|
// 计算 Mel 频谱图 [nMel x numFrames]
|
||||||
|
std::vector<float> melSpec(nMel_ * numFrames, 0.0f);
|
||||||
|
|
||||||
|
for (int t = 0; t < numFrames; t++) {
|
||||||
|
int frameStart = t * hopLength_;
|
||||||
|
auto magnitude = stft(padded, frameStart);
|
||||||
|
|
||||||
|
// 应用 mel 滤波器组
|
||||||
|
for (int m = 0; m < nMel_; m++) {
|
||||||
|
float melVal = 0.0f;
|
||||||
|
for (int k = 0; k < nFreq; k++) {
|
||||||
|
melVal += magnitude[k] * filters[m * nFreq + k];
|
||||||
|
}
|
||||||
|
// 对数压缩
|
||||||
|
melVal = std::max(melVal, 1e-10f);
|
||||||
|
melSpec[m * numFrames + t] = std::log(melVal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Whisper 的全局归一化
|
||||||
|
float globalMin = melSpec[0];
|
||||||
|
for (float v : melSpec) {
|
||||||
|
if (v < globalMin) globalMin = v;
|
||||||
|
}
|
||||||
|
float offset = std::max(globalMin, kMinLevel);
|
||||||
|
for (float& v : melSpec) {
|
||||||
|
v = (v - offset) / -kMinLevel;
|
||||||
|
}
|
||||||
|
|
||||||
|
return melSpec;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace impress
|
||||||
53
src/core/mel_spectrogram.h
Normal file
53
src/core/mel_spectrogram.h
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace impress {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Mel 频谱图提取器
|
||||||
|
*
|
||||||
|
* 将音频 PCM 数据转换为 Whisper 模型所需的 Mel 频谱图。
|
||||||
|
* 使用 Hann 窗口 + FFT + Mel 滤波器组。
|
||||||
|
*/
|
||||||
|
class MelSpectrogram {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief 构造函数
|
||||||
|
* @param nMel 滤波器数量,Whisper 使用 80
|
||||||
|
* @param nFFT FFT 窗口大小
|
||||||
|
* @param hopLength 帧移步长
|
||||||
|
* @param sampleRate 采样率
|
||||||
|
*/
|
||||||
|
MelSpectrogram(int nMel = 80, int nFFT = 400, int hopLength = 160, int sampleRate = 16000);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 计算 Mel 频谱图
|
||||||
|
* @param samples 归一化 PCM 浮点数据 [-1, 1]
|
||||||
|
* @return Mel 频谱图数据,维度 [nMel x nFrames]
|
||||||
|
*/
|
||||||
|
std::vector<float> compute(const std::vector<float>& samples) const;
|
||||||
|
|
||||||
|
/** @brief 获取帧数 */
|
||||||
|
int nFrames(int numSamples) const;
|
||||||
|
|
||||||
|
/** @brief Mel 滤波器组数量 */
|
||||||
|
int nMel() const { return nMel_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<float> hannWindow(int size) const;
|
||||||
|
std::vector<float> melFilterbank() const;
|
||||||
|
std::vector<float> stft(const std::vector<float>& samples, int frameStart) const;
|
||||||
|
static float hzToMel(float hz);
|
||||||
|
static float melToHz(float mel);
|
||||||
|
|
||||||
|
int nMel_;
|
||||||
|
int nFFT_;
|
||||||
|
int hopLength_;
|
||||||
|
int sampleRate_;
|
||||||
|
int nFFTWindow_; // 实际 FFT 大小(向上取 2 的幂)
|
||||||
|
int preemphasisCoeff_ = 0; // Whisper 不使用预加重
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace impress
|
||||||
@ -1,4 +1,6 @@
|
|||||||
#include "stt_engine.h"
|
#include "stt_engine.h"
|
||||||
|
#include "mel_spectrogram.h"
|
||||||
|
#include "whisper_tokenizer.h"
|
||||||
#include "utils/logger.h"
|
#include "utils/logger.h"
|
||||||
#include "utils/timer.h"
|
#include "utils/timer.h"
|
||||||
|
|
||||||
@ -7,6 +9,10 @@
|
|||||||
#include <QtConcurrent>
|
#include <QtConcurrent>
|
||||||
#include <QMutex>
|
#include <QMutex>
|
||||||
#include <QMutexLocker>
|
#include <QMutexLocker>
|
||||||
|
#include <QFileInfo>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
// ONNX Runtime headers
|
// ONNX Runtime headers
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
@ -15,26 +21,29 @@
|
|||||||
|
|
||||||
static const char* const kTag = "STTEngine";
|
static const char* const kTag = "STTEngine";
|
||||||
|
|
||||||
|
// Whisper 常量
|
||||||
|
static const int kMaxTokens = 224;
|
||||||
|
static const int kMelBins = 80;
|
||||||
|
|
||||||
namespace impress {
|
namespace impress {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief STT 引擎内部实现
|
||||||
|
*/
|
||||||
struct STTEngine::Impl {
|
struct STTEngine::Impl {
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
std::unique_ptr<Ort::Env> env;
|
std::unique_ptr<Ort::Env> env;
|
||||||
std::unique_ptr<Ort::SessionOptions> sessionOptions;
|
std::unique_ptr<Ort::SessionOptions> sessionOptions;
|
||||||
std::unique_ptr<Ort::Session> session;
|
std::unique_ptr<Ort::Session> session;
|
||||||
#endif
|
|
||||||
QMutex mutex;
|
|
||||||
|
|
||||||
/**
|
std::vector<std::string> inputNames;
|
||||||
* @brief 在后台线程中执行模型加载
|
std::vector<std::string> outputNames;
|
||||||
* 返回 true 表示成功,false 表示失败
|
|
||||||
*/
|
|
||||||
bool loadInWorker(const QString& modelPath,
|
bool loadInWorker(const QString& modelPath,
|
||||||
const QString& device,
|
const QString& device,
|
||||||
int numThreads,
|
int numThreads,
|
||||||
QString& errorMsg)
|
QString& errorMsg)
|
||||||
{
|
{
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
|
||||||
QMutexLocker locker(&mutex);
|
QMutexLocker locker(&mutex);
|
||||||
try {
|
try {
|
||||||
auto envPtr = std::make_unique<Ort::Env>(
|
auto envPtr = std::make_unique<Ort::Env>(
|
||||||
@ -50,13 +59,33 @@ struct STTEngine::Impl {
|
|||||||
|
|
||||||
LOG_INFO(kTag, QString("正在加载模型: %1 (线程: %2)").arg(modelPath).arg(numThreads));
|
LOG_INFO(kTag, QString("正在加载模型: %1 (线程: %2)").arg(modelPath).arg(numThreads));
|
||||||
|
|
||||||
// ONNX Session 构造函数在 Linux 上使用 const char* 路径
|
|
||||||
auto sessionPtr = std::make_unique<Ort::Session>(
|
auto sessionPtr = std::make_unique<Ort::Session>(
|
||||||
*envPtr,
|
*envPtr,
|
||||||
modelPath.toUtf8().constData(),
|
modelPath.toUtf8().constData(),
|
||||||
*optionsPtr);
|
*optionsPtr);
|
||||||
|
|
||||||
// 全部成功后才替换成员变量
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
size_t inputCount = sessionPtr->GetInputCount();
|
||||||
|
size_t outputCount = sessionPtr->GetOutputCount();
|
||||||
|
|
||||||
|
LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出")
|
||||||
|
.arg(inputCount).arg(outputCount));
|
||||||
|
|
||||||
|
inputNames.clear();
|
||||||
|
outputNames.clear();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < inputCount; i++) {
|
||||||
|
auto namePtr = sessionPtr->GetInputNameAllocated(i, allocator);
|
||||||
|
inputNames.emplace_back(namePtr.get());
|
||||||
|
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(namePtr.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < outputCount; i++) {
|
||||||
|
auto namePtr = sessionPtr->GetOutputNameAllocated(i, allocator);
|
||||||
|
outputNames.emplace_back(namePtr.get());
|
||||||
|
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(namePtr.get()));
|
||||||
|
}
|
||||||
|
|
||||||
env = std::move(envPtr);
|
env = std::move(envPtr);
|
||||||
sessionOptions = std::move(optionsPtr);
|
sessionOptions = std::move(optionsPtr);
|
||||||
session = std::move(sessionPtr);
|
session = std::move(sessionPtr);
|
||||||
@ -72,12 +101,10 @@ struct STTEngine::Impl {
|
|||||||
LOG_ERROR(kTag, errorMsg);
|
LOG_ERROR(kTag, errorMsg);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
errorMsg = "ONNX Runtime 未编译启用";
|
|
||||||
LOG_ERROR(kTag, errorMsg);
|
|
||||||
return false;
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QMutex mutex;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
STTEngine::STTEngine(QObject* parent)
|
STTEngine::STTEngine(QObject* parent)
|
||||||
@ -122,12 +149,10 @@ void STTEngine::loadModelAsync(const QString& modelPath,
|
|||||||
|
|
||||||
LOG_INFO(kTag, QString("异步加载模型: %1").arg(modelPath));
|
LOG_INFO(kTag, QString("异步加载模型: %1").arg(modelPath));
|
||||||
|
|
||||||
// 在后台线程中执行加载
|
|
||||||
QFuture<void> future = QtConcurrent::run([this, modelPath, device, numThreads]() {
|
QFuture<void> future = QtConcurrent::run([this, modelPath, device, numThreads]() {
|
||||||
QString errorMsg;
|
QString errorMsg;
|
||||||
bool success = impl_->loadInWorker(modelPath, device, numThreads, errorMsg);
|
bool success = impl_->loadInWorker(modelPath, device, numThreads, errorMsg);
|
||||||
|
|
||||||
// 回到主线程发送信号
|
|
||||||
QMetaObject::invokeMethod(this, [this, modelPath, errorMsg, success]() {
|
QMetaObject::invokeMethod(this, [this, modelPath, errorMsg, success]() {
|
||||||
loaded_ = success;
|
loaded_ = success;
|
||||||
if (success) {
|
if (success) {
|
||||||
@ -156,13 +181,48 @@ bool STTEngine::isLoaded() const {
|
|||||||
return loaded_;
|
return loaded_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int STTEngine::vocabSize() const {
|
||||||
|
return 51865;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** argmax: 寻找数组中最大值的索引 */
|
||||||
|
static int argmax(const float* data, int start, int end) {
|
||||||
|
int bestIdx = start;
|
||||||
|
float bestVal = data[start];
|
||||||
|
for (int i = start + 1; i < end; i++) {
|
||||||
|
if (data[i] > bestVal) {
|
||||||
|
bestVal = data[i];
|
||||||
|
bestIdx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestIdx;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** softmax 计算 */
|
||||||
|
static std::vector<float> softmax(const float* data, int start, int end) {
|
||||||
|
float maxVal = -1e9f;
|
||||||
|
for (int i = start; i < end; i++) {
|
||||||
|
maxVal = std::max(maxVal, data[i]);
|
||||||
|
}
|
||||||
|
float sum = 0.0f;
|
||||||
|
std::vector<float> probs(end - start);
|
||||||
|
for (int i = start; i < end; i++) {
|
||||||
|
probs[i - start] = std::exp(data[i] - maxVal);
|
||||||
|
sum += probs[i - start];
|
||||||
|
}
|
||||||
|
for (float& p : probs) p /= sum;
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||||
int sampleRate,
|
int sampleRate,
|
||||||
bool isStreaming)
|
const QString& language)
|
||||||
{
|
{
|
||||||
Timer timer;
|
Timer timer;
|
||||||
RecognitionResult result;
|
RecognitionResult result;
|
||||||
|
|
||||||
|
(void)language;
|
||||||
|
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
if (!loaded_) {
|
if (!loaded_) {
|
||||||
result.text = "[错误] 模型未加载";
|
result.text = "[错误] 模型未加载";
|
||||||
@ -171,29 +231,115 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 标记未使用的参数,消除编译警告
|
// 1. 计算 Mel 频谱图
|
||||||
(void)samples;
|
Timer melTimer;
|
||||||
(void)sampleRate;
|
MelSpectrogram melExtractor(kMelBins, 400, 160, sampleRate);
|
||||||
(void)isStreaming;
|
std::vector<float> melSpec = melExtractor.compute(samples);
|
||||||
|
int nFrames = melExtractor.nFrames(static_cast<int>(samples.size()));
|
||||||
|
if (nFrames <= 0) nFrames = 1;
|
||||||
|
LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)").arg(melTimer.elapsedMs(), 0, 'f', 1).arg(nFrames));
|
||||||
|
|
||||||
// TODO: 实现完整的 ONNX 推理流程
|
// 2. 运行 ONNX 推理
|
||||||
// 1. 创建输入 Tensor
|
Timer inferTimer;
|
||||||
// 2. 运行推理
|
QMutexLocker locker(&impl_->mutex);
|
||||||
// 3. 解码输出 (CTC / 自回归)
|
|
||||||
// 4. Tokenizer 解码文本
|
int64_t melShape[] = {1, kMelBins, static_cast<int64_t>(nFrames)};
|
||||||
|
auto memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||||
|
std::vector<Ort::Value> inputTensors;
|
||||||
|
inputTensors.push_back(Ort::Value::CreateTensor<float>(
|
||||||
|
memInfo, melSpec.data(), melSpec.size(), melShape, 3));
|
||||||
|
|
||||||
|
std::vector<const char*> inputNamePtrs;
|
||||||
|
for (auto& name : impl_->inputNames) inputNamePtrs.push_back(name.c_str());
|
||||||
|
std::vector<const char*> outputNamePtrs;
|
||||||
|
for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str());
|
||||||
|
|
||||||
|
auto outputTensors = impl_->session->Run(
|
||||||
|
Ort::RunOptions{nullptr},
|
||||||
|
inputNamePtrs.data(), inputTensors.data(), inputTensors.size(),
|
||||||
|
outputNamePtrs.data(), impl_->outputNames.size());
|
||||||
|
|
||||||
|
LOG_DEBUG(kTag, QString("ONNX 推理: %1 ms").arg(inferTimer.elapsedMs(), 0, 'f', 1));
|
||||||
|
|
||||||
|
// 3. 解析输出
|
||||||
|
auto& outputTensor = outputTensors[0];
|
||||||
|
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
const float* outputData = outputTensor.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
LOG_DEBUG(kTag, QString("输出维度: %1").arg(shape.size()));
|
||||||
|
for (size_t i = 0; i < shape.size(); i++) {
|
||||||
|
LOG_DEBUG(kTag, QString(" dim[%1] = %2").arg(i).arg(shape[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
int vocabSize = 51865;
|
||||||
|
std::vector<int> tokens;
|
||||||
|
|
||||||
|
if (shape.size() == 2 && shape[1] == vocabSize) {
|
||||||
|
// [1, vocab_size] - 直接输出
|
||||||
|
int bestToken = argmax(outputData, 0, std::min(vocabSize, 50256));
|
||||||
|
if (!WhisperTokenizer::isSpecialToken(bestToken)) {
|
||||||
|
tokens.push_back(bestToken);
|
||||||
|
}
|
||||||
|
auto probs = softmax(outputData, 0, std::min(vocabSize, 50256));
|
||||||
|
float maxProb = probs[0];
|
||||||
|
for (size_t i = 1; i < probs.size(); i++) {
|
||||||
|
if (probs[i] > maxProb) maxProb = probs[i];
|
||||||
|
}
|
||||||
|
result.confidence = maxProb;
|
||||||
|
|
||||||
|
} else if (shape.size() >= 3) {
|
||||||
|
// [1, seq_len, vocab_size] - 自回归输出
|
||||||
|
int seqLen = static_cast<int>(shape[1]);
|
||||||
|
vocabSize = static_cast<int>(shape[2]);
|
||||||
|
|
||||||
|
for (int t = 0; t < seqLen && static_cast<int>(tokens.size()) < kMaxTokens; t++) {
|
||||||
|
int offset = t * vocabSize;
|
||||||
|
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
||||||
|
if (WhisperTokenizer::isSpecialToken(bestToken)) break;
|
||||||
|
if (!tokens.empty() && tokens.back() == bestToken) continue;
|
||||||
|
tokens.push_back(bestToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tokens.empty()) {
|
||||||
|
float avgConf = 0.0f;
|
||||||
|
for (int t = 0; t < seqLen && t < static_cast<int>(tokens.size()); t++) {
|
||||||
|
int offset = t * vocabSize;
|
||||||
|
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
||||||
|
auto probs = softmax(outputData, offset, offset + vocabSize);
|
||||||
|
avgConf += probs[bestToken - offset];
|
||||||
|
}
|
||||||
|
result.confidence = avgConf / tokens.size();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.text = QString("[错误] 不支持的输出维度: %1").arg(shape.size());
|
||||||
|
result.latency_ms = timer.elapsedMs();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 解码 token 为文本
|
||||||
|
if (tokens.empty()) {
|
||||||
|
result.text = "";
|
||||||
|
} else {
|
||||||
|
QString decodedText;
|
||||||
|
for (int token : tokens) {
|
||||||
|
if (token < 0 || token >= 50256) continue;
|
||||||
|
decodedText += QString("[T%1]").arg(token);
|
||||||
|
}
|
||||||
|
result.text = decodedText;
|
||||||
|
}
|
||||||
|
|
||||||
result.text = "[占位] 推理逻辑待实现";
|
|
||||||
result.confidence = 0.95f;
|
|
||||||
result.isFinal = true;
|
result.isFinal = true;
|
||||||
|
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
result.text = QString("[错误] 推理失败: %1").arg(e.what());
|
result.text = QString("[错误] 推理失败: %1").arg(e.what());
|
||||||
|
LOG_ERROR(kTag, result.text);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
result.text = "[占位] ONNX Runtime 未启用,推理逻辑未实现";
|
result.text = "[占位] ONNX Runtime 未启用";
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
result.latency_ms = timer.elapsedMs();
|
result.latency_ms = timer.elapsedMs();
|
||||||
LOG_DEBUG(kTag, QString("推理耗时: %1 ms").arg(result.latency_ms, 0, 'f', 1));
|
LOG_DEBUG(kTag, QString("推理总耗时: %1 ms").arg(result.latency_ms, 0, 'f', 1));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,7 @@ struct RecognitionResult {
|
|||||||
* @brief STT 推理引擎
|
* @brief STT 推理引擎
|
||||||
*
|
*
|
||||||
* 封装 ONNX Runtime 推理逻辑,负责模型加载、音频推理和结果输出。
|
* 封装 ONNX Runtime 推理逻辑,负责模型加载、音频推理和结果输出。
|
||||||
|
* 支持 Whisper ONNX 模型(单模型或 encoder/decoder 分离模型)。
|
||||||
* 模型加载在后台线程执行,不阻塞 UI。
|
* 模型加载在后台线程执行,不阻塞 UI。
|
||||||
*/
|
*/
|
||||||
class STTEngine : public QObject {
|
class STTEngine : public QObject {
|
||||||
@ -26,7 +27,7 @@ public:
|
|||||||
explicit STTEngine(QObject* parent = nullptr);
|
explicit STTEngine(QObject* parent = nullptr);
|
||||||
~STTEngine() override;
|
~STTEngine() override;
|
||||||
|
|
||||||
/** @brief 同步加载模型(阻塞,不推荐在 UI 线程调用) */
|
/** @brief 同步加载模型 */
|
||||||
bool loadModelSync(const QString& modelPath,
|
bool loadModelSync(const QString& modelPath,
|
||||||
const QString& device = "cpu",
|
const QString& device = "cpu",
|
||||||
int numThreads = 4);
|
int numThreads = 4);
|
||||||
@ -42,15 +43,18 @@ public:
|
|||||||
/** @brief 是否已加载模型 */
|
/** @brief 是否已加载模型 */
|
||||||
bool isLoaded() const;
|
bool isLoaded() const;
|
||||||
|
|
||||||
|
/** @brief 获取词表大小(加载模型后可查询) */
|
||||||
|
int vocabSize() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief 推理音频数据
|
* @brief 推理音频数据
|
||||||
* @param samples 归一化后的 PCM 浮点样本(范围 [-1, 1])
|
* @param samples 归一化后的 PCM 浮点样本(范围 [-1, 1])
|
||||||
* @param sampleRate 采样率
|
* @param sampleRate 采样率
|
||||||
* @param isStreaming 是否流式推理
|
* @param language 识别语言代码(如 "zh", "en"),空则自动检测
|
||||||
*/
|
*/
|
||||||
RecognitionResult infer(const std::vector<float>& samples,
|
RecognitionResult infer(const std::vector<float>& samples,
|
||||||
int sampleRate,
|
int sampleRate,
|
||||||
bool isStreaming = true);
|
const QString& language = QString());
|
||||||
|
|
||||||
signals:
|
signals:
|
||||||
void modelLoaded(const QString& modelPath);
|
void modelLoaded(const QString& modelPath);
|
||||||
|
|||||||
103
src/core/whisper_tokenizer.cpp
Normal file
103
src/core/whisper_tokenizer.cpp
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
#include "whisper_tokenizer.h"
|
||||||
|
#include "utils/logger.h"
|
||||||
|
#include <QFile>
|
||||||
|
#include <QTextStream>
|
||||||
|
|
||||||
|
static const char* const kTag = "WhisperTokenizer";
|
||||||
|
|
||||||
|
namespace impress {
|
||||||
|
|
||||||
|
WhisperTokenizer::WhisperTokenizer() = default;
|
||||||
|
|
||||||
|
bool WhisperTokenizer::loadVocabulary(const QString& vocabPath) {
|
||||||
|
QFile file(vocabPath);
|
||||||
|
if (!file.open(QIODevice::ReadOnly | QIODevice::Text)) {
|
||||||
|
LOG_ERROR(kTag, QString("无法打开词表文件: %1").arg(vocabPath));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
QTextStream stream(&file);
|
||||||
|
stream.setEncoding(QStringConverter::Utf8);
|
||||||
|
|
||||||
|
tokenToString_.clear();
|
||||||
|
stringToToken_.clear();
|
||||||
|
|
||||||
|
// 支持两种格式:
|
||||||
|
// 1. tiktoken base64 格式: "<base64> <token_id>"
|
||||||
|
// 2. 纯文本格式: "<token_string> <token_id>"
|
||||||
|
int lineCount = 0;
|
||||||
|
while (!stream.atEnd()) {
|
||||||
|
QString line = stream.readLine().trimmed();
|
||||||
|
if (line.isEmpty()) continue;
|
||||||
|
|
||||||
|
// 查找最后一个空格分隔 token_id
|
||||||
|
int lastSpace = line.lastIndexOf(' ');
|
||||||
|
if (lastSpace < 0) continue;
|
||||||
|
|
||||||
|
bool ok = false;
|
||||||
|
int tokenId = line.mid(lastSpace + 1).toInt(&ok);
|
||||||
|
if (!ok) continue;
|
||||||
|
|
||||||
|
QString tokenStr = line.left(lastSpace);
|
||||||
|
tokenToString_[tokenId] = tokenStr;
|
||||||
|
stringToToken_[tokenStr] = tokenId;
|
||||||
|
lineCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO(kTag, QString("词表已加载: %1 个词条 (文件: %2)").arg(lineCount).arg(vocabPath));
|
||||||
|
return !tokenToString_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
QString WhisperTokenizer::decode(const std::vector<int>& tokens) const {
|
||||||
|
QString result;
|
||||||
|
for (int token : tokens) {
|
||||||
|
if (isSpecialToken(token)) continue;
|
||||||
|
|
||||||
|
auto it = tokenToString_.find(token);
|
||||||
|
if (it != tokenToString_.end()) {
|
||||||
|
QString decoded = decodeBytePair(it->second);
|
||||||
|
result += decoded;
|
||||||
|
} else {
|
||||||
|
result += QString("<|token:%1|>").arg(token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> WhisperTokenizer::encode(const QString& text) const {
|
||||||
|
std::vector<int> tokens;
|
||||||
|
// 简单的字符级编码(实际 BPE 编码需要完整实现)
|
||||||
|
for (int i = 0; i < text.length(); i++) {
|
||||||
|
QString ch = text.mid(i, 1);
|
||||||
|
auto it = stringToToken_.find(ch);
|
||||||
|
if (it != stringToToken_.end()) {
|
||||||
|
tokens.push_back(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
QString WhisperTokenizer::decodeBytePair(const QString& text) const {
|
||||||
|
// Whisper 使用 unicode 转义如 Ġ 表示空格
|
||||||
|
QString result = text;
|
||||||
|
result.replace(QChar(0x0120), ' '); // Ġ -> space
|
||||||
|
result.replace(QChar(0x010A), '\n'); // Ċ -> newline
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
int WhisperTokenizer::languageTokenId(const QString& langCode) {
|
||||||
|
static const std::unordered_map<QString, int> langMap = {
|
||||||
|
{"zh", 50260}, {"en", 50259}, {"ja", 50261}, {"ko", 50262},
|
||||||
|
{"fr", 50265}, {"de", 50266}, {"es", 50267}, {"ru", 50268},
|
||||||
|
{"pt", 50269}, {"it", 50270}, {"auto", 50359}
|
||||||
|
};
|
||||||
|
auto it = langMap.find(langCode);
|
||||||
|
return it != langMap.end() ? it->second : 50259; // 默认英语
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WhisperTokenizer::isSpecialToken(int token) {
|
||||||
|
// Whisper 特殊 token 范围: [50257, 50362]
|
||||||
|
return token >= 50257 && token <= 50363;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace impress
|
||||||
58
src/core/whisper_tokenizer.h
Normal file
58
src/core/whisper_tokenizer.h
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QString>
|
||||||
|
#include <QStringList>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace impress {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Whisper Tokenizer
|
||||||
|
*
|
||||||
|
* 基于 BPE 的分词器,支持 Whisper 模型的 token 编解码。
|
||||||
|
* 从 tiktoken 格式的词汇表文件加载。
|
||||||
|
*/
|
||||||
|
class WhisperTokenizer {
|
||||||
|
public:
|
||||||
|
WhisperTokenizer();
|
||||||
|
|
||||||
|
/** @brief 从 tiktoken 格式的词汇表文件加载 */
|
||||||
|
bool loadVocabulary(const QString& vocabPath);
|
||||||
|
|
||||||
|
/** @brief 将 token IDs 解码为文本 */
|
||||||
|
QString decode(const std::vector<int>& tokens) const;
|
||||||
|
|
||||||
|
/** @brief 将文本编码为 token IDs(用于 prompt) */
|
||||||
|
std::vector<int> encode(const QString& text) const;
|
||||||
|
|
||||||
|
/** @brief 是否已加载词表 */
|
||||||
|
bool isLoaded() const { return !tokenToString_.empty(); }
|
||||||
|
|
||||||
|
/** @brief 词表大小 */
|
||||||
|
int vocabSize() const { return static_cast<int>(tokenToString_.size()); }
|
||||||
|
|
||||||
|
// Whisper 特殊 token
|
||||||
|
static constexpr int kTokenEndOfText = 50257;
|
||||||
|
static constexpr int kTokenEndOfSpeech = 50256;
|
||||||
|
static constexpr int kTokenNoSpeech = 50362;
|
||||||
|
static constexpr int kTokenTranscription = 50359;
|
||||||
|
|
||||||
|
// 语言 token 起始偏移
|
||||||
|
static constexpr int kTokenLanguageBase = 50259;
|
||||||
|
|
||||||
|
/** @brief 获取语言 token ID */
|
||||||
|
static int languageTokenId(const QString& langCode);
|
||||||
|
|
||||||
|
/** @brief 判断是否为特殊 token */
|
||||||
|
static bool isSpecialToken(int token);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<int, QString> tokenToString_;
|
||||||
|
std::unordered_map<QString, int> stringToToken_;
|
||||||
|
|
||||||
|
QString decodeBytePair(const QString& text) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace impress
|
||||||
@ -188,7 +188,8 @@ void FileTranscribePage::processNextFile() {
|
|||||||
const auto& samples = audioDecoder_->samples();
|
const auto& samples = audioDecoder_->samples();
|
||||||
int sampleRate = audioDecoder_->sampleRate();
|
int sampleRate = audioDecoder_->sampleRate();
|
||||||
|
|
||||||
auto result = sttEngine_->infer(samples, sampleRate, false);
|
auto result = sttEngine_->infer(samples, sampleRate,
|
||||||
|
configManager_->get("stt.language").toString());
|
||||||
task.result = result.text;
|
task.result = result.text;
|
||||||
task.status = "完成";
|
task.status = "完成";
|
||||||
task.progress = 1.0;
|
task.progress = 1.0;
|
||||||
|
|||||||
@ -211,7 +211,8 @@ void STTTestPage::processAudioChunk(const std::vector<float>& samples, int sampl
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = sttEngine_->infer(samples, sampleRate, true);
|
auto result = sttEngine_->infer(samples, sampleRate,
|
||||||
|
configManager_->get("stt.language").toString());
|
||||||
emit onRecognitionResult(result.text, result.confidence, result.latency_ms, result.isFinal);
|
emit onRecognitionResult(result.text, result.confidence, result.latency_ms, result.isFinal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user