feat: 扩展音频格式支持与推理管线优化
- 新增 MP3/FLAC 格式解码 (dr_mp3/dr_flac) - 修复 Mel 频谱图使用 magnitude² 替代 magnitude 的问题 - 推理管线增加音频重采样 (非 16kHz 自动转换) - 更新 README 项目状态 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
760899e81c
commit
59c12ab931
@ -125,8 +125,9 @@ ctest
|
|||||||
- [x] 音频采集/解码框架 (PortAudio/dr_libs)
|
- [x] 音频采集/解码框架 (PortAudio/dr_libs)
|
||||||
- [x] 三个 GUI 页面 (实时识别 / 文件转写 / 配置)
|
- [x] 三个 GUI 页面 (实时识别 / 文件转写 / 配置)
|
||||||
- [x] 日志系统 (控制台 + 文件输出)
|
- [x] 日志系统 (控制台 + 文件输出)
|
||||||
|
- [x] 批量文件转写 (支持 WAV/MP3/FLAC)
|
||||||
|
- [x] 音频重采样 (非 16kHz 音频自动重采样)
|
||||||
- [ ] 完整 Whisper 推理 (自回归解码 + 流式识别)
|
- [ ] 完整 Whisper 推理 (自回归解码 + 流式识别)
|
||||||
- [ ] 批量文件转写
|
|
||||||
- [ ] 跨平台打包
|
- [ ] 跨平台打包
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
#include "audio_decoder.h"
|
#include "audio_decoder.h"
|
||||||
#include "utils/logger.h"
|
#include "utils/logger.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
#ifdef HAVE_DR_LIBS
|
#ifdef HAVE_DR_LIBS
|
||||||
#define DR_WAV_IMPLEMENTATION
|
#define DR_WAV_IMPLEMENTATION
|
||||||
#define DR_MP3_IMPLEMENTATION
|
#define DR_MP3_IMPLEMENTATION
|
||||||
@ -23,7 +25,7 @@ AudioDecoder::AudioDecoder(QObject* parent)
|
|||||||
AudioDecoder::~AudioDecoder() = default;
|
AudioDecoder::~AudioDecoder() = default;
|
||||||
|
|
||||||
QStringList AudioDecoder::supportedFormats() {
|
QStringList AudioDecoder::supportedFormats() {
|
||||||
return {"wav", "mp3", "flac", "ogg", "aac"};
|
return {"wav", "mp3", "flac"};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AudioDecoder::decode(const QString& filePath) {
|
bool AudioDecoder::decode(const QString& filePath) {
|
||||||
@ -34,13 +36,37 @@ bool AudioDecoder::decode(const QString& filePath) {
|
|||||||
sampleRate_ = 0;
|
sampleRate_ = 0;
|
||||||
channels_ = 0;
|
channels_ = 0;
|
||||||
|
|
||||||
if (ext != "wav") {
|
#ifdef HAVE_DR_LIBS
|
||||||
|
bool success = false;
|
||||||
|
|
||||||
|
if (ext == "wav") {
|
||||||
|
success = decodeWav(filePath);
|
||||||
|
} else if (ext == "mp3") {
|
||||||
|
success = decodeMp3(filePath);
|
||||||
|
} else if (ext == "flac") {
|
||||||
|
success = decodeFlac(filePath);
|
||||||
|
} else {
|
||||||
LOG_ERROR(kTag, QString("暂不支持格式: %1").arg(ext));
|
LOG_ERROR(kTag, QString("暂不支持格式: %1").arg(ext));
|
||||||
emit error(QString("暂不支持格式: %1").arg(ext));
|
emit error(QString("暂不支持格式: %1").arg(ext));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (success) {
|
||||||
|
emit progress(1.0);
|
||||||
|
emit decoded(filePath);
|
||||||
|
LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz, %4声道)")
|
||||||
|
.arg(filePath).arg(samples_.size()).arg(sampleRate_).arg(channels_));
|
||||||
|
}
|
||||||
|
return success;
|
||||||
|
#else
|
||||||
|
LOG_ERROR(kTag, "dr_libs 未编译启用");
|
||||||
|
emit error("音频解码库未启用");
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef HAVE_DR_LIBS
|
#ifdef HAVE_DR_LIBS
|
||||||
|
bool AudioDecoder::decodeWav(const QString& filePath) {
|
||||||
drwav wav;
|
drwav wav;
|
||||||
if (!drwav_init_file(&wav, filePath.toUtf8().constData(), nullptr)) {
|
if (!drwav_init_file(&wav, filePath.toUtf8().constData(), nullptr)) {
|
||||||
LOG_ERROR(kTag, QString("无法打开 WAV 文件: %1").arg(filePath));
|
LOG_ERROR(kTag, QString("无法打开 WAV 文件: %1").arg(filePath));
|
||||||
@ -55,33 +81,80 @@ bool AudioDecoder::decode(const QString& filePath) {
|
|||||||
drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data());
|
drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data());
|
||||||
drwav_uninit(&wav);
|
drwav_uninit(&wav);
|
||||||
|
|
||||||
// 多声道混合为单声道
|
convertToMono(pcm16, wav.totalPCMFrameCount);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AudioDecoder::decodeMp3(const QString& filePath) {
|
||||||
|
drmp3 mp3;
|
||||||
|
if (!drmp3_init_file(&mp3, filePath.toUtf8().constData(), nullptr)) {
|
||||||
|
LOG_ERROR(kTag, QString("无法打开 MP3 文件: %1").arg(filePath));
|
||||||
|
emit error("无法打开音频文件");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
channels_ = mp3.channels;
|
||||||
|
sampleRate_ = mp3.sampleRate;
|
||||||
|
|
||||||
|
drmp3_uint64 totalFrames = drmp3_get_pcm_frame_count(&mp3);
|
||||||
|
std::vector<float> rawPcm(totalFrames * channels_);
|
||||||
|
drmp3_read_pcm_frames_f32(&mp3, totalFrames, rawPcm.data());
|
||||||
|
drmp3_uninit(&mp3);
|
||||||
|
|
||||||
|
// MP3 解码器直接输出归一化浮点数据
|
||||||
if (channels_ == 1) {
|
if (channels_ == 1) {
|
||||||
samples_ = std::vector<float>(pcm16.begin(), pcm16.end());
|
samples_ = std::move(rawPcm);
|
||||||
// 归一化
|
|
||||||
for (auto& s : samples_) s /= 32768.0f;
|
|
||||||
} else {
|
} else {
|
||||||
samples_.resize(wav.totalPCMFrameCount);
|
samples_.resize(totalFrames);
|
||||||
for (size_t i = 0; i < wav.totalPCMFrameCount; ++i) {
|
for (drmp3_uint64 i = 0; i < totalFrames; ++i) {
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int ch = 0; ch < channels_; ++ch) {
|
for (int ch = 0; ch < channels_; ++ch) {
|
||||||
sum += pcm16[static_cast<size_t>(i) * static_cast<size_t>(channels_) + static_cast<size_t>(ch)];
|
sum += rawPcm[i * channels_ + ch];
|
||||||
|
}
|
||||||
|
samples_[i] = sum / channels_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AudioDecoder::decodeFlac(const QString& filePath) {
|
||||||
|
drflac* flac = drflac_open_file(filePath.toUtf8().constData(), nullptr);
|
||||||
|
if (!flac) {
|
||||||
|
LOG_ERROR(kTag, QString("无法打开 FLAC 文件: %1").arg(filePath));
|
||||||
|
emit error("无法打开音频文件");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
channels_ = flac->channels;
|
||||||
|
sampleRate_ = flac->sampleRate;
|
||||||
|
|
||||||
|
drflac_uint64 totalFrames = flac->totalPCMFrameCount;
|
||||||
|
std::vector<short> pcm16(totalFrames * channels_);
|
||||||
|
drflac_read_pcm_frames_s16(flac, totalFrames, pcm16.data());
|
||||||
|
drflac_close(flac);
|
||||||
|
|
||||||
|
convertToMono(pcm16, totalFrames);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AudioDecoder::convertToMono(const std::vector<short>& pcm16, uint64_t frameCount) {
|
||||||
|
if (channels_ == 1) {
|
||||||
|
samples_.resize(frameCount);
|
||||||
|
for (uint64_t i = 0; i < frameCount; ++i) {
|
||||||
|
samples_[i] = pcm16[i] / 32768.0f;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
samples_.resize(frameCount);
|
||||||
|
for (uint64_t i = 0; i < frameCount; ++i) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int ch = 0; ch < channels_; ++ch) {
|
||||||
|
sum += pcm16[i * channels_ + ch];
|
||||||
}
|
}
|
||||||
samples_[i] = sum / (channels_ * 32768.0f);
|
samples_[i] = sum / (channels_ * 32768.0f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
emit progress(1.0);
|
|
||||||
emit decoded(filePath);
|
|
||||||
LOG_INFO(kTag, QString("文件解码完成: %1 (%2 样本, %3Hz)")
|
|
||||||
.arg(filePath).arg(samples_.size()).arg(sampleRate_));
|
|
||||||
return true;
|
|
||||||
#else
|
|
||||||
LOG_ERROR(kTag, "dr_libs 未编译启用");
|
|
||||||
emit error("音频解码库未启用");
|
|
||||||
return false;
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
double AudioDecoder::duration() const {
|
double AudioDecoder::duration() const {
|
||||||
if (sampleRate_ == 0) return 0.0;
|
if (sampleRate_ == 0) return 0.0;
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
#include <QStringList>
|
#include <QStringList>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
namespace impress {
|
namespace impress {
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ public:
|
|||||||
explicit AudioDecoder(QObject* parent = nullptr);
|
explicit AudioDecoder(QObject* parent = nullptr);
|
||||||
~AudioDecoder() override;
|
~AudioDecoder() override;
|
||||||
|
|
||||||
/** @brief 支持的格式 */
|
/** @brief 支持的格式 (WAV, MP3, FLAC) */
|
||||||
static QStringList supportedFormats();
|
static QStringList supportedFormats();
|
||||||
|
|
||||||
/** @brief 解码音频文件 */
|
/** @brief 解码音频文件 */
|
||||||
@ -49,6 +50,13 @@ signals:
|
|||||||
void error(const QString& message);
|
void error(const QString& message);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
#ifdef HAVE_DR_LIBS
|
||||||
|
bool decodeWav(const QString& filePath);
|
||||||
|
bool decodeMp3(const QString& filePath);
|
||||||
|
bool decodeFlac(const QString& filePath);
|
||||||
|
void convertToMono(const std::vector<short>& pcm16, uint64_t frameCount);
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<float> samples_;
|
std::vector<float> samples_;
|
||||||
int sampleRate_ = 0;
|
int sampleRate_ = 0;
|
||||||
int channels_ = 0;
|
int channels_ = 0;
|
||||||
|
|||||||
@ -161,9 +161,9 @@ std::vector<float> MelSpectrogram::stft(const std::vector<float>& samples, int f
|
|||||||
// 执行 FFT
|
// 执行 FFT
|
||||||
fft(fftInput);
|
fft(fftInput);
|
||||||
|
|
||||||
// 计算幅度谱
|
// 计算幅度谱 (Whisper 使用 magnitude 而非 magnitude²)
|
||||||
for (int k = 0; k < nFreq; k++) {
|
for (int k = 0; k < nFreq; k++) {
|
||||||
magnitude[k] = fftInput[k].magnitudeSq();
|
magnitude[k] = std::sqrt(fftInput[k].magnitudeSq());
|
||||||
}
|
}
|
||||||
|
|
||||||
return magnitude;
|
return magnitude;
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#include "stt_engine.h"
|
#include "stt_engine.h"
|
||||||
#include "mel_spectrogram.h"
|
#include "mel_spectrogram.h"
|
||||||
#include "whisper_tokenizer.h"
|
#include "whisper_tokenizer.h"
|
||||||
|
#include "audio_processor.h"
|
||||||
#include "utils/logger.h"
|
#include "utils/logger.h"
|
||||||
#include "utils/timer.h"
|
#include "utils/timer.h"
|
||||||
|
|
||||||
@ -237,7 +238,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
RecognitionResult result;
|
RecognitionResult result;
|
||||||
|
|
||||||
QString lang = language.isEmpty() ? "zh" : language;
|
QString lang = language.isEmpty() ? "zh" : language;
|
||||||
LOG_DEBUG(kTag, QString("推理语言: %1").arg(lang));
|
LOG_DEBUG(kTag, QString("推理语言: %1 (采样率: %2Hz, 样本数: %3)")
|
||||||
|
.arg(lang).arg(sampleRate).arg(samples.size()));
|
||||||
|
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
if (!loaded_) {
|
if (!loaded_) {
|
||||||
@ -247,15 +249,29 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 1. 计算 Mel 频谱图
|
// 1. 重采样到 Whisper 所需的 16kHz
|
||||||
Timer melTimer;
|
Timer preprocessTimer;
|
||||||
MelSpectrogram melExtractor(kMelBins, 400, 160, sampleRate);
|
std::vector<float> processedSamples = samples;
|
||||||
std::vector<float> melSpec = melExtractor.compute(samples);
|
int currentSampleRate = sampleRate;
|
||||||
int nFrames = melExtractor.nFrames(static_cast<int>(samples.size()));
|
|
||||||
if (nFrames <= 0) nFrames = 1;
|
|
||||||
LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)").arg(melTimer.elapsedMs(), 0, 'f', 1).arg(nFrames));
|
|
||||||
|
|
||||||
// 2. 运行 ONNX 推理
|
if (sampleRate != 16000) {
|
||||||
|
AudioProcessor processor(16000);
|
||||||
|
processedSamples = processor.resample(samples, sampleRate);
|
||||||
|
currentSampleRate = 16000;
|
||||||
|
LOG_DEBUG(kTag, QString("重采样: %1Hz -> %2Hz (%3 -> %4 样本)")
|
||||||
|
.arg(sampleRate).arg(currentSampleRate)
|
||||||
|
.arg(samples.size()).arg(processedSamples.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Mel 频谱图提取
|
||||||
|
MelSpectrogram melExtractor(kMelBins, 400, 160, currentSampleRate);
|
||||||
|
std::vector<float> melSpec = melExtractor.compute(processedSamples);
|
||||||
|
int nFrames = melExtractor.nFrames(static_cast<int>(processedSamples.size()));
|
||||||
|
if (nFrames <= 0) nFrames = 1;
|
||||||
|
LOG_DEBUG(kTag, QString("Mel 计算: %1 ms (%2 帧)")
|
||||||
|
.arg(preprocessTimer.elapsedMs(), 0, 'f', 1).arg(nFrames));
|
||||||
|
|
||||||
|
// 3. 运行 ONNX 推理
|
||||||
Timer inferTimer;
|
Timer inferTimer;
|
||||||
QMutexLocker locker(&impl_->mutex);
|
QMutexLocker locker(&impl_->mutex);
|
||||||
|
|
||||||
@ -277,7 +293,7 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
|
|
||||||
LOG_DEBUG(kTag, QString("ONNX 推理: %1 ms").arg(inferTimer.elapsedMs(), 0, 'f', 1));
|
LOG_DEBUG(kTag, QString("ONNX 推理: %1 ms").arg(inferTimer.elapsedMs(), 0, 'f', 1));
|
||||||
|
|
||||||
// 3. 解析输出
|
// 4. 解析输出
|
||||||
auto& outputTensor = outputTensors[0];
|
auto& outputTensor = outputTensors[0];
|
||||||
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
const float* outputData = outputTensor.GetTensorMutableData<float>();
|
const float* outputData = outputTensor.GetTensorMutableData<float>();
|
||||||
@ -292,11 +308,12 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
|
|
||||||
if (shape.size() == 2 && shape[1] == vocabSize) {
|
if (shape.size() == 2 && shape[1] == vocabSize) {
|
||||||
// [1, vocab_size] - 直接输出
|
// [1, vocab_size] - 直接输出
|
||||||
int bestToken = argmax(outputData, 0, std::min(vocabSize, 50256));
|
int searchEnd = std::min(vocabSize, 50256);
|
||||||
|
int bestToken = argmax(outputData, 0, searchEnd);
|
||||||
if (!WhisperTokenizer::isSpecialToken(bestToken)) {
|
if (!WhisperTokenizer::isSpecialToken(bestToken)) {
|
||||||
tokens.push_back(bestToken);
|
tokens.push_back(bestToken);
|
||||||
}
|
}
|
||||||
auto probs = softmax(outputData, 0, std::min(vocabSize, 50256));
|
auto probs = softmax(outputData, 0, searchEnd);
|
||||||
float maxProb = probs[0];
|
float maxProb = probs[0];
|
||||||
for (size_t i = 1; i < probs.size(); i++) {
|
for (size_t i = 1; i < probs.size(); i++) {
|
||||||
if (probs[i] > maxProb) maxProb = probs[i];
|
if (probs[i] > maxProb) maxProb = probs[i];
|
||||||
@ -310,7 +327,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
|
|
||||||
for (int t = 0; t < seqLen && static_cast<int>(tokens.size()) < kMaxTokens; t++) {
|
for (int t = 0; t < seqLen && static_cast<int>(tokens.size()) < kMaxTokens; t++) {
|
||||||
int offset = t * vocabSize;
|
int offset = t * vocabSize;
|
||||||
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
int searchEnd = std::min(offset + vocabSize, offset + 50256);
|
||||||
|
int bestToken = argmax(outputData, offset, searchEnd);
|
||||||
if (WhisperTokenizer::isSpecialToken(bestToken)) break;
|
if (WhisperTokenizer::isSpecialToken(bestToken)) break;
|
||||||
if (!tokens.empty() && tokens.back() == bestToken) continue;
|
if (!tokens.empty() && tokens.back() == bestToken) continue;
|
||||||
tokens.push_back(bestToken);
|
tokens.push_back(bestToken);
|
||||||
@ -320,8 +338,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
float avgConf = 0.0f;
|
float avgConf = 0.0f;
|
||||||
for (int t = 0; t < seqLen && t < static_cast<int>(tokens.size()); t++) {
|
for (int t = 0; t < seqLen && t < static_cast<int>(tokens.size()); t++) {
|
||||||
int offset = t * vocabSize;
|
int offset = t * vocabSize;
|
||||||
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
|
||||||
auto probs = softmax(outputData, offset, offset + vocabSize);
|
auto probs = softmax(outputData, offset, offset + vocabSize);
|
||||||
|
int bestToken = argmax(outputData, offset, offset + vocabSize);
|
||||||
avgConf += probs[bestToken - offset];
|
avgConf += probs[bestToken - offset];
|
||||||
}
|
}
|
||||||
result.confidence = avgConf / tokens.size();
|
result.confidence = avgConf / tokens.size();
|
||||||
@ -332,7 +350,7 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 解码 token 为文本
|
// 5. 解码 token 为文本
|
||||||
if (tokens.empty()) {
|
if (tokens.empty()) {
|
||||||
result.text = "";
|
result.text = "";
|
||||||
} else if (impl_->tokenizer.isLoaded()) {
|
} else if (impl_->tokenizer.isLoaded()) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user