feat: 扩展音频格式支持与推理管线优化
- 新增 MP3/FLAC 格式解码 (dr_mp3/dr_flac) - 修复 Mel 频谱图使用 magnitude² 替代 magnitude 的问题 - 推理管线增加音频重采样 (非 16kHz 自动转换) - 更新 README 项目状态 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
760899e81c
commit
59c12ab931
@ -125,8 +125,9 @@ ctest
|
||||
- [x] 音频采集/解码框架 (PortAudio/dr_libs)
|
||||
- [x] 三个 GUI 页面 (实时识别 / 文件转写 / 配置)
|
||||
- [x] 日志系统 (控制台 + 文件输出)
|
||||
- [x] 批量文件转写 (支持 WAV/MP3/FLAC)
|
||||
- [x] 音频重采样 (非 16kHz 音频自动重采样)
|
||||
- [ ] 完整 Whisper 推理 (自回归解码 + 流式识别)
|
||||
- [ ] 批量文件转写
|
||||
- [ ] 跨平台打包
|
||||
|
||||
## License
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#include "audio_decoder.h"
|
||||
#include "utils/logger.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef HAVE_DR_LIBS
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#define DR_MP3_IMPLEMENTATION
|
||||
@ -23,7 +25,7 @@ AudioDecoder::AudioDecoder(QObject* parent)
|
||||
AudioDecoder::~AudioDecoder() = default;
|
||||
|
||||
QStringList AudioDecoder::supportedFormats() {
|
||||
return {"wav", "mp3", "flac", "ogg", "aac"};
|
||||
return {"wav", "mp3", "flac"};
|
||||
}
|
||||
|
||||
bool AudioDecoder::decode(const QString& filePath) {
|
||||
@ -34,13 +36,37 @@ bool AudioDecoder::decode(const QString& filePath) {
|
||||
sampleRate_ = 0;
|
||||
channels_ = 0;
|
||||
|
||||
if (ext != "wav") {
|
||||
#ifdef HAVE_DR_LIBS
|
||||
bool success = false;
|
||||
|
||||
if (ext == "wav") {
|
||||
success = decodeWav(filePath);
|
||||
} else if (ext == "mp3") {
|
||||
success = decodeMp3(filePath);
|
||||
} else if (ext == "flac") {
|
||||
success = decodeFlac(filePath);
|
||||
} else {
|
||||
LOG_ERROR(kTag, QString("暂不支持格式: %1").arg(ext));
|
||||
emit error(QString("暂不支持格式: %1").arg(ext));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (success) {
|
||||
emit progress(1.0);
|
||||
emit decoded(filePath);
|
||||
LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz, %4声道)")
|
||||
.arg(filePath).arg(samples_.size()).arg(sampleRate_).arg(channels_));
|
||||
}
|
||||
return success;
|
||||
#else
|
||||
LOG_ERROR(kTag, "dr_libs 未编译启用");
|
||||
emit error("音频解码库未启用");
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef HAVE_DR_LIBS
|
||||
bool AudioDecoder::decodeWav(const QString& filePath) {
|
||||
drwav wav;
|
||||
if (!drwav_init_file(&wav, filePath.toUtf8().constData(), nullptr)) {
|
||||
LOG_ERROR(kTag, QString("无法打开 WAV 文件: %1").arg(filePath));
|
||||
@ -55,33 +81,80 @@ bool AudioDecoder::decode(const QString& filePath) {
|
||||
drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// 多声道混合为单声道
|
||||
convertToMono(pcm16, wav.totalPCMFrameCount);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AudioDecoder::decodeMp3(const QString& filePath) {
|
||||
drmp3 mp3;
|
||||
if (!drmp3_init_file(&mp3, filePath.toUtf8().constData(), nullptr)) {
|
||||
LOG_ERROR(kTag, QString("无法打开 MP3 文件: %1").arg(filePath));
|
||||
emit error("无法打开音频文件");
|
||||
return false;
|
||||
}
|
||||
|
||||
channels_ = mp3.channels;
|
||||
sampleRate_ = mp3.sampleRate;
|
||||
|
||||
drmp3_uint64 totalFrames = drmp3_get_pcm_frame_count(&mp3);
|
||||
std::vector<float> rawPcm(totalFrames * channels_);
|
||||
drmp3_read_pcm_frames_f32(&mp3, totalFrames, rawPcm.data());
|
||||
drmp3_uninit(&mp3);
|
||||
|
||||
// MP3 解码器直接输出归一化浮点数据
|
||||
if (channels_ == 1) {
|
||||
samples_ = std::vector<float>(pcm16.begin(), pcm16.end());
|
||||
// 归一化
|
||||
for (auto& s : samples_) s /= 32768.0f;
|
||||
samples_ = std::move(rawPcm);
|
||||
} else {
|
||||
samples_.resize(wav.totalPCMFrameCount);
|
||||
for (size_t i = 0; i < wav.totalPCMFrameCount; ++i) {
|
||||
samples_.resize(totalFrames);
|
||||
for (drmp3_uint64 i = 0; i < totalFrames; ++i) {
|
||||
float sum = 0.0f;
|
||||
for (int ch = 0; ch < channels_; ++ch) {
|
||||
sum += pcm16[static_cast<size_t>(i) * static_cast<size_t>(channels_) + static_cast<size_t>(ch)];
|
||||
sum += rawPcm[i * channels_ + ch];
|
||||
}
|
||||
samples_[i] = sum / channels_;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AudioDecoder::decodeFlac(const QString& filePath) {
|
||||
drflac* flac = drflac_open_file(filePath.toUtf8().constData(), nullptr);
|
||||
if (!flac) {
|
||||
LOG_ERROR(kTag, QString("无法打开 FLAC 文件: %1").arg(filePath));
|
||||
emit error("无法打开音频文件");
|
||||
return false;
|
||||
}
|
||||
|
||||
channels_ = flac->channels;
|
||||
sampleRate_ = flac->sampleRate;
|
||||
|
||||
drflac_uint64 totalFrames = flac->totalPCMFrameCount;
|
||||
std::vector<short> pcm16(totalFrames * channels_);
|
||||
drflac_read_pcm_frames_s16(flac, totalFrames, pcm16.data());
|
||||
drflac_close(flac);
|
||||
|
||||
convertToMono(pcm16, totalFrames);
|
||||
return true;
|
||||
}
|
||||
|
||||
void AudioDecoder::convertToMono(const std::vector<short>& pcm16, uint64_t frameCount) {
|
||||
if (channels_ == 1) {
|
||||
samples_.resize(frameCount);
|
||||
for (uint64_t i = 0; i < frameCount; ++i) {
|
||||
samples_[i] = pcm16[i] / 32768.0f;
|
||||
}
|
||||
} else {
|
||||
samples_.resize(frameCount);
|
||||
for (uint64_t i = 0; i < frameCount; ++i) {
|
||||
float sum = 0.0f;
|
||||
for (int ch = 0; ch < channels_; ++ch) {
|
||||
sum += pcm16[i * channels_ + ch];
|
||||
}
|
||||
samples_[i] = sum / (channels_ * 32768.0f);
|
||||
}
|
||||
}
|
||||
|
||||
emit progress(1.0);
|
||||
emit decoded(filePath);
|
||||
LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz)")
|
||||
.arg(filePath).arg(samples_.size()).arg(sampleRate_));
|
||||
return true;
|
||||
#else
|
||||
LOG_ERROR(kTag, "dr_libs 未编译启用");
|
||||
emit error("音频解码库未启用");
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
double AudioDecoder::duration() const {
|
||||
if (sampleRate_ == 0) return 0.0;
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <QStringList>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <cstdint>
|
||||
|
||||
namespace impress {
|
||||
|
||||
@ -20,7 +21,7 @@ public:
|
||||
explicit AudioDecoder(QObject* parent = nullptr);
|
||||
~AudioDecoder() override;
|
||||
|
||||
/** @brief 支持的格式 */
|
||||
/** @brief 支持的格式 (WAV, MP3, FLAC) */
|
||||
static QStringList supportedFormats();
|
||||
|
||||
/** @brief 解码音频文件 */
|
||||
@ -49,6 +50,13 @@ signals:
|
||||
void error(const QString& message);
|
||||
|
||||
private:
|
||||
#ifdef HAVE_DR_LIBS
|
||||
bool decodeWav(const QString& filePath);
|
||||
bool decodeMp3(const QString& filePath);
|
||||
bool decodeFlac(const QString& filePath);
|
||||
void convertToMono(const std::vector<short>& pcm16, uint64_t frameCount);
|
||||
#endif
|
||||
|
||||
std::vector<float> samples_;
|
||||
int sampleRate_ = 0;
|
||||
int channels_ = 0;
|
||||
|
||||
@ -161,9 +161,9 @@ std::vector<float> MelSpectrogram::stft(const std::vector<float>& samples, int f
|
||||
// 执行 FFT
|
||||
fft(fftInput);
|
||||
|
||||
// 计算幅度谱
|
||||
// 计算幅度谱 (Whisper 使用 magnitude 而非 magnitude²)
|
||||
for (int k = 0; k < nFreq; k++) {
|
||||
magnitude[k] = fftInput[k].magnitudeSq();
|
||||
magnitude[k] = std::sqrt(fftInput[k].magnitudeSq());
|
||||
}
|
||||
|
||||
return magnitude;
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include "stt_engine.h"
|
||||
#include "mel_spectrogram.h"
|
||||
#include "whisper_tokenizer.h"
|
||||
#include "audio_processor.h"
|
||||
#include "utils/logger.h"
|
||||
#include "utils/timer.h"
|
||||
|
||||
@ -237,7 +238,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
RecognitionResult result;
|
||||
|
||||
QString lang = language.isEmpty() ? "zh" : language;
|
||||
LOG_DEBUG(kTag, QString("推理语言: %1").arg(lang));
|
||||
LOG_DEBUG(kTag, QString("推理语言: %1 (采样率: %2Hz, 样本数: %3)")
|
||||
.arg(lang).arg(sampleRate).arg(samples.size()));
|
||||
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
if (!loaded_) {
|
||||
@ -247,15 +249,29 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. 计算 Mel 频谱图
|
||||
Timer melTimer;
|
||||
MelSpectrogram melExtractor(kMelBins, 400, 160, sampleRate);
|
||||
std::vector<float> melSpec = melExtractor.compute(samples);
|
||||
int nFrames = melExtractor.nFrames(static_cast<int>(samples.size()));
|
||||
if (nFrames <= 0) nFrames = 1;
|
||||
LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)").arg(melTimer.elapsedMs(), 0, 'f', 1).arg(nFrames));
|
||||
// 1. 重采样到 Whisper 所需的 16kHz
|
||||
Timer preprocessTimer;
|
||||
std::vector<float> processedSamples = samples;
|
||||
int currentSampleRate = sampleRate;
|
||||
|
||||
// 2. 运行 ONNX 推理
|
||||
if (sampleRate != 16000) {
|
||||
AudioProcessor processor(16000);
|
||||
processedSamples = processor.resample(samples, sampleRate);
|
||||
currentSampleRate = 16000;
|
||||
LOG_DEBUG(kTag, QString("重采样: %1Hz -> %2Hz (%3 -> %4 样本)")
|
||||
.arg(sampleRate).arg(currentSampleRate)
|
||||
.arg(samples.size()).arg(processedSamples.size()));
|
||||
}
|
||||
|
||||
// 2. Mel 频谱图提取
|
||||
MelSpectrogram melExtractor(kMelBins, 400, 160, currentSampleRate);
|
||||
std::vector<float> melSpec = melExtractor.compute(processedSamples);
|
||||
int nFrames = melExtractor.nFrames(static_cast<int>(processedSamples.size()));
|
||||
if (nFrames <= 0) nFrames = 1;
|
||||
LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)")
|
||||
.arg(preprocessTimer.elapsedMs(), 0, 'f', 1).arg(nFrames));
|
||||
|
||||
// 3. 运行 ONNX 推理
|
||||
Timer inferTimer;
|
||||
QMutexLocker locker(&impl_->mutex);
|
||||
|
||||
@ -277,7 +293,7 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
|
||||
LOG_DEBUG(kTag, QString("ONNX 推理: %1 ms").arg(inferTimer.elapsedMs(), 0, 'f', 1));
|
||||
|
||||
// 3. 解析输出
|
||||
// 4. 解析输出
|
||||
auto& outputTensor = outputTensors[0];
|
||||
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float* outputData = outputTensor.GetTensorMutableData<float>();
|
||||
@ -292,11 +308,12 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
|
||||
if (shape.size() == 2 && shape[1] == vocabSize) {
|
||||
// [1, vocab_size] - 直接输出
|
||||
int bestToken = argmax(outputData, 0, std::min(vocabSize, 50256));
|
||||
int searchEnd = std::min(vocabSize, 50256);
|
||||
int bestToken = argmax(outputData, 0, searchEnd);
|
||||
if (!WhisperTokenizer::isSpecialToken(bestToken)) {
|
||||
tokens.push_back(bestToken);
|
||||
}
|
||||
auto probs = softmax(outputData, 0, std::min(vocabSize, 50256));
|
||||
auto probs = softmax(outputData, 0, searchEnd);
|
||||
float maxProb = probs[0];
|
||||
for (size_t i = 1; i < probs.size(); i++) {
|
||||
if (probs[i] > maxProb) maxProb = probs[i];
|
||||
@ -310,7 +327,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
|
||||
for (int t = 0; t < seqLen && static_cast<int>(tokens.size()) < kMaxTokens; t++) {
|
||||
int offset = t * vocabSize;
|
||||
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
||||
int searchEnd = std::min(offset + vocabSize, offset + 50256);
|
||||
int bestToken = argmax(outputData, offset, searchEnd);
|
||||
if (WhisperTokenizer::isSpecialToken(bestToken)) break;
|
||||
if (!tokens.empty() && tokens.back() == bestToken) continue;
|
||||
tokens.push_back(bestToken);
|
||||
@ -320,8 +338,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
float avgConf = 0.0f;
|
||||
for (int t = 0; t < seqLen && t < static_cast<int>(tokens.size()); t++) {
|
||||
int offset = t * vocabSize;
|
||||
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
||||
auto probs = softmax(outputData, offset, offset + vocabSize);
|
||||
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
||||
avgConf += probs[bestToken - offset];
|
||||
}
|
||||
result.confidence = avgConf / tokens.size();
|
||||
@ -332,7 +350,7 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
return result;
|
||||
}
|
||||
|
||||
// 4. 解码 token 为文本
|
||||
// 5. 解码 token 为文本
|
||||
if (tokens.empty()) {
|
||||
result.text = "";
|
||||
} else if (impl_->tokenizer.isLoaded()) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user