feat: 添加 VAD 模块和单元测试框架
- 新增 VoiceActivityDetector 基于能量+过零率的语音活动检测 - 引入 Catch2 单元测试框架 - 添加 4 个测试模块: AudioProcessor/VAD/MelSpectrogram/WhisperTokenizer - 从构建中移除废弃的 tokenizer/decoder 文件 - 39 个测试用例全部通过 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
59c12ab931
commit
a3f1b1d9a6
@ -49,8 +49,7 @@ set(SOURCES
|
|||||||
src/core/mel_spectrogram.cpp
|
src/core/mel_spectrogram.cpp
|
||||||
src/core/whisper_tokenizer.cpp
|
src/core/whisper_tokenizer.cpp
|
||||||
src/core/audio_processor.cpp
|
src/core/audio_processor.cpp
|
||||||
src/core/decoder.cpp
|
src/core/vad.cpp
|
||||||
src/core/tokenizer.cpp
|
|
||||||
|
|
||||||
# Audio
|
# Audio
|
||||||
src/audio/audio_capture.cpp
|
src/audio/audio_capture.cpp
|
||||||
@ -80,8 +79,7 @@ set(HEADERS
|
|||||||
src/core/mel_spectrogram.h
|
src/core/mel_spectrogram.h
|
||||||
src/core/whisper_tokenizer.h
|
src/core/whisper_tokenizer.h
|
||||||
src/core/audio_processor.h
|
src/core/audio_processor.h
|
||||||
src/core/decoder.h
|
src/core/vad.h
|
||||||
src/core/tokenizer.h
|
|
||||||
|
|
||||||
src/audio/audio_capture.h
|
src/audio/audio_capture.h
|
||||||
src/audio/audio_decoder.h
|
src/audio/audio_decoder.h
|
||||||
|
|||||||
119
src/core/vad.cpp
Normal file
119
src/core/vad.cpp
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
#include "vad.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace impress {
|
||||||
|
|
||||||
|
VoiceActivityDetector::VoiceActivityDetector(int sampleRate,
|
||||||
|
int frameMs,
|
||||||
|
float energyThreshold,
|
||||||
|
int minVoiceFrames)
|
||||||
|
: sampleRate_(sampleRate)
|
||||||
|
, frameSize_(sampleRate * frameMs / 1000)
|
||||||
|
, energyThreshold_(energyThreshold)
|
||||||
|
, minVoiceFrames_(minVoiceFrames)
|
||||||
|
{}
|
||||||
|
|
||||||
|
float VoiceActivityDetector::computeEnergy(const std::vector<float>& samples) const {
|
||||||
|
if (samples.empty()) return 0.0f;
|
||||||
|
|
||||||
|
float energy = 0.0f;
|
||||||
|
for (float s : samples) {
|
||||||
|
energy += s * s;
|
||||||
|
}
|
||||||
|
return energy / static_cast<float>(samples.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
float VoiceActivityDetector::computeZeroCrossingRate(const std::vector<float>& samples) const {
|
||||||
|
if (samples.size() < 2) return 0.0f;
|
||||||
|
|
||||||
|
int crossings = 0;
|
||||||
|
for (size_t i = 1; i < samples.size(); i++) {
|
||||||
|
if ((samples[i] >= 0.0f) != (samples[i - 1] >= 0.0f)) {
|
||||||
|
crossings++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return static_cast<float>(crossings) / static_cast<float>(samples.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool VoiceActivityDetector::process(const std::vector<float>& samples) {
|
||||||
|
currentEnergy_ = computeEnergy(samples);
|
||||||
|
zeroCrossingRate_ = computeZeroCrossingRate(samples);
|
||||||
|
|
||||||
|
// 能量 + 过零率联合判定
|
||||||
|
bool isVoice = false;
|
||||||
|
if (currentEnergy_ > energyThreshold_) {
|
||||||
|
// 高能量 + 低过零率 -> 语音
|
||||||
|
if (zeroCrossingRate_ < 0.35f) {
|
||||||
|
isVoice = true;
|
||||||
|
}
|
||||||
|
// 高能量 + 高过零率 -> 可能是摩擦音 /f/ /s/ 等
|
||||||
|
else if (zeroCrossingRate_ < 0.5f && currentEnergy_ > energyThreshold_ * 3.0f) {
|
||||||
|
isVoice = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 状态机:连续多帧语音才判定为"正在说话"
|
||||||
|
if (isVoice) {
|
||||||
|
consecutiveVoiceFrames_++;
|
||||||
|
} else {
|
||||||
|
consecutiveVoiceFrames_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool wasSpeaking = isSpeaking_;
|
||||||
|
isSpeaking_ = (consecutiveVoiceFrames_ >= minVoiceFrames_);
|
||||||
|
|
||||||
|
return isSpeaking_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<VoiceActivityDetector::SpeechSegment>
|
||||||
|
VoiceActivityDetector::processBatch(const std::vector<float>& samples)
|
||||||
|
{
|
||||||
|
std::vector<SpeechSegment> segments;
|
||||||
|
if (samples.empty()) return segments;
|
||||||
|
|
||||||
|
// 逐帧处理
|
||||||
|
int numSamples = static_cast<int>(samples.size());
|
||||||
|
int totalFrames = numSamples / frameSize_;
|
||||||
|
if (totalFrames == 0) totalFrames = 1;
|
||||||
|
|
||||||
|
SpeechSegment current;
|
||||||
|
current.startFrame = -1;
|
||||||
|
bool inSpeech = false;
|
||||||
|
|
||||||
|
for (int f = 0; f < totalFrames; f++) {
|
||||||
|
int start = f * frameSize_;
|
||||||
|
int end = std::min(start + frameSize_, numSamples);
|
||||||
|
std::vector<float> frame(samples.begin() + start, samples.begin() + end);
|
||||||
|
|
||||||
|
bool voice = process(frame);
|
||||||
|
|
||||||
|
if (voice && !inSpeech) {
|
||||||
|
// 语音段开始
|
||||||
|
current.startFrame = f;
|
||||||
|
inSpeech = true;
|
||||||
|
} else if (!voice && inSpeech) {
|
||||||
|
// 语音段结束
|
||||||
|
current.endFrame = f - 1;
|
||||||
|
if (current.endFrame - current.startFrame >= minVoiceFrames_) {
|
||||||
|
segments.push_back(current);
|
||||||
|
}
|
||||||
|
inSpeech = false;
|
||||||
|
consecutiveVoiceFrames_ = 0;
|
||||||
|
isSpeaking_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果末尾仍在语音段中
|
||||||
|
if (inSpeech) {
|
||||||
|
current.endFrame = totalFrames - 1;
|
||||||
|
if (current.endFrame - current.startFrame >= minVoiceFrames_) {
|
||||||
|
segments.push_back(current);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return segments;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace impress
|
||||||
72
src/core/vad.h
Normal file
72
src/core/vad.h
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace impress {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 语音活动检测 (Voice Activity Detection)
|
||||||
|
*
|
||||||
|
* 基于短时能量和过零率的简单 VAD 实现。
|
||||||
|
* 用于实时语音流中检测有效语音段,过滤静音。
|
||||||
|
*/
|
||||||
|
class VoiceActivityDetector {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief 构造函数
|
||||||
|
* @param sampleRate 采样率
|
||||||
|
* @param frameMs 帧长度(毫秒)
|
||||||
|
* @param energyThreshold 能量阈值(归一化,建议 0.01-0.05)
|
||||||
|
* @param minVoiceFrames 判定为语音所需的最小连续帧数
|
||||||
|
*/
|
||||||
|
VoiceActivityDetector(int sampleRate = 16000,
|
||||||
|
int frameMs = 30,
|
||||||
|
float energyThreshold = 0.02f,
|
||||||
|
int minVoiceFrames = 3);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 处理一帧音频数据
|
||||||
|
* @param samples 归一化 PCM 浮点数据 [-1, 1]
|
||||||
|
* @return true = 检测到语音, false = 静音
|
||||||
|
*/
|
||||||
|
bool process(const std::vector<float>& samples);
|
||||||
|
|
||||||
|
/** @brief 获取当前帧能量 */
|
||||||
|
float currentEnergy() const { return currentEnergy_; }
|
||||||
|
|
||||||
|
/** @brief 获取过零率 */
|
||||||
|
float zeroCrossingRate() const { return zeroCrossingRate_; }
|
||||||
|
|
||||||
|
/** @brief 是否正在说话 */
|
||||||
|
bool isSpeaking() const { return isSpeaking_; }
|
||||||
|
|
||||||
|
/** @brief 连续语音帧计数 */
|
||||||
|
int consecutiveVoiceFrames() const { return consecutiveVoiceFrames_; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 处理多帧音频数据(整段)
|
||||||
|
* @param samples 归一化 PCM 浮点数据 [-1, 1]
|
||||||
|
* @return 检测到的语音段起始/结束帧索引
|
||||||
|
*/
|
||||||
|
struct SpeechSegment {
|
||||||
|
int startFrame;
|
||||||
|
int endFrame;
|
||||||
|
};
|
||||||
|
std::vector<SpeechSegment> processBatch(const std::vector<float>& samples);
|
||||||
|
|
||||||
|
private:
|
||||||
|
float computeEnergy(const std::vector<float>& samples) const;
|
||||||
|
float computeZeroCrossingRate(const std::vector<float>& samples) const;
|
||||||
|
|
||||||
|
int sampleRate_;
|
||||||
|
int frameSize_;
|
||||||
|
float energyThreshold_;
|
||||||
|
int minVoiceFrames_;
|
||||||
|
|
||||||
|
float currentEnergy_ = 0.0f;
|
||||||
|
float zeroCrossingRate_ = 0.0f;
|
||||||
|
bool isSpeaking_ = false;
|
||||||
|
int consecutiveVoiceFrames_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace impress
|
||||||
47
tests/CMakeLists.txt
Normal file
47
tests/CMakeLists.txt
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
|
# 使用 Catch2 作为测试框架(通过 FetchContent 自动下载)
|
||||||
|
include(FetchContent)
|
||||||
|
FetchContent_Declare(
|
||||||
|
Catch2
|
||||||
|
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
|
||||||
|
GIT_TAG v3.6.0
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(Catch2)
|
||||||
|
|
||||||
|
# 测试需要的源文件
|
||||||
|
set(TEST_SOURCES
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/audio_processor.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/mel_spectrogram.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/whisper_tokenizer.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/core/vad.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/utils/logger.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/utils/timer.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试可执行文件
|
||||||
|
add_executable(tests
|
||||||
|
test_audio_processor.cpp
|
||||||
|
test_vad.cpp
|
||||||
|
test_mel_spectrogram.cpp
|
||||||
|
test_whisper_tokenizer.cpp
|
||||||
|
${TEST_SOURCES}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(tests PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../src
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../third_party
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(tests PRIVATE
|
||||||
|
Catch2::Catch2WithMain
|
||||||
|
Qt6::Core
|
||||||
|
pthread
|
||||||
|
)
|
||||||
|
|
||||||
|
target_compile_features(tests PRIVATE cxx_std_17)
|
||||||
|
|
||||||
|
# 注册测试
|
||||||
|
include(Catch)
|
||||||
|
catch_discover_tests(tests)
|
||||||
125
tests/test_audio_processor.cpp
Normal file
125
tests/test_audio_processor.cpp
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers_floating_point.hpp>
|
||||||
|
#include "core/audio_processor.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
using namespace impress;
|
||||||
|
using Catch::Matchers::WithinAbs;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// AudioProcessor 测试
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
TEST_CASE("归一化 PCM16 数据", "[audio_processor]") {
|
||||||
|
std::vector<short> pcm16 = {0, 16384, -16384, 32767, -32768};
|
||||||
|
|
||||||
|
auto normalized = AudioProcessor::normalize(pcm16);
|
||||||
|
|
||||||
|
REQUIRE(normalized.size() == pcm16.size());
|
||||||
|
REQUIRE_THAT(normalized[0], WithinAbs(0.0f, 1e-5f));
|
||||||
|
REQUIRE_THAT(normalized[1], WithinAbs(0.5f, 1e-5f));
|
||||||
|
REQUIRE_THAT(normalized[2], WithinAbs(-0.5f, 1e-5f));
|
||||||
|
REQUIRE_THAT(normalized[3], WithinAbs(32767.0f / 32768.0f, 1e-5f));
|
||||||
|
REQUIRE_THAT(normalized[4], WithinAbs(-1.0f, 1e-5f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("归一化全零 PCM16 数据", "[audio_processor]") {
|
||||||
|
std::vector<short> pcm16(1000, 0);
|
||||||
|
auto normalized = AudioProcessor::normalize(pcm16);
|
||||||
|
|
||||||
|
for (float v : normalized) {
|
||||||
|
REQUIRE_THAT(v, WithinAbs(0.0f, 1e-5f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("浮点数据归一化", "[audio_processor]") {
|
||||||
|
std::vector<float> input = {0.0f, 0.5f, -0.5f, 1.0f, -1.0f};
|
||||||
|
|
||||||
|
auto normalized = AudioProcessor::normalizeFloats(input);
|
||||||
|
|
||||||
|
// 最大绝对值为 1.0,所以数据不变
|
||||||
|
REQUIRE_THAT(normalized[0], WithinAbs(0.0f, 1e-5f));
|
||||||
|
REQUIRE_THAT(normalized[3], WithinAbs(1.0f, 1e-5f));
|
||||||
|
REQUIRE_THAT(normalized[4], WithinAbs(-1.0f, 1e-5f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("浮点数据归一化 - 小值放大", "[audio_processor]") {
|
||||||
|
std::vector<float> input = {0.001f, -0.002f, 0.003f};
|
||||||
|
|
||||||
|
auto normalized = AudioProcessor::normalizeFloats(input);
|
||||||
|
|
||||||
|
// 最大绝对值 0.003,归一化后最大值应为 1.0
|
||||||
|
REQUIRE_THAT(normalized[2], WithinAbs(1.0f, 1e-5f));
|
||||||
|
REQUIRE_THAT(normalized[0], WithinAbs(1.0f / 3.0f, 1e-4f));
|
||||||
|
REQUIRE_THAT(normalized[1], WithinAbs(-2.0f / 3.0f, 1e-4f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("等采样率重采样 - 数据不变", "[audio_processor]") {
|
||||||
|
AudioProcessor processor(16000);
|
||||||
|
std::vector<float> input = {0.0f, 0.5f, -0.5f, 1.0f};
|
||||||
|
|
||||||
|
auto output = processor.resample(input, 16000);
|
||||||
|
|
||||||
|
REQUIRE(output.size() == input.size());
|
||||||
|
for (size_t i = 0; i < input.size(); i++) {
|
||||||
|
REQUIRE_THAT(output[i], WithinAbs(input[i], 1e-5f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("上采样 - 样本数增加", "[audio_processor]") {
|
||||||
|
AudioProcessor processor(32000);
|
||||||
|
std::vector<float> input(1000, 0.5f);
|
||||||
|
|
||||||
|
auto output = processor.resample(input, 16000);
|
||||||
|
|
||||||
|
REQUIRE(output.size() == 2000); // 16k -> 32k = 2x
|
||||||
|
// 恒定信号值不变
|
||||||
|
REQUIRE_THAT(output[0], WithinAbs(0.5f, 1e-5f));
|
||||||
|
REQUIRE_THAT(output[1999], WithinAbs(0.5f, 1e-5f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("下采样 - 样本数减少", "[audio_processor]") {
|
||||||
|
AudioProcessor processor(8000);
|
||||||
|
std::vector<float> input(2000, 0.5f);
|
||||||
|
|
||||||
|
auto output = processor.resample(input, 16000);
|
||||||
|
|
||||||
|
REQUIRE(output.size() == 1000); // 16k -> 8k = 0.5x
|
||||||
|
REQUIRE_THAT(output[0], WithinAbs(0.5f, 1e-5f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("分帧 - 基本功能", "[audio_processor]") {
|
||||||
|
AudioProcessor processor(16000);
|
||||||
|
std::vector<float> input(1000, 1.0f);
|
||||||
|
|
||||||
|
auto frames = processor.frame(input, 200, 100);
|
||||||
|
|
||||||
|
// 帧数: (1000 - 200) / 100 + 1 = 9
|
||||||
|
REQUIRE(frames.size() == 9);
|
||||||
|
for (const auto& frame : frames) {
|
||||||
|
REQUIRE(frame.size() == 200);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("分帧 - 输入不足一帧", "[audio_processor]") {
|
||||||
|
AudioProcessor processor(16000);
|
||||||
|
std::vector<float> input(50, 1.0f);
|
||||||
|
|
||||||
|
auto frames = processor.frame(input, 200, 100);
|
||||||
|
|
||||||
|
REQUIRE(frames.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("分帧 - 精确匹配", "[audio_processor]") {
|
||||||
|
AudioProcessor processor(16000);
|
||||||
|
std::vector<float> input(400, 1.0f);
|
||||||
|
|
||||||
|
auto frames = processor.frame(input, 200, 200);
|
||||||
|
|
||||||
|
REQUIRE(frames.size() == 2);
|
||||||
|
for (const auto& frame : frames) {
|
||||||
|
REQUIRE(frame.size() == 200);
|
||||||
|
}
|
||||||
|
}
|
||||||
137
tests/test_mel_spectrogram.cpp
Normal file
137
tests/test_mel_spectrogram.cpp
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers_floating_point.hpp>
|
||||||
|
#include "core/mel_spectrogram.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using namespace impress;
|
||||||
|
using Catch::Matchers::WithinAbs;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// MelSpectrogram 测试
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
static std::vector<float> generateSilence(int numSamples) {
|
||||||
|
return std::vector<float>(numSamples, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<float> generateTone(int numSamples, float frequency,
|
||||||
|
int sampleRate = 16000,
|
||||||
|
float amplitude = 0.5f) {
|
||||||
|
std::vector<float> samples(numSamples);
|
||||||
|
for (int i = 0; i < numSamples; i++) {
|
||||||
|
samples[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sampleRate);
|
||||||
|
}
|
||||||
|
return samples;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("构造函数 - 默认参数", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel;
|
||||||
|
|
||||||
|
REQUIRE(mel.nMel() == 80);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("帧数计算 - 30秒音频", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel(80, 400, 160, 16000);
|
||||||
|
int samples = 30 * 16000; // 30秒
|
||||||
|
|
||||||
|
int frames = mel.nFrames(samples);
|
||||||
|
|
||||||
|
// (480000 - 400 + 160) / 160 = 2999
|
||||||
|
REQUIRE(frames > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("帧数计算 - 短音频", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel(80, 400, 160, 16000);
|
||||||
|
int samples = 1600; // 100ms
|
||||||
|
|
||||||
|
int frames = mel.nFrames(samples);
|
||||||
|
REQUIRE(frames > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("静音频谱图 - 低能量", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel(80, 400, 160, 16000);
|
||||||
|
auto silence = generateSilence(480000); // 30秒
|
||||||
|
|
||||||
|
auto melSpec = mel.compute(silence);
|
||||||
|
|
||||||
|
int nFrames = mel.nFrames(static_cast<int>(silence.size()));
|
||||||
|
REQUIRE(melSpec.size() == static_cast<size_t>(80 * nFrames));
|
||||||
|
|
||||||
|
// 静音经 log 压缩和归一化后应接近 0
|
||||||
|
float maxVal = 0.0f;
|
||||||
|
for (float v : melSpec) {
|
||||||
|
maxVal = std::max(maxVal, std::abs(v));
|
||||||
|
}
|
||||||
|
// 归一化后应在 [0, 1] 范围内
|
||||||
|
REQUIRE(maxVal <= 1.1f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("频谱图维度 - 匹配预期", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel(80, 400, 160, 16000);
|
||||||
|
auto tone = generateTone(480000, 440.0f, 16000, 0.5f);
|
||||||
|
|
||||||
|
auto melSpec = mel.compute(tone);
|
||||||
|
|
||||||
|
int nFrames = mel.nFrames(static_cast<int>(tone.size()));
|
||||||
|
REQUIRE(melSpec.size() == static_cast<size_t>(80 * nFrames));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("正弦波频谱图 - 能量分布", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel(80, 400, 160, 16000);
|
||||||
|
auto tone = generateTone(480000, 440.0f, 16000, 0.8f);
|
||||||
|
|
||||||
|
auto melSpec = mel.compute(tone);
|
||||||
|
|
||||||
|
int nFrames = mel.nFrames(static_cast<int>(tone.size()));
|
||||||
|
REQUIRE(melSpec.size() == static_cast<size_t>(80 * nFrames));
|
||||||
|
|
||||||
|
// 440Hz 正弦波应在低频 Mel 滤波器上有较高能量
|
||||||
|
// 计算第一帧前几个 Mel bin 的能量
|
||||||
|
float lowFreqEnergy = 0.0f;
|
||||||
|
for (int m = 0; m < 10; m++) {
|
||||||
|
lowFreqEnergy += std::abs(melSpec[m * nFrames]);
|
||||||
|
}
|
||||||
|
|
||||||
|
float highFreqEnergy = 0.0f;
|
||||||
|
for (int m = 70; m < 80; m++) {
|
||||||
|
highFreqEnergy += std::abs(melSpec[m * nFrames]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 低频能量应高于高频(440Hz 是低频信号)
|
||||||
|
REQUIRE(lowFreqEnergy > highFreqEnergy);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("不同 Mel 滤波器数量", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel40(40, 400, 160, 16000);
|
||||||
|
REQUIRE(mel40.nMel() == 40);
|
||||||
|
|
||||||
|
auto tone = generateTone(480000, 440.0f, 16000, 0.5f);
|
||||||
|
auto melSpec = mel40.compute(tone);
|
||||||
|
int nFrames = mel40.nFrames(static_cast<int>(tone.size()));
|
||||||
|
REQUIRE(melSpec.size() == static_cast<size_t>(40 * nFrames));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("频谱图归一化 - 值在合理范围内", "[mel_spectrogram]") {
|
||||||
|
MelSpectrogram mel(80, 400, 160, 16000);
|
||||||
|
auto tone = generateTone(480000, 440.0f, 16000, 0.5f);
|
||||||
|
|
||||||
|
auto melSpec = mel.compute(tone);
|
||||||
|
|
||||||
|
// Whisper 归一化后的值通常在 [0, 1] 附近
|
||||||
|
// 但由于公式 (v - offset) / -kMinLevel,最小值可能为负
|
||||||
|
// 当 globalMin < kMinLevel 时,最小值 ≈ (globalMin - kMinLevel) / -kMinLevel
|
||||||
|
float minVal = melSpec[0];
|
||||||
|
float maxVal = melSpec[0];
|
||||||
|
for (float v : melSpec) {
|
||||||
|
if (v < minVal) minVal = v;
|
||||||
|
if (v > maxVal) maxVal = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 值应在合理范围内(不超过 ±2)
|
||||||
|
REQUIRE(minVal >= -2.0f);
|
||||||
|
REQUIRE(maxVal <= 2.0f);
|
||||||
|
// 动态范围不应过大
|
||||||
|
REQUIRE((maxVal - minVal) <= 3.0f);
|
||||||
|
}
|
||||||
156
tests/test_vad.cpp
Normal file
156
tests/test_vad.cpp
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers_floating_point.hpp>
|
||||||
|
#include "core/vad.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
using namespace impress;
|
||||||
|
using Catch::Matchers::WithinAbs;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// VoiceActivityDetector 测试
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
static std::vector<float> generateSilence(int numSamples) {
|
||||||
|
return std::vector<float>(numSamples, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<float> generateTone(int numSamples, float frequency,
|
||||||
|
int sampleRate = 16000,
|
||||||
|
float amplitude = 0.5f) {
|
||||||
|
std::vector<float> samples(numSamples);
|
||||||
|
for (int i = 0; i < numSamples; i++) {
|
||||||
|
samples[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sampleRate);
|
||||||
|
}
|
||||||
|
return samples;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("静音检测 - 无信号", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
auto silence = generateSilence(480); // 30ms @ 16kHz
|
||||||
|
|
||||||
|
bool result = vad.process(silence);
|
||||||
|
REQUIRE(!result); // 静音
|
||||||
|
REQUIRE(!vad.isSpeaking());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("静音检测 - 极低能量", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
std::vector<float> noise(480, 0.001f); // 极低噪声
|
||||||
|
|
||||||
|
bool result = vad.process(noise);
|
||||||
|
REQUIRE(!result); // 应判定为静音
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("语音检测 - 纯正弦波", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
auto tone = generateTone(480, 440.0f, 16000, 0.5f);
|
||||||
|
|
||||||
|
// 需要连续 minVoiceFrames 帧才能判定为语音
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
bool result = vad.process(tone);
|
||||||
|
if (i < 2) {
|
||||||
|
REQUIRE(!result); // 帧数不足
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第 3 帧后应判定为语音
|
||||||
|
bool result = vad.process(tone);
|
||||||
|
REQUIRE(result); // 连续 3 帧语音
|
||||||
|
REQUIRE(vad.isSpeaking());
|
||||||
|
REQUIRE(vad.consecutiveVoiceFrames() >= 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("语音检测 - 多帧连续", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
auto tone = generateTone(480, 440.0f, 16000, 0.8f);
|
||||||
|
|
||||||
|
// 连续输入语音帧
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
vad.process(tone);
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE(vad.isSpeaking());
|
||||||
|
REQUIRE(vad.consecutiveVoiceFrames() == 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("语音转静音 - 状态重置", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
auto tone = generateTone(480, 440.0f, 16000, 0.8f);
|
||||||
|
auto silence = generateSilence(480);
|
||||||
|
|
||||||
|
// 先产生语音
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
vad.process(tone);
|
||||||
|
}
|
||||||
|
REQUIRE(vad.isSpeaking());
|
||||||
|
|
||||||
|
// 然后静音
|
||||||
|
vad.process(silence);
|
||||||
|
REQUIRE(!vad.isSpeaking());
|
||||||
|
REQUIRE(vad.consecutiveVoiceFrames() == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("批量处理 - 静音 + 语音 + 静音", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
|
||||||
|
// 90ms 静音 + 300ms 语音 + 90ms 静音
|
||||||
|
std::vector<float> samples;
|
||||||
|
auto silence1 = generateSilence(1440);
|
||||||
|
auto tone = generateTone(4800, 440.0f, 16000, 1.0f); // 更高振幅
|
||||||
|
auto silence2 = generateSilence(1440);
|
||||||
|
|
||||||
|
samples.insert(samples.end(), silence1.begin(), silence1.end());
|
||||||
|
samples.insert(samples.end(), tone.begin(), tone.end());
|
||||||
|
samples.insert(samples.end(), silence2.begin(), silence2.end());
|
||||||
|
|
||||||
|
auto segments = vad.processBatch(samples);
|
||||||
|
|
||||||
|
// 应检测到至少一个语音段
|
||||||
|
REQUIRE(!segments.empty());
|
||||||
|
REQUIRE(segments[0].startFrame >= 1); // 至少在第一段静音之后
|
||||||
|
REQUIRE(segments[0].endFrame > segments[0].startFrame);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("批量处理 - 全静音", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
auto silence = generateSilence(16000); // 1 秒静音
|
||||||
|
|
||||||
|
auto segments = vad.processBatch(silence);
|
||||||
|
|
||||||
|
REQUIRE(segments.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("能量计算 - 纯正弦波", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
auto tone = generateTone(480, 440.0f, 16000, 1.0f);
|
||||||
|
|
||||||
|
vad.process(tone);
|
||||||
|
|
||||||
|
// 正弦波能量 ≈ 0.5 (RMS²)
|
||||||
|
REQUIRE_THAT(vad.currentEnergy(), WithinAbs(0.5f, 0.01f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("过零率 - 正弦波", "[vad]") {
|
||||||
|
VoiceActivityDetector vad(16000, 30, 0.02f, 3);
|
||||||
|
auto tone = generateTone(480, 1000.0f, 16000, 1.0f);
|
||||||
|
|
||||||
|
vad.process(tone);
|
||||||
|
|
||||||
|
// 正弦波每个周期有 2 个过零点
|
||||||
|
// 1000Hz 在 16kHz 采样率下:每 16 个样本一个周期,2 个过零点
|
||||||
|
// ZCR ≈ 2 * f / sr = 2 * 1000 / 16000 = 0.125
|
||||||
|
float expectedZcr = 2.0f * 1000.0f / 16000.0f;
|
||||||
|
REQUIRE_THAT(vad.zeroCrossingRate(), WithinAbs(expectedZcr, 0.05f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("默认构造函数参数", "[vad]") {
|
||||||
|
VoiceActivityDetector vad; // 默认参数
|
||||||
|
|
||||||
|
auto silence = generateSilence(480);
|
||||||
|
bool result = vad.process(silence);
|
||||||
|
|
||||||
|
REQUIRE(!result); // 默认参数应能正常工作
|
||||||
|
}
|
||||||
122
tests/test_whisper_tokenizer.cpp
Normal file
122
tests/test_whisper_tokenizer.cpp
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers_string.hpp>
|
||||||
|
#include "core/whisper_tokenizer.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using namespace impress;
|
||||||
|
using Catch::Matchers::ContainsSubstring;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// WhisperTokenizer 测试
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
TEST_CASE("默认构造函数 - 未加载词表", "[whisper_tokenizer]") {
|
||||||
|
WhisperTokenizer tokenizer;
|
||||||
|
|
||||||
|
REQUIRE(!tokenizer.isLoaded());
|
||||||
|
REQUIRE(tokenizer.vocabSize() == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("语言 token ID - 常用语言", "[whisper_tokenizer]") {
|
||||||
|
REQUIRE(WhisperTokenizer::languageTokenId("zh") == 50260);
|
||||||
|
REQUIRE(WhisperTokenizer::languageTokenId("en") == 50259);
|
||||||
|
REQUIRE(WhisperTokenizer::languageTokenId("ja") == 50261);
|
||||||
|
REQUIRE(WhisperTokenizer::languageTokenId("ko") == 50262);
|
||||||
|
REQUIRE(WhisperTokenizer::languageTokenId("fr") == 50265);
|
||||||
|
REQUIRE(WhisperTokenizer::languageTokenId("de") == 50266);
|
||||||
|
REQUIRE(WhisperTokenizer::languageTokenId("es") == 50267);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("语言 token ID - 未知语言回退到英语", "[whisper_tokenizer]") {
|
||||||
|
int unknownId = WhisperTokenizer::languageTokenId("unknown");
|
||||||
|
REQUIRE(unknownId == 50259); // 默认英语
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("特殊 token 检测", "[whisper_tokenizer]") {
|
||||||
|
// Whisper 特殊 token 范围: [50257, 50363]
|
||||||
|
REQUIRE(WhisperTokenizer::isSpecialToken(50257) == true);
|
||||||
|
REQUIRE(WhisperTokenizer::isSpecialToken(50362) == true);
|
||||||
|
REQUIRE(WhisperTokenizer::isSpecialToken(50363) == true);
|
||||||
|
|
||||||
|
// 非特殊 token
|
||||||
|
REQUIRE(WhisperTokenizer::isSpecialToken(0) == false);
|
||||||
|
REQUIRE(WhisperTokenizer::isSpecialToken(100) == false);
|
||||||
|
REQUIRE(WhisperTokenizer::isSpecialToken(50256) == false);
|
||||||
|
REQUIRE(WhisperTokenizer::isSpecialToken(50400) == false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("特殊 token 常量", "[whisper_tokenizer]") {
|
||||||
|
REQUIRE(WhisperTokenizer::kTokenEndOfText == 50257);
|
||||||
|
REQUIRE(WhisperTokenizer::kTokenEndOfSpeech == 50256);
|
||||||
|
REQUIRE(WhisperTokenizer::kTokenNoSpeech == 50362);
|
||||||
|
REQUIRE(WhisperTokenizer::kTokenTranscription == 50359);
|
||||||
|
REQUIRE(WhisperTokenizer::kTokenLanguageBase == 50259);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("解码空 token 列表", "[whisper_tokenizer]") {
|
||||||
|
WhisperTokenizer tokenizer;
|
||||||
|
std::vector<int> emptyTokens;
|
||||||
|
|
||||||
|
QString result = tokenizer.decode(emptyTokens);
|
||||||
|
REQUIRE(result.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("未加载词表时解码 - 使用 token ID 格式", "[whisper_tokenizer]") {
|
||||||
|
WhisperTokenizer tokenizer;
|
||||||
|
std::vector<int> tokens = {100, 200, 300};
|
||||||
|
|
||||||
|
QString result = tokenizer.decode(tokens);
|
||||||
|
|
||||||
|
// 未加载词表时,tokenToString_ 为空,走 else 分支
|
||||||
|
REQUIRE(result.contains("token:100"));
|
||||||
|
REQUIRE(result.contains("token:200"));
|
||||||
|
REQUIRE(result.contains("token:300"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("特殊 token 在解码中被跳过", "[whisper_tokenizer]") {
|
||||||
|
WhisperTokenizer tokenizer;
|
||||||
|
|
||||||
|
// 即使未加载词表,特殊 token 也应该被跳过
|
||||||
|
std::vector<int> tokens = {50257, 100, 50258, 200, 50260};
|
||||||
|
|
||||||
|
QString result = tokenizer.decode(tokens);
|
||||||
|
|
||||||
|
// 特殊 token 不应出现在结果中
|
||||||
|
REQUIRE(!result.contains("token:50257"));
|
||||||
|
REQUIRE(!result.contains("token:50258"));
|
||||||
|
REQUIRE(!result.contains("token:50260"));
|
||||||
|
// 普通 token 应该在结果中
|
||||||
|
REQUIRE(result.contains("token:100"));
|
||||||
|
REQUIRE(result.contains("token:200"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("encode 空文本", "[whisper_tokenizer]") {
|
||||||
|
WhisperTokenizer tokenizer;
|
||||||
|
|
||||||
|
auto tokens = tokenizer.encode("");
|
||||||
|
|
||||||
|
REQUIRE(tokens.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("decodeBytePair - 空格转义", "[whisper_tokenizer]") {
|
||||||
|
WhisperTokenizer tokenizer;
|
||||||
|
|
||||||
|
// 0x0120 = Ġ (unicode 空格转义)
|
||||||
|
// 这个测试验证 BPE 解码时使用的 unicode 常量有效
|
||||||
|
QChar spaceEscape(0x0120);
|
||||||
|
REQUIRE(spaceEscape.unicode() == 0x0120);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("解码 - token ID 格式输出", "[whisper_tokenizer]") {
|
||||||
|
WhisperTokenizer tokenizer;
|
||||||
|
|
||||||
|
std::vector<int> tokens = {1, 42, 1000};
|
||||||
|
QString result = tokenizer.decode(tokens);
|
||||||
|
|
||||||
|
// 未加载词表时输出 <token:ID|> 格式
|
||||||
|
REQUIRE(result.contains("token:1"));
|
||||||
|
REQUIRE(result.contains("token:42"));
|
||||||
|
REQUIRE(result.contains("token:1000"));
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user