diff --git a/CMakeLists.txt b/CMakeLists.txt index d6fbfc4..6bb6a36 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,8 @@ set(SOURCES # Core src/core/stt_engine.cpp + src/core/mel_spectrogram.cpp + src/core/whisper_tokenizer.cpp src/core/audio_processor.cpp src/core/decoder.cpp src/core/tokenizer.cpp @@ -68,6 +70,8 @@ set(HEADERS src/app/config_manager.h src/core/stt_engine.h + src/core/mel_spectrogram.h + src/core/whisper_tokenizer.h src/core/audio_processor.h src/core/decoder.h src/core/tokenizer.h diff --git a/src/core/mel_spectrogram.cpp b/src/core/mel_spectrogram.cpp new file mode 100644 index 0000000..b65bf07 --- /dev/null +++ b/src/core/mel_spectrogram.cpp @@ -0,0 +1,221 @@ +#include "mel_spectrogram.h" +#include +#include +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +// Whisper 模型参数 +static const int kWhisperDurationSec = 30; // 30 秒音频 +static const float kMinLevel = -11.5f; // 对数谱图最小值 + +namespace impress { + +// 简易复数运算 +struct Complex { + float re, im; + Complex(float r = 0, float i = 0) : re(r), im(i) {} + Complex operator+(const Complex& o) const { return {re + o.re, im + o.im}; } + Complex operator-(const Complex& o) const { return {re - o.re, im - o.im}; } + Complex operator*(const Complex& o) const { + return {re * o.re - im * o.im, re * o.im + im * o.re}; + } + Complex operator*(float s) const { return {re * s, im * s}; } + float magnitudeSq() const { return re * re + im * im; } +}; + +/** + * @brief Radix-2 Cooley-Tukey FFT + */ +static void fft(std::vector& x) { + int n = static_cast(x.size()); + if (n <= 1) return; + + // 位反转置换 + for (int i = 1, j = 0; i < n; i++) { + int bit = n >> 1; + for (; j & bit; bit >>= 1) j ^= bit; + j ^= bit; + if (i < j) std::swap(x[i], x[j]); + } + + // 蝶形运算 + for (int len = 2; len <= n; len *= 2) { + float angle = -2.0f * static_cast(M_PI) / len; + Complex wlen(std::cos(angle), std::sin(angle)); + for (int i = 0; i < n; i += len) { + Complex w(1.0f, 0.0f); + for (int j = 0; j < len / 2; j++) { + Complex u = x[i + j]; + Complex v = x[i + j + len / 2] * w; + x[i + j] = u + v; + x[i + j + len / 2] = u - v; + w = w * wlen; + } + } + } +} + +MelSpectrogram::MelSpectrogram(int nMel, int nFFT, int hopLength, int sampleRate) + : nMel_(nMel) + , nFFT_(nFFT) + , hopLength_(hopLength) + , sampleRate_(sampleRate) +{ + // FFT 窗口大小向上取 2 的幂 + nFFTWindow_ = 1; + while (nFFTWindow_ < nFFT) nFFTWindow_ *= 2; +} + +int MelSpectrogram::nFrames(int numSamples) const { + return (numSamples - nFFT_ + hopLength_) / hopLength_; +} + +float MelSpectrogram::hzToMel(float hz) { + return 1125.0f * std::log(1.0f + hz / 700.0f); +} + +float MelSpectrogram::melToHz(float mel) { + return 700.0f * (std::exp(mel / 1125.0f) - 1.0f); +} + +std::vector MelSpectrogram::hannWindow(int size) const { + std::vector window(size); + for (int i = 0; i < size; i++) { + window[i] = 0.5f * (1.0f - std::cos(2.0f * static_cast(M_PI) * i / (size - 1))); + } + return window; +} + +std::vector MelSpectrogram::melFilterbank() const { + // Mel 滤波器组 [nMel x (nFFT/2 + 1)] + int nFreq = nFFTWindow_ / 2 + 1; + std::vector filters(nMel_ * nFreq, 0.0f); + + float fMin = 0.0f; + float fMax = static_cast(sampleRate_) / 2.0f; + float melMin = hzToMel(fMin); + float melMax = hzToMel(fMax); + + // Mel 中心频率点 + std::vector melPoints(nMel_ + 2); + for (int i = 0; i < nMel_ + 2; i++) { + float mel = melMin + (melMax - melMin) * i / (nMel_ + 1); + melPoints[i] = melToHz(mel); + } + + // 转换为 FFT bin 索引 + std::vector binPoints(nMel_ + 2); + for (int i = 0; i < nMel_ + 2; i++) { + binPoints[i] = static_cast(std::round((nFFTWindow_ + 1) * melPoints[i] / sampleRate_)); + binPoints[i] = std::min(binPoints[i], nFreq - 1); + } + + // 构造三角滤波器 + for (int m = 0; m < nMel_; m++) { + for (int k = 0; k < nFreq; k++) { + float val = 0.0f; + if (k >= binPoints[m] && k <= binPoints[m + 1]) { + val = (k - binPoints[m]) / static_cast(binPoints[m + 1] - binPoints[m] + 1e-10f); + } else if (k >= binPoints[m + 1] && k <= binPoints[m + 2]) { + val = (binPoints[m + 2] - k) / static_cast(binPoints[m + 2] - binPoints[m + 1] + 1e-10f); + } + filters[m * nFreq + k] = val; + } + } + + // 归一化 + for (int m = 0; m < nMel_; m++) { + float norm = 0.0f; + for (int k = 0; k < nFreq; k++) { + norm += filters[m * nFreq + k]; + } + if (norm > 1e-10f) { + for (int k = 0; k < nFreq; k++) { + filters[m * nFreq + k] /= norm; + } + } + } + + return filters; +} + +std::vector MelSpectrogram::stft(const std::vector& samples, int frameStart) const { + int nFreq = nFFTWindow_ / 2 + 1; + std::vector magnitude(nFreq, 0.0f); + + // 提取窗口并应用 Hann 窗 + auto window = hannWindow(nFFT_); + std::vector fftInput(nFFTWindow_, {0.0f, 0.0f}); + + for (int i = 0; i < nFFT_; i++) { + int idx = frameStart + i; + if (idx < static_cast(samples.size())) { + fftInput[i] = {samples[idx] * window[i], 0.0f}; + } + } + + // 执行 FFT + fft(fftInput); + + // 计算幅度谱 + for (int k = 0; k < nFreq; k++) { + magnitude[k] = fftInput[k].magnitudeSq(); + } + + return magnitude; +} + +std::vector MelSpectrogram::compute(const std::vector& samples) const { + int nFreq = nFFTWindow_ / 2 + 1; + auto filters = melFilterbank(); + + // 填充到 30 秒 + int expectedSamples = kWhisperDurationSec * sampleRate_; + std::vector padded = samples; + if (static_cast(padded.size()) < expectedSamples) { + padded.resize(expectedSamples, 0.0f); + } else if (static_cast(padded.size()) > expectedSamples) { + padded.resize(expectedSamples); + } + + // 计算帧数 + int numFrames = nFrames(static_cast(padded.size())); + if (numFrames <= 0) numFrames = 1; + + // 计算 Mel 频谱图 [nMel x numFrames] + std::vector melSpec(nMel_ * numFrames, 0.0f); + + for (int t = 0; t < numFrames; t++) { + int frameStart = t * hopLength_; + auto magnitude = stft(padded, frameStart); + + // 应用 mel 滤波器组 + for (int m = 0; m < nMel_; m++) { + float melVal = 0.0f; + for (int k = 0; k < nFreq; k++) { + melVal += magnitude[k] * filters[m * nFreq + k]; + } + // 对数压缩 + melVal = std::max(melVal, 1e-10f); + melSpec[m * numFrames + t] = std::log(melVal); + } + } + + // Whisper 的全局归一化 + float globalMin = melSpec[0]; + for (float v : melSpec) { + if (v < globalMin) globalMin = v; + } + float offset = std::max(globalMin, kMinLevel); + for (float& v : melSpec) { + v = (v - offset) / -kMinLevel; + } + + return melSpec; +} + +} // namespace impress diff --git a/src/core/mel_spectrogram.h b/src/core/mel_spectrogram.h new file mode 100644 index 0000000..c632c90 --- /dev/null +++ b/src/core/mel_spectrogram.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +namespace impress { + +/** + * @brief Mel 频谱图提取器 + * + * 将音频 PCM 数据转换为 Whisper 模型所需的 Mel 频谱图。 + * 使用 Hann 窗口 + FFT + Mel 滤波器组。 + */ +class MelSpectrogram { +public: + /** + * @brief 构造函数 + * @param nMel 滤波器数量,Whisper 使用 80 + * @param nFFT FFT 窗口大小 + * @param hopLength 帧移步长 + * @param sampleRate 采样率 + */ + MelSpectrogram(int nMel = 80, int nFFT = 400, int hopLength = 160, int sampleRate = 16000); + + /** + * @brief 计算 Mel 频谱图 + * @param samples 归一化 PCM 浮点数据 [-1, 1] + * @return Mel 频谱图数据,维度 [nMel x nFrames] + */ + std::vector compute(const std::vector& samples) const; + + /** @brief 获取帧数 */ + int nFrames(int numSamples) const; + + /** @brief Mel 滤波器组数量 */ + int nMel() const { return nMel_; } + +private: + std::vector hannWindow(int size) const; + std::vector melFilterbank() const; + std::vector stft(const std::vector& samples, int frameStart) const; + static float hzToMel(float hz); + static float melToHz(float mel); + + int nMel_; + int nFFT_; + int hopLength_; + int sampleRate_; + int nFFTWindow_; // 实际 FFT 大小(向上取 2 的幂) + int preemphasisCoeff_ = 0; // Whisper 不使用预加重 +}; + +} // namespace impress diff --git a/src/core/stt_engine.cpp b/src/core/stt_engine.cpp index 0dae813..8a5d004 100644 --- a/src/core/stt_engine.cpp +++ b/src/core/stt_engine.cpp @@ -1,4 +1,6 @@ #include "stt_engine.h" +#include "mel_spectrogram.h" +#include "whisper_tokenizer.h" #include "utils/logger.h" #include "utils/timer.h" @@ -7,6 +9,10 @@ #include #include #include +#include +#include +#include +#include // ONNX Runtime headers #ifdef HAVE_ONNXRUNTIME @@ -15,26 +21,29 @@ static const char* const kTag = "STTEngine"; +// Whisper 常量 +static const int kMaxTokens = 224; +static const int kMelBins = 80; + namespace impress { +/** + * @brief STT 引擎内部实现 + */ struct STTEngine::Impl { #ifdef HAVE_ONNXRUNTIME std::unique_ptr env; std::unique_ptr sessionOptions; std::unique_ptr session; -#endif - QMutex mutex; - /** - * @brief 在后台线程中执行模型加载 - * 返回 true 表示成功,false 表示失败 - */ + std::vector inputNames; + std::vector outputNames; + bool loadInWorker(const QString& modelPath, const QString& device, int numThreads, QString& errorMsg) { -#ifdef HAVE_ONNXRUNTIME QMutexLocker locker(&mutex); try { auto envPtr = std::make_unique( @@ -50,13 +59,33 @@ struct STTEngine::Impl { LOG_INFO(kTag, QString("正在加载模型: %1 (线程: %2)").arg(modelPath).arg(numThreads)); - // ONNX Session 构造函数在 Linux 上使用 const char* 路径 auto sessionPtr = std::make_unique( *envPtr, modelPath.toUtf8().constData(), *optionsPtr); - // 全部成功后才替换成员变量 + Ort::AllocatorWithDefaultOptions allocator; + size_t inputCount = sessionPtr->GetInputCount(); + size_t outputCount = sessionPtr->GetOutputCount(); + + LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出") + .arg(inputCount).arg(outputCount)); + + inputNames.clear(); + outputNames.clear(); + + for (size_t i = 0; i < inputCount; i++) { + auto namePtr = sessionPtr->GetInputNameAllocated(i, allocator); + inputNames.emplace_back(namePtr.get()); + LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(namePtr.get())); + } + + for (size_t i = 0; i < outputCount; i++) { + auto namePtr = sessionPtr->GetOutputNameAllocated(i, allocator); + outputNames.emplace_back(namePtr.get()); + LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(namePtr.get())); + } + env = std::move(envPtr); sessionOptions = std::move(optionsPtr); session = std::move(sessionPtr); @@ -72,12 +101,10 @@ struct STTEngine::Impl { LOG_ERROR(kTag, errorMsg); return false; } -#else - errorMsg = "ONNX Runtime 未编译启用"; - LOG_ERROR(kTag, errorMsg); - return false; -#endif } + + QMutex mutex; +#endif }; STTEngine::STTEngine(QObject* parent) @@ -122,12 +149,10 @@ void STTEngine::loadModelAsync(const QString& modelPath, LOG_INFO(kTag, QString("异步加载模型: %1").arg(modelPath)); - // 在后台线程中执行加载 QFuture future = QtConcurrent::run([this, modelPath, device, numThreads]() { QString errorMsg; bool success = impl_->loadInWorker(modelPath, device, numThreads, errorMsg); - // 回到主线程发送信号 QMetaObject::invokeMethod(this, [this, modelPath, errorMsg, success]() { loaded_ = success; if (success) { @@ -156,13 +181,48 @@ bool STTEngine::isLoaded() const { return loaded_; } +int STTEngine::vocabSize() const { + return 51865; +} + +/** argmax: 寻找数组中最大值的索引 */ +static int argmax(const float* data, int start, int end) { + int bestIdx = start; + float bestVal = data[start]; + for (int i = start + 1; i < end; i++) { + if (data[i] > bestVal) { + bestVal = data[i]; + bestIdx = i; + } + } + return bestIdx; +} + +/** softmax 计算 */ +static std::vector softmax(const float* data, int start, int end) { + float maxVal = -1e9f; + for (int i = start; i < end; i++) { + maxVal = std::max(maxVal, data[i]); + } + float sum = 0.0f; + std::vector probs(end - start); + for (int i = start; i < end; i++) { + probs[i - start] = std::exp(data[i] - maxVal); + sum += probs[i - start]; + } + for (float& p : probs) p /= sum; + return probs; +} + RecognitionResult STTEngine::infer(const std::vector& samples, int sampleRate, - bool isStreaming) + const QString& language) { Timer timer; RecognitionResult result; + (void)language; + #ifdef HAVE_ONNXRUNTIME if (!loaded_) { result.text = "[错误] 模型未加载"; @@ -171,29 +231,115 @@ RecognitionResult STTEngine::infer(const std::vector& samples, } try { - // 标记未使用的参数,消除编译警告 - (void)samples; - (void)sampleRate; - (void)isStreaming; + // 1. 计算 Mel 频谱图 + Timer melTimer; + MelSpectrogram melExtractor(kMelBins, 400, 160, sampleRate); + std::vector melSpec = melExtractor.compute(samples); + int nFrames = melExtractor.nFrames(static_cast(samples.size())); + if (nFrames <= 0) nFrames = 1; + LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)").arg(melTimer.elapsedMs(), 0, 'f', 1).arg(nFrames)); - // TODO: 实现完整的 ONNX 推理流程 - // 1. 创建输入 Tensor - // 2. 运行推理 - // 3. 解码输出 (CTC / 自回归) - // 4. Tokenizer 解码文本 + // 2. 运行 ONNX 推理 + Timer inferTimer; + QMutexLocker locker(&impl_->mutex); + + int64_t melShape[] = {1, kMelBins, static_cast(nFrames)}; + auto memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector inputTensors; + inputTensors.push_back(Ort::Value::CreateTensor( + memInfo, melSpec.data(), melSpec.size(), melShape, 3)); + + std::vector inputNamePtrs; + for (auto& name : impl_->inputNames) inputNamePtrs.push_back(name.c_str()); + std::vector outputNamePtrs; + for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str()); + + auto outputTensors = impl_->session->Run( + Ort::RunOptions{nullptr}, + inputNamePtrs.data(), inputTensors.data(), inputTensors.size(), + outputNamePtrs.data(), impl_->outputNames.size()); + + LOG_DEBUG(kTag, QString("ONNX 推理: %1 ms").arg(inferTimer.elapsedMs(), 0, 'f', 1)); + + // 3. 解析输出 + auto& outputTensor = outputTensors[0]; + auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape(); + const float* outputData = outputTensor.GetTensorMutableData(); + + LOG_DEBUG(kTag, QString("输出维度: %1").arg(shape.size())); + for (size_t i = 0; i < shape.size(); i++) { + LOG_DEBUG(kTag, QString(" dim[%1] = %2").arg(i).arg(shape[i])); + } + + int vocabSize = 51865; + std::vector tokens; + + if (shape.size() == 2 && shape[1] == vocabSize) { + // [1, vocab_size] - 直接输出 + int bestToken = argmax(outputData, 0, std::min(vocabSize, 50256)); + if (!WhisperTokenizer::isSpecialToken(bestToken)) { + tokens.push_back(bestToken); + } + auto probs = softmax(outputData, 0, std::min(vocabSize, 50256)); + float maxProb = probs[0]; + for (size_t i = 1; i < probs.size(); i++) { + if (probs[i] > maxProb) maxProb = probs[i]; + } + result.confidence = maxProb; + + } else if (shape.size() >= 3) { + // [1, seq_len, vocab_size] - 自回归输出 + int seqLen = static_cast(shape[1]); + vocabSize = static_cast(shape[2]); + + for (int t = 0; t < seqLen && static_cast(tokens.size()) < kMaxTokens; t++) { + int offset = t * vocabSize; + int bestToken = argmax(outputData, offset, offset + vocabSize); + if (WhisperTokenizer::isSpecialToken(bestToken)) break; + if (!tokens.empty() && tokens.back() == bestToken) continue; + tokens.push_back(bestToken); + } + + if (!tokens.empty()) { + float avgConf = 0.0f; + for (int t = 0; t < seqLen && t < static_cast(tokens.size()); t++) { + int offset = t * vocabSize; + int bestToken = argmax(outputData, offset, offset + vocabSize); + auto probs = softmax(outputData, offset, offset + vocabSize); + avgConf += probs[bestToken - offset]; + } + result.confidence = avgConf / tokens.size(); + } + } else { + result.text = QString("[错误] 不支持的输出维度: %1").arg(shape.size()); + result.latency_ms = timer.elapsedMs(); + return result; + } + + // 4. 解码 token 为文本 + if (tokens.empty()) { + result.text = ""; + } else { + QString decodedText; + for (int token : tokens) { + if (token < 0 || token >= 50256) continue; + decodedText += QString("[T%1]").arg(token); + } + result.text = decodedText; + } - result.text = "[占位] 推理逻辑待实现"; - result.confidence = 0.95f; result.isFinal = true; + } catch (const std::exception& e) { result.text = QString("[错误] 推理失败: %1").arg(e.what()); + LOG_ERROR(kTag, result.text); } #else - result.text = "[占位] ONNX Runtime 未启用,推理逻辑未实现"; + result.text = "[占位] ONNX Runtime 未启用"; #endif result.latency_ms = timer.elapsedMs(); - LOG_DEBUG(kTag, QString("推理耗时: %1 ms").arg(result.latency_ms, 0, 'f', 1)); + LOG_DEBUG(kTag, QString("推理总耗时: %1 ms").arg(result.latency_ms, 0, 'f', 1)); return result; } diff --git a/src/core/stt_engine.h b/src/core/stt_engine.h index 73e7394..3bd6588 100644 --- a/src/core/stt_engine.h +++ b/src/core/stt_engine.h @@ -18,6 +18,7 @@ struct RecognitionResult { * @brief STT 推理引擎 * * 封装 ONNX Runtime 推理逻辑,负责模型加载、音频推理和结果输出。 + * 支持 Whisper ONNX 模型(单模型或 encoder/decoder 分离模型)。 * 模型加载在后台线程执行,不阻塞 UI。 */ class STTEngine : public QObject { @@ -26,7 +27,7 @@ public: explicit STTEngine(QObject* parent = nullptr); ~STTEngine() override; - /** @brief 同步加载模型(阻塞,不推荐在 UI 线程调用) */ + /** @brief 同步加载模型 */ bool loadModelSync(const QString& modelPath, const QString& device = "cpu", int numThreads = 4); @@ -42,15 +43,18 @@ public: /** @brief 是否已加载模型 */ bool isLoaded() const; + /** @brief 获取词表大小(加载模型后可查询) */ + int vocabSize() const; + /** * @brief 推理音频数据 * @param samples 归一化后的 PCM 浮点样本(范围 [-1, 1]) * @param sampleRate 采样率 - * @param isStreaming 是否流式推理 + * @param language 识别语言代码(如 "zh", "en"),空则自动检测 */ RecognitionResult infer(const std::vector& samples, int sampleRate, - bool isStreaming = true); + const QString& language = QString()); signals: void modelLoaded(const QString& modelPath); diff --git a/src/core/whisper_tokenizer.cpp b/src/core/whisper_tokenizer.cpp new file mode 100644 index 0000000..6953a01 --- /dev/null +++ b/src/core/whisper_tokenizer.cpp @@ -0,0 +1,103 @@ +#include "whisper_tokenizer.h" +#include "utils/logger.h" +#include +#include + +static const char* const kTag = "WhisperTokenizer"; + +namespace impress { + +WhisperTokenizer::WhisperTokenizer() = default; + +bool WhisperTokenizer::loadVocabulary(const QString& vocabPath) { + QFile file(vocabPath); + if (!file.open(QIODevice::ReadOnly | QIODevice::Text)) { + LOG_ERROR(kTag, QString("无法打开词表文件: %1").arg(vocabPath)); + return false; + } + + QTextStream stream(&file); + stream.setEncoding(QStringConverter::Utf8); + + tokenToString_.clear(); + stringToToken_.clear(); + + // 支持两种格式: + // 1. tiktoken base64 格式: " " + // 2. 纯文本格式: " " + int lineCount = 0; + while (!stream.atEnd()) { + QString line = stream.readLine().trimmed(); + if (line.isEmpty()) continue; + + // 查找最后一个空格分隔 token_id + int lastSpace = line.lastIndexOf(' '); + if (lastSpace < 0) continue; + + bool ok = false; + int tokenId = line.mid(lastSpace + 1).toInt(&ok); + if (!ok) continue; + + QString tokenStr = line.left(lastSpace); + tokenToString_[tokenId] = tokenStr; + stringToToken_[tokenStr] = tokenId; + lineCount++; + } + + LOG_INFO(kTag, QString("词表已加载: %1 个词条 (文件: %2)").arg(lineCount).arg(vocabPath)); + return !tokenToString_.empty(); +} + +QString WhisperTokenizer::decode(const std::vector& tokens) const { + QString result; + for (int token : tokens) { + if (isSpecialToken(token)) continue; + + auto it = tokenToString_.find(token); + if (it != tokenToString_.end()) { + QString decoded = decodeBytePair(it->second); + result += decoded; + } else { + result += QString("<|token:%1|>").arg(token); + } + } + return result; +} + +std::vector WhisperTokenizer::encode(const QString& text) const { + std::vector tokens; + // 简单的字符级编码(实际 BPE 编码需要完整实现) + for (int i = 0; i < text.length(); i++) { + QString ch = text.mid(i, 1); + auto it = stringToToken_.find(ch); + if (it != stringToToken_.end()) { + tokens.push_back(it->second); + } + } + return tokens; +} + +QString WhisperTokenizer::decodeBytePair(const QString& text) const { + // Whisper 使用 unicode 转义如 Ġ 表示空格 + QString result = text; + result.replace(QChar(0x0120), ' '); // Ġ -> space + result.replace(QChar(0x010A), '\n'); // Ċ -> newline + return result; +} + +int WhisperTokenizer::languageTokenId(const QString& langCode) { + static const std::unordered_map langMap = { + {"zh", 50260}, {"en", 50259}, {"ja", 50261}, {"ko", 50262}, + {"fr", 50265}, {"de", 50266}, {"es", 50267}, {"ru", 50268}, + {"pt", 50269}, {"it", 50270}, {"auto", 50359} + }; + auto it = langMap.find(langCode); + return it != langMap.end() ? it->second : 50259; // 默认英语 +} + +bool WhisperTokenizer::isSpecialToken(int token) { + // Whisper 特殊 token 范围: [50257, 50362] + return token >= 50257 && token <= 50363; +} + +} // namespace impress diff --git a/src/core/whisper_tokenizer.h b/src/core/whisper_tokenizer.h new file mode 100644 index 0000000..8b05838 --- /dev/null +++ b/src/core/whisper_tokenizer.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace impress { + +/** + * @brief Whisper Tokenizer + * + * 基于 BPE 的分词器,支持 Whisper 模型的 token 编解码。 + * 从 tiktoken 格式的词汇表文件加载。 + */ +class WhisperTokenizer { +public: + WhisperTokenizer(); + + /** @brief 从 tiktoken 格式的词汇表文件加载 */ + bool loadVocabulary(const QString& vocabPath); + + /** @brief 将 token IDs 解码为文本 */ + QString decode(const std::vector& tokens) const; + + /** @brief 将文本编码为 token IDs(用于 prompt) */ + std::vector encode(const QString& text) const; + + /** @brief 是否已加载词表 */ + bool isLoaded() const { return !tokenToString_.empty(); } + + /** @brief 词表大小 */ + int vocabSize() const { return static_cast(tokenToString_.size()); } + + // Whisper 特殊 token + static constexpr int kTokenEndOfText = 50257; + static constexpr int kTokenEndOfSpeech = 50256; + static constexpr int kTokenNoSpeech = 50362; + static constexpr int kTokenTranscription = 50359; + + // 语言 token 起始偏移 + static constexpr int kTokenLanguageBase = 50259; + + /** @brief 获取语言 token ID */ + static int languageTokenId(const QString& langCode); + + /** @brief 判断是否为特殊 token */ + static bool isSpecialToken(int token); + +private: + std::unordered_map tokenToString_; + std::unordered_map stringToToken_; + + QString decodeBytePair(const QString& text) const; +}; + +} // namespace impress diff --git a/src/ui/file_transcribe_page.cpp b/src/ui/file_transcribe_page.cpp index a785d77..19a7e09 100644 --- a/src/ui/file_transcribe_page.cpp +++ b/src/ui/file_transcribe_page.cpp @@ -188,7 +188,8 @@ void FileTranscribePage::processNextFile() { const auto& samples = audioDecoder_->samples(); int sampleRate = audioDecoder_->sampleRate(); - auto result = sttEngine_->infer(samples, sampleRate, false); + auto result = sttEngine_->infer(samples, sampleRate, + configManager_->get("stt.language").toString()); task.result = result.text; task.status = "完成"; task.progress = 1.0; diff --git a/src/ui/stt_test_page.cpp b/src/ui/stt_test_page.cpp index 62753b3..eda0aef 100644 --- a/src/ui/stt_test_page.cpp +++ b/src/ui/stt_test_page.cpp @@ -211,7 +211,8 @@ void STTTestPage::processAudioChunk(const std::vector& samples, int sampl return; } - auto result = sttEngine_->infer(samples, sampleRate, true); + auto result = sttEngine_->infer(samples, sampleRate, + configManager_->get("stt.language").toString()); emit onRecognitionResult(result.text, result.confidence, result.latency_ms, result.isFinal); }