feat: 扩展音频格式支持与推理管线优化

- 新增 MP3/FLAC 格式解码 (dr_mp3/dr_flac)
- 修复 Mel 频谱图使用 magnitude² 替代 magnitude 的问题
- 推理管线增加音频重采样 (非 16kHz 自动转换)
- 更新 README 项目状态

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Alvin Young 2026-05-12 16:35:48 +08:00
parent 760899e81c
commit 59c12ab931
5 changed files with 139 additions and 39 deletions

View File

@ -125,8 +125,9 @@ ctest
- [x] 音频采集/解码框架 (PortAudio/dr_libs)
- [x] 三个 GUI 页面 (实时识别 / 文件转写 / 配置)
- [x] 日志系统 (控制台 + 文件输出)
- [x] 批量文件转写 (支持 WAV/MP3/FLAC)
- [x] 音频重采样 (非 16kHz 音频自动重采样)
- [ ] 完整 Whisper 推理 (自回归解码 + 流式识别)
- [ ] 批量文件转写
- [ ] 跨平台打包
## License

View File

@ -1,6 +1,8 @@
#include "audio_decoder.h"
#include "utils/logger.h"
#include <cstdint>
#ifdef HAVE_DR_LIBS
#define DR_WAV_IMPLEMENTATION
#define DR_MP3_IMPLEMENTATION
@ -23,7 +25,7 @@ AudioDecoder::AudioDecoder(QObject* parent)
AudioDecoder::~AudioDecoder() = default;
QStringList AudioDecoder::supportedFormats() {
return {"wav", "mp3", "flac", "ogg", "aac"};
return {"wav", "mp3", "flac"};
}
bool AudioDecoder::decode(const QString& filePath) {
@ -34,13 +36,37 @@ bool AudioDecoder::decode(const QString& filePath) {
sampleRate_ = 0;
channels_ = 0;
if (ext != "wav") {
#ifdef HAVE_DR_LIBS
bool success = false;
if (ext == "wav") {
success = decodeWav(filePath);
} else if (ext == "mp3") {
success = decodeMp3(filePath);
} else if (ext == "flac") {
success = decodeFlac(filePath);
} else {
LOG_ERROR(kTag, QString("暂不支持格式: %1").arg(ext));
emit error(QString("暂不支持格式: %1").arg(ext));
return false;
}
if (success) {
emit progress(1.0);
emit decoded(filePath);
LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz, %4声道)")
.arg(filePath).arg(samples_.size()).arg(sampleRate_).arg(channels_));
}
return success;
#else
LOG_ERROR(kTag, "dr_libs 未编译启用");
emit error("音频解码库未启用");
return false;
#endif
}
#ifdef HAVE_DR_LIBS
bool AudioDecoder::decodeWav(const QString& filePath) {
drwav wav;
if (!drwav_init_file(&wav, filePath.toUtf8().constData(), nullptr)) {
LOG_ERROR(kTag, QString("无法打开 WAV 文件: %1").arg(filePath));
@ -55,33 +81,80 @@ bool AudioDecoder::decode(const QString& filePath) {
drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data());
drwav_uninit(&wav);
// 多声道混合为单声道
convertToMono(pcm16, wav.totalPCMFrameCount);
return true;
}
bool AudioDecoder::decodeMp3(const QString& filePath) {
drmp3 mp3;
if (!drmp3_init_file(&mp3, filePath.toUtf8().constData(), nullptr)) {
LOG_ERROR(kTag, QString("无法打开 MP3 文件: %1").arg(filePath));
emit error("无法打开音频文件");
return false;
}
channels_ = mp3.channels;
sampleRate_ = mp3.sampleRate;
drmp3_uint64 totalFrames = drmp3_get_pcm_frame_count(&mp3);
std::vector<float> rawPcm(totalFrames * channels_);
drmp3_read_pcm_frames_f32(&mp3, totalFrames, rawPcm.data());
drmp3_uninit(&mp3);
// MP3 解码器直接输出归一化浮点数据
if (channels_ == 1) {
samples_ = std::vector<float>(pcm16.begin(), pcm16.end());
// 归一化
for (auto& s : samples_) s /= 32768.0f;
samples_ = std::move(rawPcm);
} else {
samples_.resize(wav.totalPCMFrameCount);
for (size_t i = 0; i < wav.totalPCMFrameCount; ++i) {
samples_.resize(totalFrames);
for (drmp3_uint64 i = 0; i < totalFrames; ++i) {
float sum = 0.0f;
for (int ch = 0; ch < channels_; ++ch) {
sum += pcm16[static_cast<size_t>(i) * static_cast<size_t>(channels_) + static_cast<size_t>(ch)];
sum += rawPcm[i * channels_ + ch];
}
samples_[i] = sum / channels_;
}
}
return true;
}
bool AudioDecoder::decodeFlac(const QString& filePath) {
drflac* flac = drflac_open_file(filePath.toUtf8().constData(), nullptr);
if (!flac) {
LOG_ERROR(kTag, QString("无法打开 FLAC 文件: %1").arg(filePath));
emit error("无法打开音频文件");
return false;
}
channels_ = flac->channels;
sampleRate_ = flac->sampleRate;
drflac_uint64 totalFrames = flac->totalPCMFrameCount;
std::vector<short> pcm16(totalFrames * channels_);
drflac_read_pcm_frames_s16(flac, totalFrames, pcm16.data());
drflac_close(flac);
convertToMono(pcm16, totalFrames);
return true;
}
void AudioDecoder::convertToMono(const std::vector<short>& pcm16, uint64_t frameCount) {
if (channels_ == 1) {
samples_.resize(frameCount);
for (uint64_t i = 0; i < frameCount; ++i) {
samples_[i] = pcm16[i] / 32768.0f;
}
} else {
samples_.resize(frameCount);
for (uint64_t i = 0; i < frameCount; ++i) {
float sum = 0.0f;
for (int ch = 0; ch < channels_; ++ch) {
sum += pcm16[i * channels_ + ch];
}
samples_[i] = sum / (channels_ * 32768.0f);
}
}
emit progress(1.0);
emit decoded(filePath);
LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz)")
.arg(filePath).arg(samples_.size()).arg(sampleRate_));
return true;
#else
LOG_ERROR(kTag, "dr_libs 未编译启用");
emit error("音频解码库未启用");
return false;
#endif
}
#endif
double AudioDecoder::duration() const {
if (sampleRate_ == 0) return 0.0;

View File

@ -5,6 +5,7 @@
#include <QStringList>
#include <vector>
#include <memory>
#include <cstdint>
namespace impress {
@ -20,7 +21,7 @@ public:
explicit AudioDecoder(QObject* parent = nullptr);
~AudioDecoder() override;
/** @brief 支持的格式 */
/** @brief 支持的格式 (WAV, MP3, FLAC) */
static QStringList supportedFormats();
/** @brief 解码音频文件 */
@ -49,6 +50,13 @@ signals:
void error(const QString& message);
private:
#ifdef HAVE_DR_LIBS
bool decodeWav(const QString& filePath);
bool decodeMp3(const QString& filePath);
bool decodeFlac(const QString& filePath);
void convertToMono(const std::vector<short>& pcm16, uint64_t frameCount);
#endif
std::vector<float> samples_;
int sampleRate_ = 0;
int channels_ = 0;

View File

@ -161,9 +161,9 @@ std::vector<float> MelSpectrogram::stft(const std::vector<float>& samples, int f
// 执行 FFT
fft(fftInput);
// 计算幅度谱
// 计算幅度谱 (Whisper 使用 magnitude 而非 magnitude²)
for (int k = 0; k < nFreq; k++) {
magnitude[k] = fftInput[k].magnitudeSq();
magnitude[k] = std::sqrt(fftInput[k].magnitudeSq());
}
return magnitude;

View File

@ -1,6 +1,7 @@
#include "stt_engine.h"
#include "mel_spectrogram.h"
#include "whisper_tokenizer.h"
#include "audio_processor.h"
#include "utils/logger.h"
#include "utils/timer.h"
@ -237,7 +238,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
RecognitionResult result;
QString lang = language.isEmpty() ? "zh" : language;
LOG_DEBUG(kTag, QString("推理语言: %1").arg(lang));
LOG_DEBUG(kTag, QString("推理语言: %1 (采样率: %2Hz, 样本数: %3)")
.arg(lang).arg(sampleRate).arg(samples.size()));
#ifdef HAVE_ONNXRUNTIME
if (!loaded_) {
@ -247,15 +249,29 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
}
try {
// 1. 计算 Mel 频谱图
Timer melTimer;
MelSpectrogram melExtractor(kMelBins, 400, 160, sampleRate);
std::vector<float> melSpec = melExtractor.compute(samples);
int nFrames = melExtractor.nFrames(static_cast<int>(samples.size()));
if (nFrames <= 0) nFrames = 1;
LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)").arg(melTimer.elapsedMs(), 0, 'f', 1).arg(nFrames));
// 1. 重采样到 Whisper 所需的 16kHz
Timer preprocessTimer;
std::vector<float> processedSamples = samples;
int currentSampleRate = sampleRate;
// 2. 运行 ONNX 推理
if (sampleRate != 16000) {
AudioProcessor processor(16000);
processedSamples = processor.resample(samples, sampleRate);
currentSampleRate = 16000;
LOG_DEBUG(kTag, QString("重采样: %1Hz -> %2Hz (%3 -> %4 样本)")
.arg(sampleRate).arg(currentSampleRate)
.arg(samples.size()).arg(processedSamples.size()));
}
// 2. Mel 频谱图提取
MelSpectrogram melExtractor(kMelBins, 400, 160, currentSampleRate);
std::vector<float> melSpec = melExtractor.compute(processedSamples);
int nFrames = melExtractor.nFrames(static_cast<int>(processedSamples.size()));
if (nFrames <= 0) nFrames = 1;
LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)")
.arg(preprocessTimer.elapsedMs(), 0, 'f', 1).arg(nFrames));
// 3. 运行 ONNX 推理
Timer inferTimer;
QMutexLocker locker(&impl_->mutex);
@ -277,7 +293,7 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
LOG_DEBUG(kTag, QString("ONNX 推理: %1 ms").arg(inferTimer.elapsedMs(), 0, 'f', 1));
// 3. 解析输出
// 4. 解析输出
auto& outputTensor = outputTensors[0];
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
const float* outputData = outputTensor.GetTensorMutableData<float>();
@ -292,11 +308,12 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
if (shape.size() == 2 && shape[1] == vocabSize) {
// [1, vocab_size] - 直接输出
int bestToken = argmax(outputData, 0, std::min(vocabSize, 50256));
int searchEnd = std::min(vocabSize, 50256);
int bestToken = argmax(outputData, 0, searchEnd);
if (!WhisperTokenizer::isSpecialToken(bestToken)) {
tokens.push_back(bestToken);
}
auto probs = softmax(outputData, 0, std::min(vocabSize, 50256));
auto probs = softmax(outputData, 0, searchEnd);
float maxProb = probs[0];
for (size_t i = 1; i < probs.size(); i++) {
if (probs[i] > maxProb) maxProb = probs[i];
@ -310,7 +327,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
for (int t = 0; t < seqLen && static_cast<int>(tokens.size()) < kMaxTokens; t++) {
int offset = t * vocabSize;
int bestToken = argmax(outputData, offset, offset + vocabSize);
int searchEnd = std::min(offset + vocabSize, offset + 50256);
int bestToken = argmax(outputData, offset, searchEnd);
if (WhisperTokenizer::isSpecialToken(bestToken)) break;
if (!tokens.empty() && tokens.back() == bestToken) continue;
tokens.push_back(bestToken);
@ -320,8 +338,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
float avgConf = 0.0f;
for (int t = 0; t < seqLen && t < static_cast<int>(tokens.size()); t++) {
int offset = t * vocabSize;
int bestToken = argmax(outputData, offset, offset + vocabSize);
auto probs = softmax(outputData, offset, offset + vocabSize);
int bestToken = argmax(outputData, offset, offset + vocabSize);
avgConf += probs[bestToken - offset];
}
result.confidence = avgConf / tokens.size();
@ -332,7 +350,7 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
return result;
}
// 4. 解码 token 为文本
// 5. 解码 token 为文本
if (tokens.empty()) {
result.text = "";
} else if (impl_->tokenizer.isLoaded()) {