fix: 修复 SenseVoice argmax 偏移 bug + 添加调试音频保存
修复转写结果显示为 token ID 的问题:argmax() 返回的是扁平化数组的 绝对索引,需减去 offset 才能得到正确的 token ID。同时修正置信度 计算使用正确的绝对索引。 添加调试音频保存功能:开启后每次推理将原始 PCM 保存为 WAV 文件 到 /tmp/impress_audio_debug/,并增加 RMS 电平和 NaN 诊断日志。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
85b67780b1
commit
d87b3e1ff8
@ -7,7 +7,7 @@
|
|||||||
#define DR_WAV_IMPLEMENTATION
|
#define DR_WAV_IMPLEMENTATION
|
||||||
#define DR_MP3_IMPLEMENTATION
|
#define DR_MP3_IMPLEMENTATION
|
||||||
#define DR_FLAC_IMPLEMENTATION
|
#define DR_FLAC_IMPLEMENTATION
|
||||||
#include <dr_wav.h>
|
#include "dr_wav.h"
|
||||||
#include <dr_mp3.h>
|
#include <dr_mp3.h>
|
||||||
#include <dr_flac.h>
|
#include <dr_flac.h>
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -12,6 +12,10 @@
|
|||||||
#include <QMutex>
|
#include <QMutex>
|
||||||
#include <QMutexLocker>
|
#include <QMutexLocker>
|
||||||
#include <QFileInfo>
|
#include <QFileInfo>
|
||||||
|
#include <QDir>
|
||||||
|
#include <QDateTime>
|
||||||
|
#include <QFile>
|
||||||
|
#include <QDataStream>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
@ -22,6 +26,54 @@
|
|||||||
|
|
||||||
static const char* const kTag = "SenseVoiceEngine";
|
static const char* const kTag = "SenseVoiceEngine";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 简易 WAV 写入(不依赖 dr_wav,避免多定义冲突)
|
||||||
|
*/
|
||||||
|
static bool saveWav16(const QString& path, const std::vector<float>& samples, int sampleRate) {
|
||||||
|
QFile file(path);
|
||||||
|
if (!file.open(QIODevice::WriteOnly)) return false;
|
||||||
|
|
||||||
|
int numSamples = static_cast<int>(samples.size());
|
||||||
|
int dataSize = numSamples * 2; // 16-bit mono
|
||||||
|
int totalSize = 36 + dataSize; // RIFF header size
|
||||||
|
quint16 audioFormat = 1; // PCM
|
||||||
|
quint16 numChannels = 1;
|
||||||
|
quint32 byteRate = sampleRate * 2; // sampleRate * numChannels * bitsPerSample/8
|
||||||
|
quint16 blockAlign = 2; // numChannels * bitsPerSample/8
|
||||||
|
quint16 bitsPerSample = 16;
|
||||||
|
|
||||||
|
QDataStream out(&file);
|
||||||
|
out.setByteOrder(QDataStream::LittleEndian);
|
||||||
|
|
||||||
|
// RIFF header
|
||||||
|
out.writeRawData("RIFF", 4);
|
||||||
|
out << static_cast<quint32>(totalSize);
|
||||||
|
out.writeRawData("WAVE", 4);
|
||||||
|
|
||||||
|
// fmt chunk
|
||||||
|
out.writeRawData("fmt ", 4);
|
||||||
|
out << static_cast<quint32>(16); // chunk size
|
||||||
|
out << audioFormat;
|
||||||
|
out << numChannels;
|
||||||
|
out << static_cast<quint32>(sampleRate);
|
||||||
|
out << byteRate;
|
||||||
|
out << blockAlign;
|
||||||
|
out << bitsPerSample;
|
||||||
|
|
||||||
|
// data chunk
|
||||||
|
out.writeRawData("data", 4);
|
||||||
|
out << static_cast<quint32>(dataSize);
|
||||||
|
|
||||||
|
// PCM data (float → int16)
|
||||||
|
for (float s : samples) {
|
||||||
|
s = std::max(-1.0f, std::min(1.0f, s)); // clip
|
||||||
|
qint16 val = static_cast<qint16>(s * 32767.0f);
|
||||||
|
out << val;
|
||||||
|
}
|
||||||
|
|
||||||
|
return file.error() == QFile::FileError::NoError;
|
||||||
|
}
|
||||||
|
|
||||||
namespace impress {
|
namespace impress {
|
||||||
|
|
||||||
/** 语言代码映射 */
|
/** 语言代码映射 */
|
||||||
@ -227,6 +279,11 @@ bool SenseVoiceEngine::isLoaded() const {
|
|||||||
return loaded_;
|
return loaded_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SenseVoiceEngine::setDebugSaveAudio(bool enable) {
|
||||||
|
debugSaveAudio_ = enable;
|
||||||
|
LOG_INFO(kTag, QString("调试录音保存: %1").arg(enable ? "开启" : "关闭"));
|
||||||
|
}
|
||||||
|
|
||||||
/** CTC 贪婪解码:去重 + 去除空白 */
|
/** CTC 贪婪解码:去重 + 去除空白 */
|
||||||
static std::vector<int> ctcGreedyDecode(const std::vector<int>& tokens, int blankToken) {
|
static std::vector<int> ctcGreedyDecode(const std::vector<int>& tokens, int blankToken) {
|
||||||
std::vector<int> result;
|
std::vector<int> result;
|
||||||
@ -286,9 +343,53 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
|||||||
try {
|
try {
|
||||||
// 1. 重采样到 16kHz
|
// 1. 重采样到 16kHz
|
||||||
Timer preprocessTimer;
|
Timer preprocessTimer;
|
||||||
|
|
||||||
|
// 调试模式:保存原始音频到 WAV 文件
|
||||||
|
if (debugSaveAudio_ && !samples.empty()) {
|
||||||
|
QString debugDir = "/tmp/impress_audio_debug";
|
||||||
|
QDir dir;
|
||||||
|
if (!dir.exists(debugDir)) {
|
||||||
|
dir.mkpath(debugDir);
|
||||||
|
}
|
||||||
|
QString timestamp = QDateTime::currentDateTime().toString("yyyyMMdd_HHmmss_zzz");
|
||||||
|
QString wavPath = QString("%1/audio_%2_%3Hz.wav")
|
||||||
|
.arg(debugDir).arg(timestamp).arg(sampleRate);
|
||||||
|
|
||||||
|
if (saveWav16(wavPath, samples, sampleRate)) {
|
||||||
|
LOG_DEBUG(kTag, QString("调试音频已保存: %1 (%2 样本, %3Hz)")
|
||||||
|
.arg(wavPath).arg(samples.size()).arg(sampleRate));
|
||||||
|
} else {
|
||||||
|
LOG_WARNING(kTag, QString("无法创建调试音频文件: %1").arg(wavPath));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<float> processedSamples = samples;
|
std::vector<float> processedSamples = samples;
|
||||||
int currentSampleRate = sampleRate;
|
int currentSampleRate = sampleRate;
|
||||||
|
|
||||||
|
// 计算输入音频 RMS 电平用于诊断
|
||||||
|
double rms = 0.0;
|
||||||
|
bool hasNaN = false;
|
||||||
|
for (float s : samples) {
|
||||||
|
if (std::isnan(s) || std::isinf(s)) { hasNaN = true; break; }
|
||||||
|
rms += s * s;
|
||||||
|
}
|
||||||
|
rms = std::sqrt(rms / samples.size());
|
||||||
|
|
||||||
|
if (hasNaN) {
|
||||||
|
result.text = "[错误] 输入音频包含 NaN/Inf 值,请检查麦克风设备";
|
||||||
|
result.latency_ms = timer.elapsedMs();
|
||||||
|
LOG_ERROR(kTag, QString("输入音频包含无效值 (NaN/Inf), 样本数: %1").arg(samples.size()));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
if (rms < 1e-6) {
|
||||||
|
result.text = "";
|
||||||
|
result.latency_ms = timer.elapsedMs();
|
||||||
|
LOG_DEBUG(kTag, QString("静音段 (RMS: %1), 跳过推理").arg(rms, 0, 'f', 6));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DEBUG(kTag, QString("输入音频 RMS: %1 (样本数: %2)").arg(rms, 0, 'f', 6).arg(samples.size()));
|
||||||
|
|
||||||
if (sampleRate != 16000) {
|
if (sampleRate != 16000) {
|
||||||
AudioProcessor processor(16000);
|
AudioProcessor processor(16000);
|
||||||
processedSamples = processor.resample(samples, sampleRate);
|
processedSamples = processor.resample(samples, sampleRate);
|
||||||
@ -370,13 +471,14 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
|||||||
|
|
||||||
for (int t = 0; t < seqLen; t++) {
|
for (int t = 0; t < seqLen; t++) {
|
||||||
int offset = t * vocabSize;
|
int offset = t * vocabSize;
|
||||||
int bestToken = argmax(logitsData, offset, offset + vocabSize);
|
int bestAbsIdx = argmax(logitsData, offset, offset + vocabSize);
|
||||||
|
int bestToken = bestAbsIdx - offset; // 绝对索引 → token ID
|
||||||
|
|
||||||
if (bestToken != SenseVoiceTokenizer::kTokenBlank) {
|
if (bestToken != SenseVoiceTokenizer::kTokenBlank) {
|
||||||
rawTokens.push_back(bestToken);
|
rawTokens.push_back(bestToken);
|
||||||
|
|
||||||
// 计算置信度
|
// 计算置信度
|
||||||
float maxLogit = logitsData[offset + bestToken];
|
float maxLogit = logitsData[bestAbsIdx];
|
||||||
// 近似置信度: 使用 softmax 的最大值位置
|
// 近似置信度: 使用 softmax 的最大值位置
|
||||||
totalConf += maxLogit;
|
totalConf += maxLogit;
|
||||||
confCount++;
|
confCount++;
|
||||||
@ -398,8 +500,8 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
|||||||
result.text = "";
|
result.text = "";
|
||||||
} else if (impl_->tokenizer.isLoaded()) {
|
} else if (impl_->tokenizer.isLoaded()) {
|
||||||
result.text = impl_->tokenizer.decode(decodedTokens);
|
result.text = impl_->tokenizer.decode(decodedTokens);
|
||||||
LOG_DEBUG(kTag, QString("解码文本: %1 个 token → %2 字符")
|
LOG_DEBUG(kTag, QString("解码文本: %1 个 token → %2 字符: %3")
|
||||||
.arg(decodedTokens.size()).arg(result.text.length()));
|
.arg(decodedTokens.size()).arg(result.text.length()).arg(result.text));
|
||||||
} else {
|
} else {
|
||||||
// 降级:输出 token ID
|
// 降级:输出 token ID
|
||||||
QString decodedText;
|
QString decodedText;
|
||||||
|
|||||||
@ -48,6 +48,9 @@ public:
|
|||||||
int sampleRate,
|
int sampleRate,
|
||||||
const QString& language = QString());
|
const QString& language = QString());
|
||||||
|
|
||||||
|
/** @brief 设置调试模式:开启后每次推理保存音频到 WAV */
|
||||||
|
void setDebugSaveAudio(bool enable);
|
||||||
|
|
||||||
signals:
|
signals:
|
||||||
void modelLoaded(const QString& modelPath);
|
void modelLoaded(const QString& modelPath);
|
||||||
void modelLoadError(const QString& modelPath, const QString& error);
|
void modelLoadError(const QString& modelPath, const QString& error);
|
||||||
@ -58,6 +61,7 @@ private:
|
|||||||
struct Impl;
|
struct Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
bool loaded_ = false;
|
bool loaded_ = false;
|
||||||
|
bool debugSaveAudio_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace impress
|
} // namespace impress
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user