diff --git a/README.md b/README.md index 12919c2..dd825c7 100644 --- a/README.md +++ b/README.md @@ -125,8 +125,9 @@ ctest - [x] 音频采集/解码框架 (PortAudio/dr_libs) - [x] 三个 GUI 页面 (实时识别 / 文件转写 / 配置) - [x] 日志系统 (控制台 + 文件输出) +- [x] 批量文件转写 (支持 WAV/MP3/FLAC) +- [x] 音频重采样 (非 16kHz 音频自动重采样) - [ ] 完整 Whisper 推理 (自回归解码 + 流式识别) -- [ ] 批量文件转写 - [ ] 跨平台打包 ## License diff --git a/src/audio/audio_decoder.cpp b/src/audio/audio_decoder.cpp index c626614..8848362 100644 --- a/src/audio/audio_decoder.cpp +++ b/src/audio/audio_decoder.cpp @@ -1,6 +1,8 @@ #include "audio_decoder.h" #include "utils/logger.h" +#include + #ifdef HAVE_DR_LIBS #define DR_WAV_IMPLEMENTATION #define DR_MP3_IMPLEMENTATION @@ -23,7 +25,7 @@ AudioDecoder::AudioDecoder(QObject* parent) AudioDecoder::~AudioDecoder() = default; QStringList AudioDecoder::supportedFormats() { - return {"wav", "mp3", "flac", "ogg", "aac"}; + return {"wav", "mp3", "flac"}; } bool AudioDecoder::decode(const QString& filePath) { @@ -34,13 +36,37 @@ bool AudioDecoder::decode(const QString& filePath) { sampleRate_ = 0; channels_ = 0; - if (ext != "wav") { +#ifdef HAVE_DR_LIBS + bool success = false; + + if (ext == "wav") { + success = decodeWav(filePath); + } else if (ext == "mp3") { + success = decodeMp3(filePath); + } else if (ext == "flac") { + success = decodeFlac(filePath); + } else { LOG_ERROR(kTag, QString("暂不支持格式: %1").arg(ext)); emit error(QString("暂不支持格式: %1").arg(ext)); return false; } + if (success) { + emit progress(1.0); + emit decoded(filePath); + LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz, %4声道)") + .arg(filePath).arg(samples_.size()).arg(sampleRate_).arg(channels_)); + } + return success; +#else + LOG_ERROR(kTag, "dr_libs 未编译启用"); + emit error("音频解码库未启用"); + return false; +#endif +} + #ifdef HAVE_DR_LIBS +bool AudioDecoder::decodeWav(const QString& filePath) { drwav wav; if (!drwav_init_file(&wav, filePath.toUtf8().constData(), nullptr)) { LOG_ERROR(kTag, QString("无法打开 WAV 文件: %1").arg(filePath)); @@ -55,33 +81,80 @@ bool AudioDecoder::decode(const QString& filePath) { drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data()); drwav_uninit(&wav); - // 多声道混合为单声道 + convertToMono(pcm16, wav.totalPCMFrameCount); + return true; +} + +bool AudioDecoder::decodeMp3(const QString& filePath) { + drmp3 mp3; + if (!drmp3_init_file(&mp3, filePath.toUtf8().constData(), nullptr)) { + LOG_ERROR(kTag, QString("无法打开 MP3 文件: %1").arg(filePath)); + emit error("无法打开音频文件"); + return false; + } + + channels_ = mp3.channels; + sampleRate_ = mp3.sampleRate; + + drmp3_uint64 totalFrames = drmp3_get_pcm_frame_count(&mp3); + std::vector rawPcm(totalFrames * channels_); + drmp3_read_pcm_frames_f32(&mp3, totalFrames, rawPcm.data()); + drmp3_uninit(&mp3); + + // MP3 解码器直接输出归一化浮点数据 if (channels_ == 1) { - samples_ = std::vector(pcm16.begin(), pcm16.end()); - // 归一化 - for (auto& s : samples_) s /= 32768.0f; + samples_ = std::move(rawPcm); } else { - samples_.resize(wav.totalPCMFrameCount); - for (size_t i = 0; i < wav.totalPCMFrameCount; ++i) { + samples_.resize(totalFrames); + for (drmp3_uint64 i = 0; i < totalFrames; ++i) { float sum = 0.0f; for (int ch = 0; ch < channels_; ++ch) { - sum += pcm16[static_cast(i) * static_cast(channels_) + static_cast(ch)]; + sum += rawPcm[i * channels_ + ch]; + } + samples_[i] = sum / channels_; + } + } + return true; +} + +bool AudioDecoder::decodeFlac(const QString& filePath) { + drflac* flac = drflac_open_file(filePath.toUtf8().constData(), nullptr); + if (!flac) { + LOG_ERROR(kTag, QString("无法打开 FLAC 文件: %1").arg(filePath)); + emit error("无法打开音频文件"); + return false; + } + + channels_ = flac->channels; + sampleRate_ = flac->sampleRate; + + drflac_uint64 totalFrames = flac->totalPCMFrameCount; + std::vector pcm16(totalFrames * channels_); + drflac_read_pcm_frames_s16(flac, totalFrames, pcm16.data()); + drflac_close(flac); + + convertToMono(pcm16, totalFrames); + return true; +} + +void AudioDecoder::convertToMono(const std::vector& pcm16, uint64_t frameCount) { + if (channels_ == 1) { + samples_.resize(frameCount); + for (uint64_t i = 0; i < frameCount; ++i) { + samples_[i] = pcm16[i] / 32768.0f; + } + } else { + samples_.resize(frameCount); + for (uint64_t i = 0; i < frameCount; ++i) { + float sum = 0.0f; + for (int ch = 0; ch < channels_; ++ch) { + sum += pcm16[i * channels_ + ch]; } samples_[i] = sum / (channels_ * 32768.0f); } } - - emit progress(1.0); - emit decoded(filePath); - LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz)") - .arg(filePath).arg(samples_.size()).arg(sampleRate_)); - return true; -#else - LOG_ERROR(kTag, "dr_libs 未编译启用"); - emit error("音频解码库未启用"); - return false; -#endif } +#endif double AudioDecoder::duration() const { if (sampleRate_ == 0) return 0.0; diff --git a/src/audio/audio_decoder.h b/src/audio/audio_decoder.h index 8db44ee..40623af 100644 --- a/src/audio/audio_decoder.h +++ b/src/audio/audio_decoder.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace impress { @@ -20,7 +21,7 @@ public: explicit AudioDecoder(QObject* parent = nullptr); ~AudioDecoder() override; - /** @brief 支持的格式 */ + /** @brief 支持的格式 (WAV, MP3, FLAC) */ static QStringList supportedFormats(); /** @brief 解码音频文件 */ @@ -49,6 +50,13 @@ signals: void error(const QString& message); private: +#ifdef HAVE_DR_LIBS + bool decodeWav(const QString& filePath); + bool decodeMp3(const QString& filePath); + bool decodeFlac(const QString& filePath); + void convertToMono(const std::vector& pcm16, uint64_t frameCount); +#endif + std::vector samples_; int sampleRate_ = 0; int channels_ = 0; diff --git a/src/core/mel_spectrogram.cpp b/src/core/mel_spectrogram.cpp index b65bf07..0766c99 100644 --- a/src/core/mel_spectrogram.cpp +++ b/src/core/mel_spectrogram.cpp @@ -161,9 +161,9 @@ std::vector MelSpectrogram::stft(const std::vector& samples, int f // 执行 FFT fft(fftInput); - // 计算幅度谱 + // 计算幅度谱 (Whisper 使用 magnitude 而非 magnitude²) for (int k = 0; k < nFreq; k++) { - magnitude[k] = fftInput[k].magnitudeSq(); + magnitude[k] = std::sqrt(fftInput[k].magnitudeSq()); } return magnitude; diff --git a/src/core/stt_engine.cpp b/src/core/stt_engine.cpp index e541e90..0fe97a2 100644 --- a/src/core/stt_engine.cpp +++ b/src/core/stt_engine.cpp @@ -1,6 +1,7 @@ #include "stt_engine.h" #include "mel_spectrogram.h" #include "whisper_tokenizer.h" +#include "audio_processor.h" #include "utils/logger.h" #include "utils/timer.h" @@ -237,7 +238,8 @@ RecognitionResult STTEngine::infer(const std::vector& samples, RecognitionResult result; QString lang = language.isEmpty() ? "zh" : language; - LOG_DEBUG(kTag, QString("推理语言: %1").arg(lang)); + LOG_DEBUG(kTag, QString("推理语言: %1 (采样率: %2Hz, 样本数: %3)") + .arg(lang).arg(sampleRate).arg(samples.size())); #ifdef HAVE_ONNXRUNTIME if (!loaded_) { @@ -247,15 +249,29 @@ RecognitionResult STTEngine::infer(const std::vector& samples, } try { - // 1. 计算 Mel 频谱图 - Timer melTimer; - MelSpectrogram melExtractor(kMelBins, 400, 160, sampleRate); - std::vector melSpec = melExtractor.compute(samples); - int nFrames = melExtractor.nFrames(static_cast(samples.size())); - if (nFrames <= 0) nFrames = 1; - LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)").arg(melTimer.elapsedMs(), 0, 'f', 1).arg(nFrames)); + // 1. 重采样到 Whisper 所需的 16kHz + Timer preprocessTimer; + std::vector processedSamples = samples; + int currentSampleRate = sampleRate; - // 2. 运行 ONNX 推理 + if (sampleRate != 16000) { + AudioProcessor processor(16000); + processedSamples = processor.resample(samples, sampleRate); + currentSampleRate = 16000; + LOG_DEBUG(kTag, QString("重采样: %1Hz -> %2Hz (%3 -> %4 样本)") + .arg(sampleRate).arg(currentSampleRate) + .arg(samples.size()).arg(processedSamples.size())); + } + + // 2. Mel 频谱图提取 + MelSpectrogram melExtractor(kMelBins, 400, 160, currentSampleRate); + std::vector melSpec = melExtractor.compute(processedSamples); + int nFrames = melExtractor.nFrames(static_cast(processedSamples.size())); + if (nFrames <= 0) nFrames = 1; + LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)") + .arg(preprocessTimer.elapsedMs(), 0, 'f', 1).arg(nFrames)); + + // 3. 运行 ONNX 推理 Timer inferTimer; QMutexLocker locker(&impl_->mutex); @@ -277,7 +293,7 @@ RecognitionResult STTEngine::infer(const std::vector& samples, LOG_DEBUG(kTag, QString("ONNX 推理: %1 ms").arg(inferTimer.elapsedMs(), 0, 'f', 1)); - // 3. 解析输出 + // 4. 解析输出 auto& outputTensor = outputTensors[0]; auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape(); const float* outputData = outputTensor.GetTensorMutableData(); @@ -292,11 +308,12 @@ RecognitionResult STTEngine::infer(const std::vector& samples, if (shape.size() == 2 && shape[1] == vocabSize) { // [1, vocab_size] - 直接输出 - int bestToken = argmax(outputData, 0, std::min(vocabSize, 50256)); + int searchEnd = std::min(vocabSize, 50256); + int bestToken = argmax(outputData, 0, searchEnd); if (!WhisperTokenizer::isSpecialToken(bestToken)) { tokens.push_back(bestToken); } - auto probs = softmax(outputData, 0, std::min(vocabSize, 50256)); + auto probs = softmax(outputData, 0, searchEnd); float maxProb = probs[0]; for (size_t i = 1; i < probs.size(); i++) { if (probs[i] > maxProb) maxProb = probs[i]; @@ -310,7 +327,8 @@ RecognitionResult STTEngine::infer(const std::vector& samples, for (int t = 0; t < seqLen && static_cast(tokens.size()) < kMaxTokens; t++) { int offset = t * vocabSize; - int bestToken = argmax(outputData, offset, offset + vocabSize); + int searchEnd = std::min(offset + vocabSize, offset + 50256); + int bestToken = argmax(outputData, offset, searchEnd); if (WhisperTokenizer::isSpecialToken(bestToken)) break; if (!tokens.empty() && tokens.back() == bestToken) continue; tokens.push_back(bestToken); @@ -320,8 +338,8 @@ RecognitionResult STTEngine::infer(const std::vector& samples, float avgConf = 0.0f; for (int t = 0; t < seqLen && t < static_cast(tokens.size()); t++) { int offset = t * vocabSize; - int bestToken = argmax(outputData, offset, offset + vocabSize); auto probs = softmax(outputData, offset, offset + vocabSize); + int bestToken = argmax(outputData, offset, offset + vocabSize); avgConf += probs[bestToken - offset]; } result.confidence = avgConf / tokens.size(); @@ -332,7 +350,7 @@ RecognitionResult STTEngine::infer(const std::vector& samples, return result; } - // 4. 解码 token 为文本 + // 5. 解码 token 为文本 if (tokens.empty()) { result.text = ""; } else if (impl_->tokenizer.isLoaded()) {