feat: 添加 VAD 模块和单元测试框架

- 新增 VoiceActivityDetector 基于能量+过零率的语音活动检测
- 引入 Catch2 单元测试框架
- 添加 4 个测试模块: AudioProcessor/VAD/MelSpectrogram/WhisperTokenizer
- 从构建中移除废弃的 tokenizer/decoder 文件
- 39 个测试用例全部通过

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Alvin Young 2026-05-12 16:57:46 +08:00
parent 59c12ab931
commit a3f1b1d9a6
8 changed files with 780 additions and 4 deletions

View File

@ -49,8 +49,7 @@ set(SOURCES
src/core/mel_spectrogram.cpp src/core/mel_spectrogram.cpp
src/core/whisper_tokenizer.cpp src/core/whisper_tokenizer.cpp
src/core/audio_processor.cpp src/core/audio_processor.cpp
src/core/decoder.cpp src/core/vad.cpp
src/core/tokenizer.cpp
# Audio # Audio
src/audio/audio_capture.cpp src/audio/audio_capture.cpp
@ -80,8 +79,7 @@ set(HEADERS
src/core/mel_spectrogram.h src/core/mel_spectrogram.h
src/core/whisper_tokenizer.h src/core/whisper_tokenizer.h
src/core/audio_processor.h src/core/audio_processor.h
src/core/decoder.h src/core/vad.h
src/core/tokenizer.h
src/audio/audio_capture.h src/audio/audio_capture.h
src/audio/audio_decoder.h src/audio/audio_decoder.h

119
src/core/vad.cpp Normal file
View File

@ -0,0 +1,119 @@
#include "vad.h"
#include <cmath>
#include <algorithm>
#include <numeric>
namespace impress {
VoiceActivityDetector::VoiceActivityDetector(int sampleRate,
int frameMs,
float energyThreshold,
int minVoiceFrames)
: sampleRate_(sampleRate)
, frameSize_(sampleRate * frameMs / 1000)
, energyThreshold_(energyThreshold)
, minVoiceFrames_(minVoiceFrames)
{}
float VoiceActivityDetector::computeEnergy(const std::vector<float>& samples) const {
if (samples.empty()) return 0.0f;
float energy = 0.0f;
for (float s : samples) {
energy += s * s;
}
return energy / static_cast<float>(samples.size());
}
float VoiceActivityDetector::computeZeroCrossingRate(const std::vector<float>& samples) const {
if (samples.size() < 2) return 0.0f;
int crossings = 0;
for (size_t i = 1; i < samples.size(); i++) {
if ((samples[i] >= 0.0f) != (samples[i - 1] >= 0.0f)) {
crossings++;
}
}
return static_cast<float>(crossings) / static_cast<float>(samples.size() - 1);
}
bool VoiceActivityDetector::process(const std::vector<float>& samples) {
currentEnergy_ = computeEnergy(samples);
zeroCrossingRate_ = computeZeroCrossingRate(samples);
// 能量 + 过零率联合判定
bool isVoice = false;
if (currentEnergy_ > energyThreshold_) {
// 高能量 + 低过零率 -> 语音
if (zeroCrossingRate_ < 0.35f) {
isVoice = true;
}
// 高能量 + 高过零率 -> 可能是摩擦音 /f/ /s/ 等
else if (zeroCrossingRate_ < 0.5f && currentEnergy_ > energyThreshold_ * 3.0f) {
isVoice = true;
}
}
// 状态机:连续多帧语音才判定为"正在说话"
if (isVoice) {
consecutiveVoiceFrames_++;
} else {
consecutiveVoiceFrames_ = 0;
}
bool wasSpeaking = isSpeaking_;
isSpeaking_ = (consecutiveVoiceFrames_ >= minVoiceFrames_);
return isSpeaking_;
}
std::vector<VoiceActivityDetector::SpeechSegment>
VoiceActivityDetector::processBatch(const std::vector<float>& samples)
{
std::vector<SpeechSegment> segments;
if (samples.empty()) return segments;
// 逐帧处理
int numSamples = static_cast<int>(samples.size());
int totalFrames = numSamples / frameSize_;
if (totalFrames == 0) totalFrames = 1;
SpeechSegment current;
current.startFrame = -1;
bool inSpeech = false;
for (int f = 0; f < totalFrames; f++) {
int start = f * frameSize_;
int end = std::min(start + frameSize_, numSamples);
std::vector<float> frame(samples.begin() + start, samples.begin() + end);
bool voice = process(frame);
if (voice && !inSpeech) {
// 语音段开始
current.startFrame = f;
inSpeech = true;
} else if (!voice && inSpeech) {
// 语音段结束
current.endFrame = f - 1;
if (current.endFrame - current.startFrame >= minVoiceFrames_) {
segments.push_back(current);
}
inSpeech = false;
consecutiveVoiceFrames_ = 0;
isSpeaking_ = false;
}
}
// 如果末尾仍在语音段中
if (inSpeech) {
current.endFrame = totalFrames - 1;
if (current.endFrame - current.startFrame >= minVoiceFrames_) {
segments.push_back(current);
}
}
return segments;
}
} // namespace impress

72
src/core/vad.h Normal file
View File

@ -0,0 +1,72 @@
#pragma once
#include <vector>
namespace impress {
/**
* @brief (Voice Activity Detection)
*
* VAD
*
*/
class VoiceActivityDetector {
public:
/**
* @brief
* @param sampleRate
* @param frameMs
* @param energyThreshold 0.01-0.05
* @param minVoiceFrames
*/
VoiceActivityDetector(int sampleRate = 16000,
int frameMs = 30,
float energyThreshold = 0.02f,
int minVoiceFrames = 3);
/**
* @brief
* @param samples PCM [-1, 1]
* @return true = , false =
*/
bool process(const std::vector<float>& samples);
/** @brief 获取当前帧能量 */
float currentEnergy() const { return currentEnergy_; }
/** @brief 获取过零率 */
float zeroCrossingRate() const { return zeroCrossingRate_; }
/** @brief 是否正在说话 */
bool isSpeaking() const { return isSpeaking_; }
/** @brief 连续语音帧计数 */
int consecutiveVoiceFrames() const { return consecutiveVoiceFrames_; }
/**
* @brief
* @param samples PCM [-1, 1]
* @return /
*/
struct SpeechSegment {
int startFrame;
int endFrame;
};
std::vector<SpeechSegment> processBatch(const std::vector<float>& samples);
private:
float computeEnergy(const std::vector<float>& samples) const;
float computeZeroCrossingRate(const std::vector<float>& samples) const;
int sampleRate_;
int frameSize_;
float energyThreshold_;
int minVoiceFrames_;
float currentEnergy_ = 0.0f;
float zeroCrossingRate_ = 0.0f;
bool isSpeaking_ = false;
int consecutiveVoiceFrames_ = 0;
};
} // namespace impress

47
tests/CMakeLists.txt Normal file
View File

@ -0,0 +1,47 @@
cmake_minimum_required(VERSION 3.20)
# 使 Catch2 FetchContent
include(FetchContent)
FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
GIT_TAG v3.6.0
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(Catch2)
#
set(TEST_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/audio_processor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/mel_spectrogram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/whisper_tokenizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/vad.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../src/utils/logger.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../src/utils/timer.cpp
)
#
add_executable(tests
test_audio_processor.cpp
test_vad.cpp
test_mel_spectrogram.cpp
test_whisper_tokenizer.cpp
${TEST_SOURCES}
)
target_include_directories(tests PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../src
${CMAKE_CURRENT_SOURCE_DIR}/../third_party
)
target_link_libraries(tests PRIVATE
Catch2::Catch2WithMain
Qt6::Core
pthread
)
target_compile_features(tests PRIVATE cxx_std_17)
#
include(Catch)
catch_discover_tests(tests)

View File

@ -0,0 +1,125 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include "core/audio_processor.h"
#include <cmath>
#include <numeric>
using namespace impress;
using Catch::Matchers::WithinAbs;
// ============================================================================
// AudioProcessor 测试
// ============================================================================
TEST_CASE("归一化 PCM16 数据", "[audio_processor]") {
std::vector<short> pcm16 = {0, 16384, -16384, 32767, -32768};
auto normalized = AudioProcessor::normalize(pcm16);
REQUIRE(normalized.size() == pcm16.size());
REQUIRE_THAT(normalized[0], WithinAbs(0.0f, 1e-5f));
REQUIRE_THAT(normalized[1], WithinAbs(0.5f, 1e-5f));
REQUIRE_THAT(normalized[2], WithinAbs(-0.5f, 1e-5f));
REQUIRE_THAT(normalized[3], WithinAbs(32767.0f / 32768.0f, 1e-5f));
REQUIRE_THAT(normalized[4], WithinAbs(-1.0f, 1e-5f));
}
TEST_CASE("归一化全零 PCM16 数据", "[audio_processor]") {
std::vector<short> pcm16(1000, 0);
auto normalized = AudioProcessor::normalize(pcm16);
for (float v : normalized) {
REQUIRE_THAT(v, WithinAbs(0.0f, 1e-5f));
}
}
TEST_CASE("浮点数据归一化", "[audio_processor]") {
std::vector<float> input = {0.0f, 0.5f, -0.5f, 1.0f, -1.0f};
auto normalized = AudioProcessor::normalizeFloats(input);
// 最大绝对值为 1.0,所以数据不变
REQUIRE_THAT(normalized[0], WithinAbs(0.0f, 1e-5f));
REQUIRE_THAT(normalized[3], WithinAbs(1.0f, 1e-5f));
REQUIRE_THAT(normalized[4], WithinAbs(-1.0f, 1e-5f));
}
TEST_CASE("浮点数据归一化 - 小值放大", "[audio_processor]") {
std::vector<float> input = {0.001f, -0.002f, 0.003f};
auto normalized = AudioProcessor::normalizeFloats(input);
// 最大绝对值 0.003,归一化后最大值应为 1.0
REQUIRE_THAT(normalized[2], WithinAbs(1.0f, 1e-5f));
REQUIRE_THAT(normalized[0], WithinAbs(1.0f / 3.0f, 1e-4f));
REQUIRE_THAT(normalized[1], WithinAbs(-2.0f / 3.0f, 1e-4f));
}
TEST_CASE("等采样率重采样 - 数据不变", "[audio_processor]") {
AudioProcessor processor(16000);
std::vector<float> input = {0.0f, 0.5f, -0.5f, 1.0f};
auto output = processor.resample(input, 16000);
REQUIRE(output.size() == input.size());
for (size_t i = 0; i < input.size(); i++) {
REQUIRE_THAT(output[i], WithinAbs(input[i], 1e-5f));
}
}
TEST_CASE("上采样 - 样本数增加", "[audio_processor]") {
AudioProcessor processor(32000);
std::vector<float> input(1000, 0.5f);
auto output = processor.resample(input, 16000);
REQUIRE(output.size() == 2000); // 16k -> 32k = 2x
// 恒定信号值不变
REQUIRE_THAT(output[0], WithinAbs(0.5f, 1e-5f));
REQUIRE_THAT(output[1999], WithinAbs(0.5f, 1e-5f));
}
TEST_CASE("下采样 - 样本数减少", "[audio_processor]") {
AudioProcessor processor(8000);
std::vector<float> input(2000, 0.5f);
auto output = processor.resample(input, 16000);
REQUIRE(output.size() == 1000); // 16k -> 8k = 0.5x
REQUIRE_THAT(output[0], WithinAbs(0.5f, 1e-5f));
}
TEST_CASE("分帧 - 基本功能", "[audio_processor]") {
AudioProcessor processor(16000);
std::vector<float> input(1000, 1.0f);
auto frames = processor.frame(input, 200, 100);
// 帧数: (1000 - 200) / 100 + 1 = 9
REQUIRE(frames.size() == 9);
for (const auto& frame : frames) {
REQUIRE(frame.size() == 200);
}
}
TEST_CASE("分帧 - 输入不足一帧", "[audio_processor]") {
AudioProcessor processor(16000);
std::vector<float> input(50, 1.0f);
auto frames = processor.frame(input, 200, 100);
REQUIRE(frames.empty());
}
TEST_CASE("分帧 - 精确匹配", "[audio_processor]") {
AudioProcessor processor(16000);
std::vector<float> input(400, 1.0f);
auto frames = processor.frame(input, 200, 200);
REQUIRE(frames.size() == 2);
for (const auto& frame : frames) {
REQUIRE(frame.size() == 200);
}
}

View File

@ -0,0 +1,137 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include "core/mel_spectrogram.h"
#include <cmath>
#include <vector>
using namespace impress;
using Catch::Matchers::WithinAbs;
// ============================================================================
// MelSpectrogram 测试
// ============================================================================
static std::vector<float> generateSilence(int numSamples) {
return std::vector<float>(numSamples, 0.0f);
}
static std::vector<float> generateTone(int numSamples, float frequency,
int sampleRate = 16000,
float amplitude = 0.5f) {
std::vector<float> samples(numSamples);
for (int i = 0; i < numSamples; i++) {
samples[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sampleRate);
}
return samples;
}
TEST_CASE("构造函数 - 默认参数", "[mel_spectrogram]") {
MelSpectrogram mel;
REQUIRE(mel.nMel() == 80);
}
TEST_CASE("帧数计算 - 30秒音频", "[mel_spectrogram]") {
MelSpectrogram mel(80, 400, 160, 16000);
int samples = 30 * 16000; // 30秒
int frames = mel.nFrames(samples);
// (480000 - 400 + 160) / 160 = 2999
REQUIRE(frames > 0);
}
TEST_CASE("帧数计算 - 短音频", "[mel_spectrogram]") {
MelSpectrogram mel(80, 400, 160, 16000);
int samples = 1600; // 100ms
int frames = mel.nFrames(samples);
REQUIRE(frames > 0);
}
TEST_CASE("静音频谱图 - 低能量", "[mel_spectrogram]") {
MelSpectrogram mel(80, 400, 160, 16000);
auto silence = generateSilence(480000); // 30秒
auto melSpec = mel.compute(silence);
int nFrames = mel.nFrames(static_cast<int>(silence.size()));
REQUIRE(melSpec.size() == static_cast<size_t>(80 * nFrames));
// 静音经 log 压缩和归一化后应接近 0
float maxVal = 0.0f;
for (float v : melSpec) {
maxVal = std::max(maxVal, std::abs(v));
}
// 归一化后应在 [0, 1] 范围内
REQUIRE(maxVal <= 1.1f);
}
TEST_CASE("频谱图维度 - 匹配预期", "[mel_spectrogram]") {
MelSpectrogram mel(80, 400, 160, 16000);
auto tone = generateTone(480000, 440.0f, 16000, 0.5f);
auto melSpec = mel.compute(tone);
int nFrames = mel.nFrames(static_cast<int>(tone.size()));
REQUIRE(melSpec.size() == static_cast<size_t>(80 * nFrames));
}
TEST_CASE("正弦波频谱图 - 能量分布", "[mel_spectrogram]") {
MelSpectrogram mel(80, 400, 160, 16000);
auto tone = generateTone(480000, 440.0f, 16000, 0.8f);
auto melSpec = mel.compute(tone);
int nFrames = mel.nFrames(static_cast<int>(tone.size()));
REQUIRE(melSpec.size() == static_cast<size_t>(80 * nFrames));
// 440Hz 正弦波应在低频 Mel 滤波器上有较高能量
// 计算第一帧前几个 Mel bin 的能量
float lowFreqEnergy = 0.0f;
for (int m = 0; m < 10; m++) {
lowFreqEnergy += std::abs(melSpec[m * nFrames]);
}
float highFreqEnergy = 0.0f;
for (int m = 70; m < 80; m++) {
highFreqEnergy += std::abs(melSpec[m * nFrames]);
}
// 低频能量应高于高频440Hz 是低频信号)
REQUIRE(lowFreqEnergy > highFreqEnergy);
}
TEST_CASE("不同 Mel 滤波器数量", "[mel_spectrogram]") {
MelSpectrogram mel40(40, 400, 160, 16000);
REQUIRE(mel40.nMel() == 40);
auto tone = generateTone(480000, 440.0f, 16000, 0.5f);
auto melSpec = mel40.compute(tone);
int nFrames = mel40.nFrames(static_cast<int>(tone.size()));
REQUIRE(melSpec.size() == static_cast<size_t>(40 * nFrames));
}
TEST_CASE("频谱图归一化 - 值在合理范围内", "[mel_spectrogram]") {
MelSpectrogram mel(80, 400, 160, 16000);
auto tone = generateTone(480000, 440.0f, 16000, 0.5f);
auto melSpec = mel.compute(tone);
// Whisper 归一化后的值通常在 [0, 1] 附近
// 但由于公式 (v - offset) / -kMinLevel最小值可能为负
// 当 globalMin < kMinLevel 时,最小值 ≈ (globalMin - kMinLevel) / -kMinLevel
float minVal = melSpec[0];
float maxVal = melSpec[0];
for (float v : melSpec) {
if (v < minVal) minVal = v;
if (v > maxVal) maxVal = v;
}
// 值应在合理范围内(不超过 ±2
REQUIRE(minVal >= -2.0f);
REQUIRE(maxVal <= 2.0f);
// 动态范围不应过大
REQUIRE((maxVal - minVal) <= 3.0f);
}

156
tests/test_vad.cpp Normal file
View File

@ -0,0 +1,156 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include "core/vad.h"
#include <cmath>
#include <vector>
#include <numeric>
using namespace impress;
using Catch::Matchers::WithinAbs;
// ============================================================================
// VoiceActivityDetector 测试
// ============================================================================
static std::vector<float> generateSilence(int numSamples) {
return std::vector<float>(numSamples, 0.0f);
}
static std::vector<float> generateTone(int numSamples, float frequency,
int sampleRate = 16000,
float amplitude = 0.5f) {
std::vector<float> samples(numSamples);
for (int i = 0; i < numSamples; i++) {
samples[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sampleRate);
}
return samples;
}
TEST_CASE("静音检测 - 无信号", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
auto silence = generateSilence(480); // 30ms @ 16kHz
bool result = vad.process(silence);
REQUIRE(!result); // 静音
REQUIRE(!vad.isSpeaking());
}
TEST_CASE("静音检测 - 极低能量", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
std::vector<float> noise(480, 0.001f); // 极低噪声
bool result = vad.process(noise);
REQUIRE(!result); // 应判定为静音
}
TEST_CASE("语音检测 - 纯正弦波", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
auto tone = generateTone(480, 440.0f, 16000, 0.5f);
// 需要连续 minVoiceFrames 帧才能判定为语音
for (int i = 0; i < 2; i++) {
bool result = vad.process(tone);
if (i < 2) {
REQUIRE(!result); // 帧数不足
}
}
// 第 3 帧后应判定为语音
bool result = vad.process(tone);
REQUIRE(result); // 连续 3 帧语音
REQUIRE(vad.isSpeaking());
REQUIRE(vad.consecutiveVoiceFrames() >= 3);
}
TEST_CASE("语音检测 - 多帧连续", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
auto tone = generateTone(480, 440.0f, 16000, 0.8f);
// 连续输入语音帧
for (int i = 0; i < 10; i++) {
vad.process(tone);
}
REQUIRE(vad.isSpeaking());
REQUIRE(vad.consecutiveVoiceFrames() == 10);
}
TEST_CASE("语音转静音 - 状态重置", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
auto tone = generateTone(480, 440.0f, 16000, 0.8f);
auto silence = generateSilence(480);
// 先产生语音
for (int i = 0; i < 5; i++) {
vad.process(tone);
}
REQUIRE(vad.isSpeaking());
// 然后静音
vad.process(silence);
REQUIRE(!vad.isSpeaking());
REQUIRE(vad.consecutiveVoiceFrames() == 0);
}
TEST_CASE("批量处理 - 静音 + 语音 + 静音", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
// 90ms 静音 + 300ms 语音 + 90ms 静音
std::vector<float> samples;
auto silence1 = generateSilence(1440);
auto tone = generateTone(4800, 440.0f, 16000, 1.0f); // 更高振幅
auto silence2 = generateSilence(1440);
samples.insert(samples.end(), silence1.begin(), silence1.end());
samples.insert(samples.end(), tone.begin(), tone.end());
samples.insert(samples.end(), silence2.begin(), silence2.end());
auto segments = vad.processBatch(samples);
// 应检测到至少一个语音段
REQUIRE(!segments.empty());
REQUIRE(segments[0].startFrame >= 1); // 至少在第一段静音之后
REQUIRE(segments[0].endFrame > segments[0].startFrame);
}
TEST_CASE("批量处理 - 全静音", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
auto silence = generateSilence(16000); // 1 秒静音
auto segments = vad.processBatch(silence);
REQUIRE(segments.empty());
}
TEST_CASE("能量计算 - 纯正弦波", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
auto tone = generateTone(480, 440.0f, 16000, 1.0f);
vad.process(tone);
// 正弦波能量 ≈ 0.5 (RMS²)
REQUIRE_THAT(vad.currentEnergy(), WithinAbs(0.5f, 0.01f));
}
TEST_CASE("过零率 - 正弦波", "[vad]") {
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
auto tone = generateTone(480, 1000.0f, 16000, 1.0f);
vad.process(tone);
// 正弦波每个周期有 2 个过零点
// 1000Hz 在 16kHz 采样率下:每 16 个样本一个周期2 个过零点
// ZCR ≈ 2 * f / sr = 2 * 1000 / 16000 = 0.125
float expectedZcr = 2.0f * 1000.0f / 16000.0f;
REQUIRE_THAT(vad.zeroCrossingRate(), WithinAbs(expectedZcr, 0.05f));
}
TEST_CASE("默认构造函数参数", "[vad]") {
VoiceActivityDetector vad; // 默认参数
auto silence = generateSilence(480);
bool result = vad.process(silence);
REQUIRE(!result); // 默认参数应能正常工作
}

View File

@ -0,0 +1,122 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_string.hpp>
#include "core/whisper_tokenizer.h"
#include <string>
#include <vector>
using namespace impress;
using Catch::Matchers::ContainsSubstring;
// ============================================================================
// WhisperTokenizer 测试
// ============================================================================
TEST_CASE("默认构造函数 - 未加载词表", "[whisper_tokenizer]") {
WhisperTokenizer tokenizer;
REQUIRE(!tokenizer.isLoaded());
REQUIRE(tokenizer.vocabSize() == 0);
}
TEST_CASE("语言 token ID - 常用语言", "[whisper_tokenizer]") {
REQUIRE(WhisperTokenizer::languageTokenId("zh") == 50260);
REQUIRE(WhisperTokenizer::languageTokenId("en") == 50259);
REQUIRE(WhisperTokenizer::languageTokenId("ja") == 50261);
REQUIRE(WhisperTokenizer::languageTokenId("ko") == 50262);
REQUIRE(WhisperTokenizer::languageTokenId("fr") == 50265);
REQUIRE(WhisperTokenizer::languageTokenId("de") == 50266);
REQUIRE(WhisperTokenizer::languageTokenId("es") == 50267);
}
TEST_CASE("语言 token ID - 未知语言回退到英语", "[whisper_tokenizer]") {
int unknownId = WhisperTokenizer::languageTokenId("unknown");
REQUIRE(unknownId == 50259); // 默认英语
}
TEST_CASE("特殊 token 检测", "[whisper_tokenizer]") {
// Whisper 特殊 token 范围: [50257, 50363]
REQUIRE(WhisperTokenizer::isSpecialToken(50257) == true);
REQUIRE(WhisperTokenizer::isSpecialToken(50362) == true);
REQUIRE(WhisperTokenizer::isSpecialToken(50363) == true);
// 非特殊 token
REQUIRE(WhisperTokenizer::isSpecialToken(0) == false);
REQUIRE(WhisperTokenizer::isSpecialToken(100) == false);
REQUIRE(WhisperTokenizer::isSpecialToken(50256) == false);
REQUIRE(WhisperTokenizer::isSpecialToken(50400) == false);
}
TEST_CASE("特殊 token 常量", "[whisper_tokenizer]") {
REQUIRE(WhisperTokenizer::kTokenEndOfText == 50257);
REQUIRE(WhisperTokenizer::kTokenEndOfSpeech == 50256);
REQUIRE(WhisperTokenizer::kTokenNoSpeech == 50362);
REQUIRE(WhisperTokenizer::kTokenTranscription == 50359);
REQUIRE(WhisperTokenizer::kTokenLanguageBase == 50259);
}
TEST_CASE("解码空 token 列表", "[whisper_tokenizer]") {
WhisperTokenizer tokenizer;
std::vector<int> emptyTokens;
QString result = tokenizer.decode(emptyTokens);
REQUIRE(result.isEmpty());
}
TEST_CASE("未加载词表时解码 - 使用 token ID 格式", "[whisper_tokenizer]") {
WhisperTokenizer tokenizer;
std::vector<int> tokens = {100, 200, 300};
QString result = tokenizer.decode(tokens);
// 未加载词表时tokenToString_ 为空,走 else 分支
REQUIRE(result.contains("token:100"));
REQUIRE(result.contains("token:200"));
REQUIRE(result.contains("token:300"));
}
TEST_CASE("特殊 token 在解码中被跳过", "[whisper_tokenizer]") {
WhisperTokenizer tokenizer;
// 即使未加载词表,特殊 token 也应该被跳过
std::vector<int> tokens = {50257, 100, 50258, 200, 50260};
QString result = tokenizer.decode(tokens);
// 特殊 token 不应出现在结果中
REQUIRE(!result.contains("token:50257"));
REQUIRE(!result.contains("token:50258"));
REQUIRE(!result.contains("token:50260"));
// 普通 token 应该在结果中
REQUIRE(result.contains("token:100"));
REQUIRE(result.contains("token:200"));
}
TEST_CASE("encode 空文本", "[whisper_tokenizer]") {
WhisperTokenizer tokenizer;
auto tokens = tokenizer.encode("");
REQUIRE(tokens.empty());
}
TEST_CASE("decodeBytePair - 空格转义", "[whisper_tokenizer]") {
WhisperTokenizer tokenizer;
// 0x0120 = Ġ (unicode 空格转义)
// 这个测试验证 BPE 解码时使用的 unicode 常量有效
QChar spaceEscape(0x0120);
REQUIRE(spaceEscape.unicode() == 0x0120);
}
TEST_CASE("解码 - token ID 格式输出", "[whisper_tokenizer]") {
WhisperTokenizer tokenizer;
std::vector<int> tokens = {1, 42, 1000};
QString result = tokenizer.decode(tokens);
// 未加载词表时输出 <token:ID|> 格式
REQUIRE(result.contains("token:1"));
REQUIRE(result.contains("token:42"));
REQUIRE(result.contains("token:1000"));
}