diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f9b910..3add00c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,8 +49,7 @@ set(SOURCES src/core/mel_spectrogram.cpp src/core/whisper_tokenizer.cpp src/core/audio_processor.cpp - src/core/decoder.cpp - src/core/tokenizer.cpp + src/core/vad.cpp # Audio src/audio/audio_capture.cpp @@ -80,8 +79,7 @@ set(HEADERS src/core/mel_spectrogram.h src/core/whisper_tokenizer.h src/core/audio_processor.h - src/core/decoder.h - src/core/tokenizer.h + src/core/vad.h src/audio/audio_capture.h src/audio/audio_decoder.h diff --git a/src/core/vad.cpp b/src/core/vad.cpp new file mode 100644 index 0000000..f0008d0 --- /dev/null +++ b/src/core/vad.cpp @@ -0,0 +1,119 @@ +#include "vad.h" +#include +#include +#include + +namespace impress { + +VoiceActivityDetector::VoiceActivityDetector(int sampleRate, + int frameMs, + float energyThreshold, + int minVoiceFrames) + : sampleRate_(sampleRate) + , frameSize_(sampleRate * frameMs / 1000) + , energyThreshold_(energyThreshold) + , minVoiceFrames_(minVoiceFrames) +{} + +float VoiceActivityDetector::computeEnergy(const std::vector& samples) const { + if (samples.empty()) return 0.0f; + + float energy = 0.0f; + for (float s : samples) { + energy += s * s; + } + return energy / static_cast(samples.size()); +} + +float VoiceActivityDetector::computeZeroCrossingRate(const std::vector& samples) const { + if (samples.size() < 2) return 0.0f; + + int crossings = 0; + for (size_t i = 1; i < samples.size(); i++) { + if ((samples[i] >= 0.0f) != (samples[i - 1] >= 0.0f)) { + crossings++; + } + } + return static_cast(crossings) / static_cast(samples.size() - 1); +} + +bool VoiceActivityDetector::process(const std::vector& samples) { + currentEnergy_ = computeEnergy(samples); + zeroCrossingRate_ = computeZeroCrossingRate(samples); + + // 能量 + 过零率联合判定 + bool isVoice = false; + if (currentEnergy_ > energyThreshold_) { + // 高能量 + 低过零率 -> 语音 + if (zeroCrossingRate_ < 0.35f) { + isVoice = true; + } + // 高能量 + 高过零率 -> 可能是摩擦音 /f/ /s/ 等 + else if (zeroCrossingRate_ < 0.5f && currentEnergy_ > energyThreshold_ * 3.0f) { + isVoice = true; + } + } + + // 状态机:连续多帧语音才判定为"正在说话" + if (isVoice) { + consecutiveVoiceFrames_++; + } else { + consecutiveVoiceFrames_ = 0; + } + + bool wasSpeaking = isSpeaking_; + isSpeaking_ = (consecutiveVoiceFrames_ >= minVoiceFrames_); + + return isSpeaking_; +} + +std::vector +VoiceActivityDetector::processBatch(const std::vector& samples) +{ + std::vector segments; + if (samples.empty()) return segments; + + // 逐帧处理 + int numSamples = static_cast(samples.size()); + int totalFrames = numSamples / frameSize_; + if (totalFrames == 0) totalFrames = 1; + + SpeechSegment current; + current.startFrame = -1; + bool inSpeech = false; + + for (int f = 0; f < totalFrames; f++) { + int start = f * frameSize_; + int end = std::min(start + frameSize_, numSamples); + std::vector frame(samples.begin() + start, samples.begin() + end); + + bool voice = process(frame); + + if (voice && !inSpeech) { + // 语音段开始 + current.startFrame = f; + inSpeech = true; + } else if (!voice && inSpeech) { + // 语音段结束 + current.endFrame = f - 1; + if (current.endFrame - current.startFrame >= minVoiceFrames_) { + segments.push_back(current); + } + inSpeech = false; + consecutiveVoiceFrames_ = 0; + isSpeaking_ = false; + } + } + + // 如果末尾仍在语音段中 + if (inSpeech) { + current.endFrame = totalFrames - 1; + if (current.endFrame - current.startFrame >= minVoiceFrames_) { + segments.push_back(current); + } + } + + return segments; +} + +} // namespace impress diff --git a/src/core/vad.h b/src/core/vad.h new file mode 100644 index 0000000..ea97997 --- /dev/null +++ b/src/core/vad.h @@ -0,0 +1,72 @@ +#pragma once + +#include + +namespace impress { + +/** + * @brief 语音活动检测 (Voice Activity Detection) + * + * 基于短时能量和过零率的简单 VAD 实现。 + * 用于实时语音流中检测有效语音段,过滤静音。 + */ +class VoiceActivityDetector { +public: + /** + * @brief 构造函数 + * @param sampleRate 采样率 + * @param frameMs 帧长度(毫秒) + * @param energyThreshold 能量阈值(归一化,建议 0.01-0.05) + * @param minVoiceFrames 判定为语音所需的最小连续帧数 + */ + VoiceActivityDetector(int sampleRate = 16000, + int frameMs = 30, + float energyThreshold = 0.02f, + int minVoiceFrames = 3); + + /** + * @brief 处理一帧音频数据 + * @param samples 归一化 PCM 浮点数据 [-1, 1] + * @return true = 检测到语音, false = 静音 + */ + bool process(const std::vector& samples); + + /** @brief 获取当前帧能量 */ + float currentEnergy() const { return currentEnergy_; } + + /** @brief 获取过零率 */ + float zeroCrossingRate() const { return zeroCrossingRate_; } + + /** @brief 是否正在说话 */ + bool isSpeaking() const { return isSpeaking_; } + + /** @brief 连续语音帧计数 */ + int consecutiveVoiceFrames() const { return consecutiveVoiceFrames_; } + + /** + * @brief 处理多帧音频数据(整段) + * @param samples 归一化 PCM 浮点数据 [-1, 1] + * @return 检测到的语音段起始/结束帧索引 + */ + struct SpeechSegment { + int startFrame; + int endFrame; + }; + std::vector processBatch(const std::vector& samples); + +private: + float computeEnergy(const std::vector& samples) const; + float computeZeroCrossingRate(const std::vector& samples) const; + + int sampleRate_; + int frameSize_; + float energyThreshold_; + int minVoiceFrames_; + + float currentEnergy_ = 0.0f; + float zeroCrossingRate_ = 0.0f; + bool isSpeaking_ = false; + int consecutiveVoiceFrames_ = 0; +}; + +} // namespace impress diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..bde555c --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.20) + +# 使用 Catch2 作为测试框架(通过 FetchContent 自动下载) +include(FetchContent) +FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.6.0 + GIT_SHALLOW TRUE +) +FetchContent_MakeAvailable(Catch2) + +# 测试需要的源文件 +set(TEST_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/../src/core/audio_processor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../src/core/mel_spectrogram.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../src/core/whisper_tokenizer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../src/core/vad.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../src/utils/logger.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../src/utils/timer.cpp +) + +# 测试可执行文件 +add_executable(tests + test_audio_processor.cpp + test_vad.cpp + test_mel_spectrogram.cpp + test_whisper_tokenizer.cpp + ${TEST_SOURCES} +) + +target_include_directories(tests PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../src + ${CMAKE_CURRENT_SOURCE_DIR}/../third_party +) + +target_link_libraries(tests PRIVATE + Catch2::Catch2WithMain + Qt6::Core + pthread +) + +target_compile_features(tests PRIVATE cxx_std_17) + +# 注册测试 +include(Catch) +catch_discover_tests(tests) diff --git a/tests/test_audio_processor.cpp b/tests/test_audio_processor.cpp new file mode 100644 index 0000000..72a28be --- /dev/null +++ b/tests/test_audio_processor.cpp @@ -0,0 +1,125 @@ +#include +#include +#include "core/audio_processor.h" + +#include +#include + +using namespace impress; +using Catch::Matchers::WithinAbs; + +// ============================================================================ +// AudioProcessor 测试 +// ============================================================================ + +TEST_CASE("归一化 PCM16 数据", "[audio_processor]") { + std::vector pcm16 = {0, 16384, -16384, 32767, -32768}; + + auto normalized = AudioProcessor::normalize(pcm16); + + REQUIRE(normalized.size() == pcm16.size()); + REQUIRE_THAT(normalized[0], WithinAbs(0.0f, 1e-5f)); + REQUIRE_THAT(normalized[1], WithinAbs(0.5f, 1e-5f)); + REQUIRE_THAT(normalized[2], WithinAbs(-0.5f, 1e-5f)); + REQUIRE_THAT(normalized[3], WithinAbs(32767.0f / 32768.0f, 1e-5f)); + REQUIRE_THAT(normalized[4], WithinAbs(-1.0f, 1e-5f)); +} + +TEST_CASE("归一化全零 PCM16 数据", "[audio_processor]") { + std::vector pcm16(1000, 0); + auto normalized = AudioProcessor::normalize(pcm16); + + for (float v : normalized) { + REQUIRE_THAT(v, WithinAbs(0.0f, 1e-5f)); + } +} + +TEST_CASE("浮点数据归一化", "[audio_processor]") { + std::vector input = {0.0f, 0.5f, -0.5f, 1.0f, -1.0f}; + + auto normalized = AudioProcessor::normalizeFloats(input); + + // 最大绝对值为 1.0,所以数据不变 + REQUIRE_THAT(normalized[0], WithinAbs(0.0f, 1e-5f)); + REQUIRE_THAT(normalized[3], WithinAbs(1.0f, 1e-5f)); + REQUIRE_THAT(normalized[4], WithinAbs(-1.0f, 1e-5f)); +} + +TEST_CASE("浮点数据归一化 - 小值放大", "[audio_processor]") { + std::vector input = {0.001f, -0.002f, 0.003f}; + + auto normalized = AudioProcessor::normalizeFloats(input); + + // 最大绝对值 0.003,归一化后最大值应为 1.0 + REQUIRE_THAT(normalized[2], WithinAbs(1.0f, 1e-5f)); + REQUIRE_THAT(normalized[0], WithinAbs(1.0f / 3.0f, 1e-4f)); + REQUIRE_THAT(normalized[1], WithinAbs(-2.0f / 3.0f, 1e-4f)); +} + +TEST_CASE("等采样率重采样 - 数据不变", "[audio_processor]") { + AudioProcessor processor(16000); + std::vector input = {0.0f, 0.5f, -0.5f, 1.0f}; + + auto output = processor.resample(input, 16000); + + REQUIRE(output.size() == input.size()); + for (size_t i = 0; i < input.size(); i++) { + REQUIRE_THAT(output[i], WithinAbs(input[i], 1e-5f)); + } +} + +TEST_CASE("上采样 - 样本数增加", "[audio_processor]") { + AudioProcessor processor(32000); + std::vector input(1000, 0.5f); + + auto output = processor.resample(input, 16000); + + REQUIRE(output.size() == 2000); // 16k -> 32k = 2x + // 恒定信号值不变 + REQUIRE_THAT(output[0], WithinAbs(0.5f, 1e-5f)); + REQUIRE_THAT(output[1999], WithinAbs(0.5f, 1e-5f)); +} + +TEST_CASE("下采样 - 样本数减少", "[audio_processor]") { + AudioProcessor processor(8000); + std::vector input(2000, 0.5f); + + auto output = processor.resample(input, 16000); + + REQUIRE(output.size() == 1000); // 16k -> 8k = 0.5x + REQUIRE_THAT(output[0], WithinAbs(0.5f, 1e-5f)); +} + +TEST_CASE("分帧 - 基本功能", "[audio_processor]") { + AudioProcessor processor(16000); + std::vector input(1000, 1.0f); + + auto frames = processor.frame(input, 200, 100); + + // 帧数: (1000 - 200) / 100 + 1 = 9 + REQUIRE(frames.size() == 9); + for (const auto& frame : frames) { + REQUIRE(frame.size() == 200); + } +} + +TEST_CASE("分帧 - 输入不足一帧", "[audio_processor]") { + AudioProcessor processor(16000); + std::vector input(50, 1.0f); + + auto frames = processor.frame(input, 200, 100); + + REQUIRE(frames.empty()); +} + +TEST_CASE("分帧 - 精确匹配", "[audio_processor]") { + AudioProcessor processor(16000); + std::vector input(400, 1.0f); + + auto frames = processor.frame(input, 200, 200); + + REQUIRE(frames.size() == 2); + for (const auto& frame : frames) { + REQUIRE(frame.size() == 200); + } +} diff --git a/tests/test_mel_spectrogram.cpp b/tests/test_mel_spectrogram.cpp new file mode 100644 index 0000000..890a35b --- /dev/null +++ b/tests/test_mel_spectrogram.cpp @@ -0,0 +1,137 @@ +#include +#include +#include "core/mel_spectrogram.h" + +#include +#include + +using namespace impress; +using Catch::Matchers::WithinAbs; + +// ============================================================================ +// MelSpectrogram 测试 +// ============================================================================ + +static std::vector generateSilence(int numSamples) { + return std::vector(numSamples, 0.0f); +} + +static std::vector generateTone(int numSamples, float frequency, + int sampleRate = 16000, + float amplitude = 0.5f) { + std::vector samples(numSamples); + for (int i = 0; i < numSamples; i++) { + samples[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sampleRate); + } + return samples; +} + +TEST_CASE("构造函数 - 默认参数", "[mel_spectrogram]") { + MelSpectrogram mel; + + REQUIRE(mel.nMel() == 80); +} + +TEST_CASE("帧数计算 - 30秒音频", "[mel_spectrogram]") { + MelSpectrogram mel(80, 400, 160, 16000); + int samples = 30 * 16000; // 30秒 + + int frames = mel.nFrames(samples); + + // (480000 - 400 + 160) / 160 = 2999 + REQUIRE(frames > 0); +} + +TEST_CASE("帧数计算 - 短音频", "[mel_spectrogram]") { + MelSpectrogram mel(80, 400, 160, 16000); + int samples = 1600; // 100ms + + int frames = mel.nFrames(samples); + REQUIRE(frames > 0); +} + +TEST_CASE("静音频谱图 - 低能量", "[mel_spectrogram]") { + MelSpectrogram mel(80, 400, 160, 16000); + auto silence = generateSilence(480000); // 30秒 + + auto melSpec = mel.compute(silence); + + int nFrames = mel.nFrames(static_cast(silence.size())); + REQUIRE(melSpec.size() == static_cast(80 * nFrames)); + + // 静音经 log 压缩和归一化后应接近 0 + float maxVal = 0.0f; + for (float v : melSpec) { + maxVal = std::max(maxVal, std::abs(v)); + } + // 归一化后应在 [0, 1] 范围内 + REQUIRE(maxVal <= 1.1f); +} + +TEST_CASE("频谱图维度 - 匹配预期", "[mel_spectrogram]") { + MelSpectrogram mel(80, 400, 160, 16000); + auto tone = generateTone(480000, 440.0f, 16000, 0.5f); + + auto melSpec = mel.compute(tone); + + int nFrames = mel.nFrames(static_cast(tone.size())); + REQUIRE(melSpec.size() == static_cast(80 * nFrames)); +} + +TEST_CASE("正弦波频谱图 - 能量分布", "[mel_spectrogram]") { + MelSpectrogram mel(80, 400, 160, 16000); + auto tone = generateTone(480000, 440.0f, 16000, 0.8f); + + auto melSpec = mel.compute(tone); + + int nFrames = mel.nFrames(static_cast(tone.size())); + REQUIRE(melSpec.size() == static_cast(80 * nFrames)); + + // 440Hz 正弦波应在低频 Mel 滤波器上有较高能量 + // 计算第一帧前几个 Mel bin 的能量 + float lowFreqEnergy = 0.0f; + for (int m = 0; m < 10; m++) { + lowFreqEnergy += std::abs(melSpec[m * nFrames]); + } + + float highFreqEnergy = 0.0f; + for (int m = 70; m < 80; m++) { + highFreqEnergy += std::abs(melSpec[m * nFrames]); + } + + // 低频能量应高于高频(440Hz 是低频信号) + REQUIRE(lowFreqEnergy > highFreqEnergy); +} + +TEST_CASE("不同 Mel 滤波器数量", "[mel_spectrogram]") { + MelSpectrogram mel40(40, 400, 160, 16000); + REQUIRE(mel40.nMel() == 40); + + auto tone = generateTone(480000, 440.0f, 16000, 0.5f); + auto melSpec = mel40.compute(tone); + int nFrames = mel40.nFrames(static_cast(tone.size())); + REQUIRE(melSpec.size() == static_cast(40 * nFrames)); +} + +TEST_CASE("频谱图归一化 - 值在合理范围内", "[mel_spectrogram]") { + MelSpectrogram mel(80, 400, 160, 16000); + auto tone = generateTone(480000, 440.0f, 16000, 0.5f); + + auto melSpec = mel.compute(tone); + + // Whisper 归一化后的值通常在 [0, 1] 附近 + // 但由于公式 (v - offset) / -kMinLevel,最小值可能为负 + // 当 globalMin < kMinLevel 时,最小值 ≈ (globalMin - kMinLevel) / -kMinLevel + float minVal = melSpec[0]; + float maxVal = melSpec[0]; + for (float v : melSpec) { + if (v < minVal) minVal = v; + if (v > maxVal) maxVal = v; + } + + // 值应在合理范围内(不超过 ±2) + REQUIRE(minVal >= -2.0f); + REQUIRE(maxVal <= 2.0f); + // 动态范围不应过大 + REQUIRE((maxVal - minVal) <= 3.0f); +} diff --git a/tests/test_vad.cpp b/tests/test_vad.cpp new file mode 100644 index 0000000..bbd2bd2 --- /dev/null +++ b/tests/test_vad.cpp @@ -0,0 +1,156 @@ +#include +#include +#include "core/vad.h" + +#include +#include +#include + +using namespace impress; +using Catch::Matchers::WithinAbs; + +// ============================================================================ +// VoiceActivityDetector 测试 +// ============================================================================ + +static std::vector generateSilence(int numSamples) { + return std::vector(numSamples, 0.0f); +} + +static std::vector generateTone(int numSamples, float frequency, + int sampleRate = 16000, + float amplitude = 0.5f) { + std::vector samples(numSamples); + for (int i = 0; i < numSamples; i++) { + samples[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sampleRate); + } + return samples; +} + +TEST_CASE("静音检测 - 无信号", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + auto silence = generateSilence(480); // 30ms @ 16kHz + + bool result = vad.process(silence); + REQUIRE(!result); // 静音 + REQUIRE(!vad.isSpeaking()); +} + +TEST_CASE("静音检测 - 极低能量", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + std::vector noise(480, 0.001f); // 极低噪声 + + bool result = vad.process(noise); + REQUIRE(!result); // 应判定为静音 +} + +TEST_CASE("语音检测 - 纯正弦波", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + auto tone = generateTone(480, 440.0f, 16000, 0.5f); + + // 需要连续 minVoiceFrames 帧才能判定为语音 + for (int i = 0; i < 2; i++) { + bool result = vad.process(tone); + if (i < 2) { + REQUIRE(!result); // 帧数不足 + } + } + + // 第 3 帧后应判定为语音 + bool result = vad.process(tone); + REQUIRE(result); // 连续 3 帧语音 + REQUIRE(vad.isSpeaking()); + REQUIRE(vad.consecutiveVoiceFrames() >= 3); +} + +TEST_CASE("语音检测 - 多帧连续", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + auto tone = generateTone(480, 440.0f, 16000, 0.8f); + + // 连续输入语音帧 + for (int i = 0; i < 10; i++) { + vad.process(tone); + } + + REQUIRE(vad.isSpeaking()); + REQUIRE(vad.consecutiveVoiceFrames() == 10); +} + +TEST_CASE("语音转静音 - 状态重置", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + auto tone = generateTone(480, 440.0f, 16000, 0.8f); + auto silence = generateSilence(480); + + // 先产生语音 + for (int i = 0; i < 5; i++) { + vad.process(tone); + } + REQUIRE(vad.isSpeaking()); + + // 然后静音 + vad.process(silence); + REQUIRE(!vad.isSpeaking()); + REQUIRE(vad.consecutiveVoiceFrames() == 0); +} + +TEST_CASE("批量处理 - 静音 + 语音 + 静音", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + + // 90ms 静音 + 300ms 语音 + 90ms 静音 + std::vector samples; + auto silence1 = generateSilence(1440); + auto tone = generateTone(4800, 440.0f, 16000, 1.0f); // 更高振幅 + auto silence2 = generateSilence(1440); + + samples.insert(samples.end(), silence1.begin(), silence1.end()); + samples.insert(samples.end(), tone.begin(), tone.end()); + samples.insert(samples.end(), silence2.begin(), silence2.end()); + + auto segments = vad.processBatch(samples); + + // 应检测到至少一个语音段 + REQUIRE(!segments.empty()); + REQUIRE(segments[0].startFrame >= 1); // 至少在第一段静音之后 + REQUIRE(segments[0].endFrame > segments[0].startFrame); +} + +TEST_CASE("批量处理 - 全静音", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + auto silence = generateSilence(16000); // 1 秒静音 + + auto segments = vad.processBatch(silence); + + REQUIRE(segments.empty()); +} + +TEST_CASE("能量计算 - 纯正弦波", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + auto tone = generateTone(480, 440.0f, 16000, 1.0f); + + vad.process(tone); + + // 正弦波能量 ≈ 0.5 (RMS²) + REQUIRE_THAT(vad.currentEnergy(), WithinAbs(0.5f, 0.01f)); +} + +TEST_CASE("过零率 - 正弦波", "[vad]") { + VoiceActivityDetector vad(16000, 30, 0.02f, 3); + auto tone = generateTone(480, 1000.0f, 16000, 1.0f); + + vad.process(tone); + + // 正弦波每个周期有 2 个过零点 + // 1000Hz 在 16kHz 采样率下:每 16 个样本一个周期,2 个过零点 + // ZCR ≈ 2 * f / sr = 2 * 1000 / 16000 = 0.125 + float expectedZcr = 2.0f * 1000.0f / 16000.0f; + REQUIRE_THAT(vad.zeroCrossingRate(), WithinAbs(expectedZcr, 0.05f)); +} + +TEST_CASE("默认构造函数参数", "[vad]") { + VoiceActivityDetector vad; // 默认参数 + + auto silence = generateSilence(480); + bool result = vad.process(silence); + + REQUIRE(!result); // 默认参数应能正常工作 +} diff --git a/tests/test_whisper_tokenizer.cpp b/tests/test_whisper_tokenizer.cpp new file mode 100644 index 0000000..160350b --- /dev/null +++ b/tests/test_whisper_tokenizer.cpp @@ -0,0 +1,122 @@ +#include +#include +#include "core/whisper_tokenizer.h" + +#include +#include + +using namespace impress; +using Catch::Matchers::ContainsSubstring; + +// ============================================================================ +// WhisperTokenizer 测试 +// ============================================================================ + +TEST_CASE("默认构造函数 - 未加载词表", "[whisper_tokenizer]") { + WhisperTokenizer tokenizer; + + REQUIRE(!tokenizer.isLoaded()); + REQUIRE(tokenizer.vocabSize() == 0); +} + +TEST_CASE("语言 token ID - 常用语言", "[whisper_tokenizer]") { + REQUIRE(WhisperTokenizer::languageTokenId("zh") == 50260); + REQUIRE(WhisperTokenizer::languageTokenId("en") == 50259); + REQUIRE(WhisperTokenizer::languageTokenId("ja") == 50261); + REQUIRE(WhisperTokenizer::languageTokenId("ko") == 50262); + REQUIRE(WhisperTokenizer::languageTokenId("fr") == 50265); + REQUIRE(WhisperTokenizer::languageTokenId("de") == 50266); + REQUIRE(WhisperTokenizer::languageTokenId("es") == 50267); +} + +TEST_CASE("语言 token ID - 未知语言回退到英语", "[whisper_tokenizer]") { + int unknownId = WhisperTokenizer::languageTokenId("unknown"); + REQUIRE(unknownId == 50259); // 默认英语 +} + +TEST_CASE("特殊 token 检测", "[whisper_tokenizer]") { + // Whisper 特殊 token 范围: [50257, 50363] + REQUIRE(WhisperTokenizer::isSpecialToken(50257) == true); + REQUIRE(WhisperTokenizer::isSpecialToken(50362) == true); + REQUIRE(WhisperTokenizer::isSpecialToken(50363) == true); + + // 非特殊 token + REQUIRE(WhisperTokenizer::isSpecialToken(0) == false); + REQUIRE(WhisperTokenizer::isSpecialToken(100) == false); + REQUIRE(WhisperTokenizer::isSpecialToken(50256) == false); + REQUIRE(WhisperTokenizer::isSpecialToken(50400) == false); +} + +TEST_CASE("特殊 token 常量", "[whisper_tokenizer]") { + REQUIRE(WhisperTokenizer::kTokenEndOfText == 50257); + REQUIRE(WhisperTokenizer::kTokenEndOfSpeech == 50256); + REQUIRE(WhisperTokenizer::kTokenNoSpeech == 50362); + REQUIRE(WhisperTokenizer::kTokenTranscription == 50359); + REQUIRE(WhisperTokenizer::kTokenLanguageBase == 50259); +} + +TEST_CASE("解码空 token 列表", "[whisper_tokenizer]") { + WhisperTokenizer tokenizer; + std::vector emptyTokens; + + QString result = tokenizer.decode(emptyTokens); + REQUIRE(result.isEmpty()); +} + +TEST_CASE("未加载词表时解码 - 使用 token ID 格式", "[whisper_tokenizer]") { + WhisperTokenizer tokenizer; + std::vector tokens = {100, 200, 300}; + + QString result = tokenizer.decode(tokens); + + // 未加载词表时,tokenToString_ 为空,走 else 分支 + REQUIRE(result.contains("token:100")); + REQUIRE(result.contains("token:200")); + REQUIRE(result.contains("token:300")); +} + +TEST_CASE("特殊 token 在解码中被跳过", "[whisper_tokenizer]") { + WhisperTokenizer tokenizer; + + // 即使未加载词表,特殊 token 也应该被跳过 + std::vector tokens = {50257, 100, 50258, 200, 50260}; + + QString result = tokenizer.decode(tokens); + + // 特殊 token 不应出现在结果中 + REQUIRE(!result.contains("token:50257")); + REQUIRE(!result.contains("token:50258")); + REQUIRE(!result.contains("token:50260")); + // 普通 token 应该在结果中 + REQUIRE(result.contains("token:100")); + REQUIRE(result.contains("token:200")); +} + +TEST_CASE("encode 空文本", "[whisper_tokenizer]") { + WhisperTokenizer tokenizer; + + auto tokens = tokenizer.encode(""); + + REQUIRE(tokens.empty()); +} + +TEST_CASE("decodeBytePair - 空格转义", "[whisper_tokenizer]") { + WhisperTokenizer tokenizer; + + // 0x0120 = Ġ (unicode 空格转义) + // 这个测试验证 BPE 解码时使用的 unicode 常量有效 + QChar spaceEscape(0x0120); + REQUIRE(spaceEscape.unicode() == 0x0120); +} + +TEST_CASE("解码 - token ID 格式输出", "[whisper_tokenizer]") { + WhisperTokenizer tokenizer; + + std::vector tokens = {1, 42, 1000}; + QString result = tokenizer.decode(tokens); + + // 未加载词表时输出 格式 + REQUIRE(result.contains("token:1")); + REQUIRE(result.contains("token:42")); + REQUIRE(result.contains("token:1000")); +}