fix: 修复 SenseVoice argmax 偏移 bug + 添加调试音频保存

修复转写结果显示为 token ID 的问题:argmax() 返回的是扁平化数组的
绝对索引,需减去 offset 才能得到正确的 token ID。同时修正置信度
计算使用正确的绝对索引。

添加调试音频保存功能:开启后每次推理将原始 PCM 保存为 WAV 文件
到 /tmp/impress_audio_debug/,并增加 RMS 电平和 NaN 诊断日志。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Alvin Young 2026-05-13 11:12:34 +08:00
parent 85b67780b1
commit d87b3e1ff8
3 changed files with 111 additions and 5 deletions

View File

@ -7,7 +7,7 @@
#define DR_WAV_IMPLEMENTATION
#define DR_MP3_IMPLEMENTATION
#define DR_FLAC_IMPLEMENTATION
#include <dr_wav.h>
#include "dr_wav.h"
#include <dr_mp3.h>
#include <dr_flac.h>
#endif

View File

@ -12,6 +12,10 @@
#include <QMutex>
#include <QMutexLocker>
#include <QFileInfo>
#include <QDir>
#include <QDateTime>
#include <QFile>
#include <QDataStream>
#include <algorithm>
#include <cmath>
@ -22,6 +26,54 @@
static const char* const kTag = "SenseVoiceEngine";
/**
* @brief WAV dr_wav
*/
static bool saveWav16(const QString& path, const std::vector<float>& samples, int sampleRate) {
QFile file(path);
if (!file.open(QIODevice::WriteOnly)) return false;
int numSamples = static_cast<int>(samples.size());
int dataSize = numSamples * 2; // 16-bit mono
int totalSize = 36 + dataSize; // RIFF header size
quint16 audioFormat = 1; // PCM
quint16 numChannels = 1;
quint32 byteRate = sampleRate * 2; // sampleRate * numChannels * bitsPerSample/8
quint16 blockAlign = 2; // numChannels * bitsPerSample/8
quint16 bitsPerSample = 16;
QDataStream out(&file);
out.setByteOrder(QDataStream::LittleEndian);
// RIFF header
out.writeRawData("RIFF", 4);
out << static_cast<quint32>(totalSize);
out.writeRawData("WAVE", 4);
// fmt chunk
out.writeRawData("fmt ", 4);
out << static_cast<quint32>(16); // chunk size
out << audioFormat;
out << numChannels;
out << static_cast<quint32>(sampleRate);
out << byteRate;
out << blockAlign;
out << bitsPerSample;
// data chunk
out.writeRawData("data", 4);
out << static_cast<quint32>(dataSize);
// PCM data (float → int16)
for (float s : samples) {
s = std::max(-1.0f, std::min(1.0f, s)); // clip
qint16 val = static_cast<qint16>(s * 32767.0f);
out << val;
}
return file.error() == QFile::FileError::NoError;
}
namespace impress {
/** 语言代码映射 */
@ -227,6 +279,11 @@ bool SenseVoiceEngine::isLoaded() const {
return loaded_;
}
void SenseVoiceEngine::setDebugSaveAudio(bool enable) {
debugSaveAudio_ = enable;
LOG_INFO(kTag, QString("调试录音保存: %1").arg(enable ? "开启" : "关闭"));
}
/** CTC 贪婪解码:去重 + 去除空白 */
static std::vector<int> ctcGreedyDecode(const std::vector<int>& tokens, int blankToken) {
std::vector<int> result;
@ -286,9 +343,53 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
try {
// 1. 重采样到 16kHz
Timer preprocessTimer;
// 调试模式:保存原始音频到 WAV 文件
if (debugSaveAudio_ && !samples.empty()) {
QString debugDir = "/tmp/impress_audio_debug";
QDir dir;
if (!dir.exists(debugDir)) {
dir.mkpath(debugDir);
}
QString timestamp = QDateTime::currentDateTime().toString("yyyyMMdd_HHmmss_zzz");
QString wavPath = QString("%1/audio_%2_%3Hz.wav")
.arg(debugDir).arg(timestamp).arg(sampleRate);
if (saveWav16(wavPath, samples, sampleRate)) {
LOG_DEBUG(kTag, QString("调试音频已保存: %1 (%2 样本, %3Hz)")
.arg(wavPath).arg(samples.size()).arg(sampleRate));
} else {
LOG_WARNING(kTag, QString("无法创建调试音频文件: %1").arg(wavPath));
}
}
std::vector<float> processedSamples = samples;
int currentSampleRate = sampleRate;
// 计算输入音频 RMS 电平用于诊断
double rms = 0.0;
bool hasNaN = false;
for (float s : samples) {
if (std::isnan(s) || std::isinf(s)) { hasNaN = true; break; }
rms += s * s;
}
rms = std::sqrt(rms / samples.size());
if (hasNaN) {
result.text = "[错误] 输入音频包含 NaN/Inf 值,请检查麦克风设备";
result.latency_ms = timer.elapsedMs();
LOG_ERROR(kTag, QString("输入音频包含无效值 (NaN/Inf), 样本数: %1").arg(samples.size()));
return result;
}
if (rms < 1e-6) {
result.text = "";
result.latency_ms = timer.elapsedMs();
LOG_DEBUG(kTag, QString("静音段 (RMS: %1), 跳过推理").arg(rms, 0, 'f', 6));
return result;
}
LOG_DEBUG(kTag, QString("输入音频 RMS: %1 (样本数: %2)").arg(rms, 0, 'f', 6).arg(samples.size()));
if (sampleRate != 16000) {
AudioProcessor processor(16000);
processedSamples = processor.resample(samples, sampleRate);
@ -370,13 +471,14 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
for (int t = 0; t < seqLen; t++) {
int offset = t * vocabSize;
int bestToken = argmax(logitsData, offset, offset + vocabSize);
int bestAbsIdx = argmax(logitsData, offset, offset + vocabSize);
int bestToken = bestAbsIdx - offset; // 绝对索引 → token ID
if (bestToken != SenseVoiceTokenizer::kTokenBlank) {
rawTokens.push_back(bestToken);
// 计算置信度
float maxLogit = logitsData[offset + bestToken];
float maxLogit = logitsData[bestAbsIdx];
// 近似置信度: 使用 softmax 的最大值位置
totalConf += maxLogit;
confCount++;
@ -398,8 +500,8 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
result.text = "";
} else if (impl_->tokenizer.isLoaded()) {
result.text = impl_->tokenizer.decode(decodedTokens);
LOG_DEBUG(kTag, QString("解码文本: %1 个 token → %2 字符")
.arg(decodedTokens.size()).arg(result.text.length()));
LOG_DEBUG(kTag, QString("解码文本: %1 个 token → %2 字符: %3")
.arg(decodedTokens.size()).arg(result.text.length()).arg(result.text));
} else {
// 降级:输出 token ID
QString decodedText;

View File

@ -48,6 +48,9 @@ public:
int sampleRate,
const QString& language = QString());
/** @brief 设置调试模式:开启后每次推理保存音频到 WAV */
void setDebugSaveAudio(bool enable);
signals:
void modelLoaded(const QString& modelPath);
void modelLoadError(const QString& modelPath, const QString& error);
@ -58,6 +61,7 @@ private:
struct Impl;
std::unique_ptr<Impl> impl_;
bool loaded_ = false;
bool debugSaveAudio_ = false;
};
} // namespace impress