feat: 添加 Windows 交叉编译支持与 ONNX Runtime MinGW 兼容方案
- 新增 C API shim (ort_api_shim.h) 解决 MinGW 与 ONNX Runtime 的 SAL 注解/_stdcall 兼容性问题 - 新增轻量级 C++ 包装器 (ort_minimal) 替代 onnxruntime_cxx_api.h - cmake/dependencies.cmake 支持 Windows/ Linux 平台自动识别依赖路径 - 修复音频采集 paNonInterleaved bug(指针被误解析为 float 导致 RMS=inf) - 修复 Windows 热键和 UI 相关代码 - 添加 MinGW 交叉编译工具链配置 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
01a39ddc8c
commit
8c2e787a25
9
.gitignore
vendored
9
.gitignore
vendored
@ -64,3 +64,12 @@ configs/config.json
|
||||
|
||||
# 日志
|
||||
*.log
|
||||
|
||||
# 构建目录(所有平台)
|
||||
build-win/
|
||||
build_linux/
|
||||
build_win/
|
||||
dist/
|
||||
|
||||
# Windows 第三方依赖(ONNX Runtime Win x64 预编译二进制)
|
||||
third_party/onnxruntime-win-x64/
|
||||
|
||||
@ -55,6 +55,7 @@ set(SOURCES
|
||||
# Core (平台无关)
|
||||
src/core/stt_engine.cpp
|
||||
src/core/sense_voice_engine.cpp
|
||||
src/core/ort_minimal.cpp
|
||||
src/core/sense_voice_features.cpp
|
||||
src/core/sense_voice_tokenizer.cpp
|
||||
src/core/mel_spectrogram.cpp
|
||||
@ -166,6 +167,36 @@ target_compile_options(${PROJECT_NAME} PRIVATE
|
||||
$<$<CXX_COMPILER_ID:MSVC>:/W4>
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Windows DLL 复制(交叉编译时自动拷贝到输出目录)
|
||||
# ============================================================================
|
||||
if(WIN32)
|
||||
if(ONNXRUNTIME_DLL)
|
||||
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
"${ONNXRUNTIME_DLL}"
|
||||
$<TARGET_FILE_DIR:${PROJECT_NAME}>
|
||||
COMMENT "拷贝 onnxruntime.dll 到输出目录"
|
||||
)
|
||||
endif()
|
||||
if(ONNXRUNTIME_PROVIDERS_DLL)
|
||||
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
"${ONNXRUNTIME_PROVIDERS_DLL}"
|
||||
$<TARGET_FILE_DIR:${PROJECT_NAME}>
|
||||
COMMENT "拷贝 onnxruntime_providers_shared.dll 到输出目录"
|
||||
)
|
||||
endif()
|
||||
if(PORTAUDIO_DLL)
|
||||
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
"${PORTAUDIO_DLL}"
|
||||
$<TARGET_FILE_DIR:${PROJECT_NAME}>
|
||||
COMMENT "拷贝 libportaudio.dll 到输出目录"
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# 资源文件
|
||||
# ============================================================================
|
||||
|
||||
@ -2,21 +2,39 @@ include(FetchContent)
|
||||
|
||||
set(THIRD_PARTY_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party")
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# ============================================================================
|
||||
# ONNX Runtime
|
||||
# ----------------------------------------------------------------------------
|
||||
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime")
|
||||
|
||||
find_library(ONNXRUNTIME_LIB
|
||||
NAMES onnxruntime
|
||||
PATHS "${ONNXRUNTIME_ROOT}/lib"
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
find_path(ONNXRUNTIME_INCLUDE_DIR
|
||||
NAMES onnxruntime_cxx_api.h
|
||||
PATHS "${ONNXRUNTIME_ROOT}/include"
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
# ============================================================================
|
||||
if(WIN32)
|
||||
# Windows 版本:onnxruntime.dll
|
||||
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime-win-x64")
|
||||
if(NOT EXISTS "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||
# 回退到旧目录名
|
||||
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime")
|
||||
endif()
|
||||
# 直接用 DLL 路径(MinGW 可直接链接 DLL)
|
||||
if(EXISTS "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||
set(ONNXRUNTIME_LIB "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||
set(ONNXRUNTIME_DLL "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||
set(ONNXRUNTIME_INCLUDE_DIR "${ONNXRUNTIME_ROOT}/include")
|
||||
endif()
|
||||
if(NOT ONNXRUNTIME_INCLUDE_DIR)
|
||||
set(ONNXRUNTIME_INCLUDE_DIR "${ONNXRUNTIME_ROOT}/include")
|
||||
endif()
|
||||
else()
|
||||
# Linux 版本:libonnxruntime.so
|
||||
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime")
|
||||
find_library(ONNXRUNTIME_LIB
|
||||
NAMES onnxruntime
|
||||
PATHS "${ONNXRUNTIME_ROOT}/lib"
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
find_path(ONNXRUNTIME_INCLUDE_DIR
|
||||
NAMES onnxruntime_cxx_api.h
|
||||
PATHS "${ONNXRUNTIME_ROOT}/include"
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
endif()
|
||||
|
||||
if(ONNXRUNTIME_LIB AND ONNXRUNTIME_INCLUDE_DIR)
|
||||
set(ONNXRUNTIME_LIBRARIES ${ONNXRUNTIME_LIB})
|
||||
@ -27,21 +45,45 @@ else()
|
||||
message(WARNING "未找到 ONNX Runtime,推理功能将使用占位实现")
|
||||
endif()
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# ============================================================================
|
||||
# PortAudio
|
||||
# ----------------------------------------------------------------------------
|
||||
# ============================================================================
|
||||
set(PORTAUDIO_ROOT "${THIRD_PARTY_DIR}/portaudio")
|
||||
|
||||
find_library(PORTAUDIO_LIB
|
||||
NAMES portaudio libportaudio
|
||||
PATHS "${PORTAUDIO_ROOT}/lib"
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
find_path(PORTAUDIO_INCLUDE_DIR
|
||||
NAMES portaudio.h
|
||||
PATHS "${PORTAUDIO_ROOT}/include"
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
if(WIN32)
|
||||
# Windows 版本:libportaudio.dll 在 bin/ 目录
|
||||
if(EXISTS "${PORTAUDIO_ROOT}/bin/libportaudio.dll")
|
||||
set(PORTAUDIO_LIB "${PORTAUDIO_ROOT}/bin/libportaudio.dll")
|
||||
set(PORTAUDIO_DLL "${PORTAUDIO_ROOT}/bin/libportaudio.dll")
|
||||
endif()
|
||||
if(EXISTS "${PORTAUDIO_ROOT}/include/portaudio.h")
|
||||
set(PORTAUDIO_INCLUDE_DIR "${PORTAUDIO_ROOT}/include")
|
||||
endif()
|
||||
else()
|
||||
# Linux 版本:优先使用构建好的本地库
|
||||
find_library(PORTAUDIO_LIB
|
||||
NAMES portaudio libportaudio
|
||||
PATHS "${PORTAUDIO_ROOT}/lib" /usr/lib64 /usr/lib /usr/local/lib
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
find_path(PORTAUDIO_INCLUDE_DIR
|
||||
NAMES portaudio.h
|
||||
PATHS "${PORTAUDIO_ROOT}/include" /usr/include /usr/include/portaudio /usr/local/include
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
|
||||
# 回退:通过 pkg-config 查找
|
||||
if(NOT PORTAUDIO_LIB OR NOT PORTAUDIO_INCLUDE_DIR)
|
||||
find_package(PkgConfig QUIET)
|
||||
if(PKG_CONFIG_FOUND)
|
||||
pkg_check_modules(PORTAUDIO_PC portaudio-2.0 QUIET)
|
||||
if(PORTAUDIO_PC_FOUND)
|
||||
set(PORTAUDIO_LIBRARIES ${PORTAUDIO_PC_LIBRARIES})
|
||||
set(PORTAUDIO_INCLUDE_DIRS ${PORTAUDIO_PC_INCLUDE_DIRS})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(PORTAUDIO_LIB AND PORTAUDIO_INCLUDE_DIR)
|
||||
set(PORTAUDIO_LIBRARIES ${PORTAUDIO_LIB})
|
||||
@ -52,9 +94,9 @@ else()
|
||||
message(WARNING "未找到 PortAudio,音频采集功能将使用占位实现")
|
||||
endif()
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# ============================================================================
|
||||
# dr_libs (header-only)
|
||||
# ----------------------------------------------------------------------------
|
||||
# ============================================================================
|
||||
set(DR_LIBS_INCLUDE_DIR "${THIRD_PARTY_DIR}/dr_libs")
|
||||
if(EXISTS "${DR_LIBS_INCLUDE_DIR}/dr_wav.h")
|
||||
message(STATUS "找到 dr_libs: ${DR_LIBS_INCLUDE_DIR}")
|
||||
@ -63,7 +105,7 @@ else()
|
||||
message(WARNING "未找到 dr_libs 头文件")
|
||||
endif()
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# ============================================================================
|
||||
# nlohmann/json (header-only)
|
||||
# ----------------------------------------------------------------------------
|
||||
# ============================================================================
|
||||
set(NLOHMANN_JSON_INCLUDE_DIR "${THIRD_PARTY_DIR}/nlohmann_json")
|
||||
|
||||
15
cmake/mingw-toolchain.cmake
Normal file
15
cmake/mingw-toolchain.cmake
Normal file
@ -0,0 +1,15 @@
|
||||
# MinGW Windows 交叉编译工具链
|
||||
set(CMAKE_SYSTEM_NAME Windows)
|
||||
set(CMAKE_SYSTEM_PROCESSOR x86_64)
|
||||
|
||||
set(CMAKE_C_COMPILER x86_64-w64-mingw32-gcc)
|
||||
set(CMAKE_CXX_COMPILER x86_64-w64-mingw32-g++)
|
||||
set(CMAKE_RC_COMPILER x86_64-w64-mingw32-windres)
|
||||
|
||||
set(CMAKE_FIND_ROOT_PATH /usr/x86_64-w64-mingw32/sys-root/mingw)
|
||||
|
||||
# 搜索策略:允许在 sysroot 和宿主路径之外查找
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
||||
26
run.sh
26
run.sh
@ -1,11 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Impress Voice Input 启动脚本
|
||||
# 设置 ONNX Runtime / PortAudio 库路径并启动应用
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
BUILD_DIR="${SCRIPT_DIR}/build"
|
||||
ONNXRUNTIME_LIB_DIR="${SCRIPT_DIR}/third_party/onnxruntime/lib"
|
||||
PORTAUDIO_LIB_DIR="${SCRIPT_DIR}/third_party/portaudio/lib"
|
||||
|
||||
# 设置库路径
|
||||
export LD_LIBRARY_PATH="${SCRIPT_DIR}/third_party/onnxruntime/lib:${SCRIPT_DIR}/third_party/portaudio/lib:${LD_LIBRARY_PATH}"
|
||||
# 检查可执行文件
|
||||
if [ ! -f "${BUILD_DIR}/impress_voice_input" ]; then
|
||||
echo "错误:未找到可执行文件,请先编译:"
|
||||
echo " mkdir -p build && cd build"
|
||||
echo " cmake .. -DCMAKE_BUILD_TYPE=RelWithDebInfo"
|
||||
echo " cmake --build . -j\$(nproc)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 运行
|
||||
# 检查 ONNX Runtime
|
||||
if [ ! -f "${ONNXRUNTIME_LIB_DIR}/libonnxruntime.so" ]; then
|
||||
echo "警告:ONNX Runtime 未部署,推理功能将不可用"
|
||||
echo " 请按照 third_party/README.md 部署 ONNX Runtime"
|
||||
fi
|
||||
|
||||
# 设置库路径(ONNX Runtime 优先,PortAudio 回退到系统)
|
||||
export LD_LIBRARY_PATH="${ONNXRUNTIME_LIB_DIR}:${PORTAUDIO_LIB_DIR}:${LD_LIBRARY_PATH}"
|
||||
|
||||
# 启动应用
|
||||
exec "${BUILD_DIR}/impress_voice_input" "$@"
|
||||
|
||||
@ -13,6 +13,7 @@ namespace impress {
|
||||
// 预分配缓冲区,避免在实时回调中分配内存
|
||||
static constexpr int kMaxBufferSize = 8192;
|
||||
|
||||
#ifdef HAVE_PORTAUDIO
|
||||
// 全局 PortAudio 初始化状态
|
||||
static bool gPaInitialized = false;
|
||||
|
||||
@ -33,6 +34,7 @@ static void safePaTerminate() {
|
||||
gPaInitialized = false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// 回调上下文:独立于 Impl 的 POD 结构,供静态回调使用
|
||||
struct CallbackContext {
|
||||
@ -52,13 +54,13 @@ struct AudioCapture::Impl {
|
||||
CallbackContext ctx;
|
||||
};
|
||||
|
||||
#ifdef HAVE_PORTAUDIO
|
||||
static int paCallback(const void* input, void* /*output*/,
|
||||
unsigned long frameCount,
|
||||
const PaStreamCallbackTimeInfo* /*timeInfo*/,
|
||||
PaStreamCallbackFlags /*statusFlags*/,
|
||||
void* userData)
|
||||
{
|
||||
#ifdef HAVE_PORTAUDIO
|
||||
auto* ctx = static_cast<CallbackContext*>(userData);
|
||||
|
||||
const float* samples = static_cast<const float*>(input);
|
||||
@ -88,11 +90,15 @@ static int paCallback(const void* input, void* /*output*/,
|
||||
emit ctx->owner->audioDataReady(data, ctx->sampleRate);
|
||||
|
||||
return paContinue;
|
||||
#else
|
||||
(void)input; (void)frameCount; (void)userData;
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
// 占位回调(无 PortAudio 时不使用)
|
||||
static int paCallbackStub(const void*, void*, unsigned long,
|
||||
const int*, int, void*)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
AudioCapture::AudioCapture(QObject* parent)
|
||||
: QObject(parent)
|
||||
@ -107,8 +113,8 @@ AudioCapture::~AudioCapture() {
|
||||
|
||||
QStringList AudioCapture::getDeviceList() {
|
||||
QStringList devices;
|
||||
devices << "默认设备";
|
||||
#ifdef HAVE_PORTAUDIO
|
||||
devices << "默认设备";
|
||||
if (!ensurePaInitialized()) {
|
||||
LOG_ERROR(kTag, "PortAudio 初始化失败");
|
||||
return devices;
|
||||
@ -124,6 +130,9 @@ QStringList AudioCapture::getDeviceList() {
|
||||
.arg(info->defaultSampleRate).arg(hostApiName);
|
||||
}
|
||||
}
|
||||
#else
|
||||
devices << "PortAudio 未启用(占位设备)";
|
||||
LOG_WARNING(kTag, "PortAudio 未编译启用,设备列表为占位");
|
||||
#endif
|
||||
return devices;
|
||||
}
|
||||
@ -195,7 +204,9 @@ bool AudioCapture::start(int deviceIndex, int sampleRate, int bufferSizeMs) {
|
||||
PaStreamParameters inputParams{};
|
||||
inputParams.device = devIdx;
|
||||
inputParams.channelCount = 1;
|
||||
inputParams.sampleFormat = paFloat32 | paNonInterleaved;
|
||||
inputParams.sampleFormat = paFloat32;
|
||||
// 不使用 paNonInterleaved:input 指针直接是 float* 数组(interleaved mono),
|
||||
// 回调中可以安全地 static_cast<const float*>(input)
|
||||
// 使用高延迟以避免回调过快
|
||||
inputParams.suggestedLatency = devInfo->defaultHighInputLatency;
|
||||
|
||||
@ -232,8 +243,9 @@ bool AudioCapture::start(int deviceIndex, int sampleRate, int bufferSizeMs) {
|
||||
.arg(deviceIndex).arg(sampleRate).arg(bufferSizeMs));
|
||||
return true;
|
||||
#else
|
||||
LOG_ERROR(kTag, "PortAudio 未编译启用");
|
||||
emit error("PortAudio 未编译启用");
|
||||
(void)deviceIndex; (void)sampleRate; (void)bufferSizeMs;
|
||||
LOG_ERROR(kTag, "PortAudio 未编译启用,无法启动采集");
|
||||
emit error("PortAudio 未编译启用,请在 third_party/portaudio/ 中部署后重新编译");
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
62
src/core/ort_api_shim.h
Normal file
62
src/core/ort_api_shim.h
Normal file
@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
/**
|
||||
* @brief ONNX Runtime C API Shim
|
||||
*
|
||||
* 包装 onnxruntime_c_api.h 以解决 MinGW 交叉编译兼容性问题:
|
||||
* 1. SAL 注解(specstrings.h 中的宏与 MinGW 不兼容)
|
||||
* 2. _stdcall 调用约定导致函数声明语法错误
|
||||
*
|
||||
* 关键:必须在包含 onnxruntime_c_api.h 之前定义这些宏,
|
||||
* 因为 header 内部的 #define 会使用它们。
|
||||
*/
|
||||
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
#ifndef ORT_API_SHIM_H
|
||||
#define ORT_API_SHIM_H
|
||||
|
||||
#ifdef _WIN32
|
||||
#define ORT_DLL_IMPORT
|
||||
|
||||
/* 在 specstrings.h 被包含之前,抢先定义 SAL 注解为空。
|
||||
onnxruntime_c_api.h 第 74 行 #include <specstrings.h>,
|
||||
如果 specstrings.h 用 #ifndef 保护,我们的定义就不会被覆盖。
|
||||
即使被覆盖,_stdcall 下面的定义也会生效。 */
|
||||
#define _Success_(x)
|
||||
#define _Check_return_
|
||||
#define _Ret_maybenull_
|
||||
#define _In_
|
||||
#define _In_z_
|
||||
#define _In_opt_
|
||||
#define _In_opt_z_
|
||||
#define _Out_
|
||||
#define _Outptr_
|
||||
#define _Out_opt_
|
||||
#define _Inout_
|
||||
#define _Inout_opt_
|
||||
#define _Frees_ptr_opt_
|
||||
#define _Ret_notnull_
|
||||
#define _In_reads_(x)
|
||||
#define _Inout_updates_(x)
|
||||
#define _Out_writes_(x)
|
||||
#define _Inout_updates_all_(x)
|
||||
#define _Out_writes_bytes_all_(x)
|
||||
#define _Out_writes_all_(x)
|
||||
#define _Outptr_result_maybenull_(x)
|
||||
#define _In_reads_opt_(x)
|
||||
#define _Outptr_result_buffer_maybenull_(x)
|
||||
#define _Return_type_success_(x)
|
||||
#define _Out_writes_bytes_all_opt_(x)
|
||||
#define _In_reads_bytes_(x)
|
||||
|
||||
/* 将 _stdcall 定义为空。
|
||||
onnxruntime_c_api.h 第 86 行 #define ORT_API_CALL _stdcall,
|
||||
所以当 ORT_API_CALL 展开为 _stdcall 后,_stdcall 再展开为空。
|
||||
这样最终的函数声明没有调用约定修饰,MinGW 可以正常解析。 */
|
||||
#define _stdcall
|
||||
#define __stdcall
|
||||
#endif /* _WIN32 */
|
||||
|
||||
#include <onnxruntime_c_api.h>
|
||||
|
||||
#endif /* ORT_API_SHIM_H */
|
||||
#endif /* HAVE_ONNXRUNTIME */
|
||||
365
src/core/ort_minimal.cpp
Normal file
365
src/core/ort_minimal.cpp
Normal file
@ -0,0 +1,365 @@
|
||||
/**
|
||||
* @brief ONNX Runtime 轻量级 C API 包装器 - 实现
|
||||
*
|
||||
* 直接使用 C API,避免 onnxruntime_cxx_api.h 的 MinGW 兼容性问题。
|
||||
*/
|
||||
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
|
||||
#include "ort_minimal.h"
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
namespace ort {
|
||||
|
||||
/**
|
||||
* 获取 OrtApi 指针
|
||||
*
|
||||
* 调用路径: OrtGetApiBase() -> OrtApiBase -> GetApi(version) -> OrtApi
|
||||
*/
|
||||
const OrtApi* getApi() {
|
||||
static const OrtApi* api = nullptr;
|
||||
if (!api) {
|
||||
const OrtApiBase* apiBase = OrtGetApiBase();
|
||||
api = apiBase->GetApi(ORT_API_VERSION);
|
||||
}
|
||||
return api;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Env
|
||||
// ============================================================================
|
||||
Env::Env(OrtLoggingLevel logLevel, const char* logId) {
|
||||
const OrtApi* api = getApi();
|
||||
OrtStatus* status = api->CreateEnv(logLevel, logId, &env_);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateEnv failed");
|
||||
}
|
||||
}
|
||||
|
||||
Env::~Env() {
|
||||
if (env_) {
|
||||
const OrtApi* api = getApi();
|
||||
api->ReleaseEnv(env_);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SessionOptions
|
||||
// ============================================================================
|
||||
SessionOptions::SessionOptions() {
|
||||
const OrtApi* api = getApi();
|
||||
OrtStatus* status = api->CreateSessionOptions(&opts_);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateSessionOptions failed");
|
||||
}
|
||||
}
|
||||
|
||||
SessionOptions::~SessionOptions() {
|
||||
if (opts_) {
|
||||
const OrtApi* api = getApi();
|
||||
api->ReleaseSessionOptions(opts_);
|
||||
}
|
||||
}
|
||||
|
||||
void SessionOptions::setIntraOpNumThreads(int n) {
|
||||
const OrtApi* api = getApi();
|
||||
OrtStatus* status = api->SetIntraOpNumThreads(opts_, n);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "SetIntraOpNumThreads failed");
|
||||
}
|
||||
}
|
||||
|
||||
void SessionOptions::setGraphOptimizationLevel(GraphOptimizationLevel level) {
|
||||
const OrtApi* api = getApi();
|
||||
OrtStatus* status = api->SetSessionGraphOptimizationLevel(opts_, level);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "SetSessionGraphOptimizationLevel failed");
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session
|
||||
// ============================================================================
|
||||
Session::Session(const Env& env, const char* modelPath, const SessionOptions& opts) {
|
||||
const OrtApi* api = getApi();
|
||||
#ifdef _WIN32
|
||||
// Windows: ORTCHAR_T = wchar_t, 需要将 UTF-8 转换为 UTF-16
|
||||
int len = MultiByteToWideChar(CP_UTF8, 0, modelPath, -1, nullptr, 0);
|
||||
std::vector<wchar_t> widePath(len);
|
||||
MultiByteToWideChar(CP_UTF8, 0, modelPath, -1, widePath.data(), len);
|
||||
OrtStatus* status = api->CreateSession(env.ptr(), widePath.data(), opts.ptr(), &session_);
|
||||
#else
|
||||
OrtStatus* status = api->CreateSession(env.ptr(), modelPath, opts.ptr(), &session_);
|
||||
#endif
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateSession failed");
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
Session::Session(const Env& env, const wchar_t* modelPath, const SessionOptions& opts) {
|
||||
const OrtApi* api = getApi();
|
||||
OrtStatus* status = api->CreateSession(env.ptr(), modelPath, opts.ptr(), &session_);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateSession failed");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Session::~Session() {
|
||||
if (session_) {
|
||||
const OrtApi* api = getApi();
|
||||
api->ReleaseSession(session_);
|
||||
}
|
||||
}
|
||||
|
||||
size_t Session::getInputCount() const {
|
||||
const OrtApi* api = getApi();
|
||||
size_t count = 0;
|
||||
OrtStatus* status = api->SessionGetInputCount(session_, &count);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "SessionGetInputCount failed");
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
size_t Session::getOutputCount() const {
|
||||
const OrtApi* api = getApi();
|
||||
size_t count = 0;
|
||||
OrtStatus* status = api->SessionGetOutputCount(session_, &count);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "SessionGetOutputCount failed");
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
std::string Session::getInputName(size_t index) const {
|
||||
const OrtApi* api = getApi();
|
||||
char* name = nullptr;
|
||||
OrtAllocator* allocator = nullptr;
|
||||
OrtStatus* status = api->GetAllocatorWithDefaultOptions(&allocator);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "GetAllocatorWithDefaultOptions failed");
|
||||
}
|
||||
status = api->SessionGetInputName(session_, index, allocator, &name);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "SessionGetInputName failed");
|
||||
}
|
||||
std::string result(name);
|
||||
// 不要释放 name - allocator 分配的内存由 allocator 管理
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Session::getOutputName(size_t index) const {
|
||||
const OrtApi* api = getApi();
|
||||
char* name = nullptr;
|
||||
OrtAllocator* allocator = nullptr;
|
||||
OrtStatus* status = api->GetAllocatorWithDefaultOptions(&allocator);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "GetAllocatorWithDefaultOptions failed");
|
||||
}
|
||||
status = api->SessionGetOutputName(session_, index, allocator, &name);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "SessionGetOutputName failed");
|
||||
}
|
||||
std::string result(name);
|
||||
return result;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MemoryInfo
|
||||
// ============================================================================
|
||||
MemoryInfo::MemoryInfo(OrtMemoryInfo* info) : info_(info) {}
|
||||
|
||||
MemoryInfo MemoryInfo::createCpu(OrtAllocatorType type, OrtMemType memType) {
|
||||
const OrtApi* api = getApi();
|
||||
OrtMemoryInfo* info = nullptr;
|
||||
OrtStatus* status = api->CreateCpuMemoryInfo(type, memType, &info);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateCpuMemoryInfo failed");
|
||||
}
|
||||
return MemoryInfo(info);
|
||||
}
|
||||
|
||||
MemoryInfo::~MemoryInfo() {
|
||||
if (info_) {
|
||||
const OrtApi* api = getApi();
|
||||
api->ReleaseMemoryInfo(info_);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Value
|
||||
// ============================================================================
|
||||
Value::Value(OrtValue* value) : value_(value) {}
|
||||
|
||||
Value Value::fromRaw(OrtValue* value) {
|
||||
return Value(value);
|
||||
}
|
||||
|
||||
Value Value::createTensor(const MemoryInfo& info, float* data, size_t elemCount,
|
||||
const int64_t* shape, size_t shapeLen) {
|
||||
const OrtApi* api = getApi();
|
||||
OrtValue* value = nullptr;
|
||||
OrtStatus* status = api->CreateTensorWithDataAsOrtValue(
|
||||
info.ptr(), data, elemCount * sizeof(float),
|
||||
shape, shapeLen, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateTensorWithDataAsOrtValue failed");
|
||||
}
|
||||
return Value(value);
|
||||
}
|
||||
|
||||
Value Value::createTensor(const MemoryInfo& info, int32_t* data, size_t elemCount,
|
||||
const int64_t* shape, size_t shapeLen) {
|
||||
const OrtApi* api = getApi();
|
||||
OrtValue* value = nullptr;
|
||||
OrtStatus* status = api->CreateTensorWithDataAsOrtValue(
|
||||
info.ptr(), data, elemCount * sizeof(int32_t),
|
||||
shape, shapeLen, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, &value);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateTensorWithDataAsOrtValue failed");
|
||||
}
|
||||
return Value(value);
|
||||
}
|
||||
|
||||
Value::~Value() {
|
||||
if (value_) {
|
||||
const OrtApi* api = getApi();
|
||||
api->ReleaseValue(value_);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> Value::getShape() const {
|
||||
const OrtApi* api = getApi();
|
||||
OrtTensorTypeAndShapeInfo* info = nullptr;
|
||||
OrtStatus* status = api->GetTensorTypeAndShape(value_, &info);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "GetTensorTypeAndShape failed");
|
||||
}
|
||||
size_t dimCount = 0;
|
||||
status = api->GetDimensionsCount(info, &dimCount);
|
||||
std::vector<int64_t> shape(dimCount);
|
||||
if (dimCount > 0) {
|
||||
status = api->GetDimensions(info, shape.data(), dimCount);
|
||||
}
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "GetDimensions failed");
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
const float* Value::getTensorData() const {
|
||||
const OrtApi* api = getApi();
|
||||
void* data = nullptr;
|
||||
OrtStatus* status = api->GetTensorMutableData(value_, &data);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "GetTensorMutableData failed");
|
||||
}
|
||||
return static_cast<const float*>(data);
|
||||
}
|
||||
|
||||
Value::Value(Value&& other) noexcept : value_(other.value_) {
|
||||
other.value_ = nullptr;
|
||||
}
|
||||
|
||||
Value& Value::operator=(Value&& other) noexcept {
|
||||
if (this != &other) {
|
||||
if (value_) {
|
||||
const OrtApi* api = getApi();
|
||||
api->ReleaseValue(value_);
|
||||
}
|
||||
value_ = other.value_;
|
||||
other.value_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// RunOptions
|
||||
// ============================================================================
|
||||
RunOptions::RunOptions() {
|
||||
const OrtApi* api = getApi();
|
||||
OrtStatus* status = api->CreateRunOptions(&opts_);
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "CreateRunOptions failed");
|
||||
}
|
||||
}
|
||||
|
||||
RunOptions::~RunOptions() {
|
||||
if (opts_) {
|
||||
const OrtApi* api = getApi();
|
||||
api->ReleaseRunOptions(opts_);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// run
|
||||
// ============================================================================
|
||||
std::vector<Value> run(Session& session,
|
||||
const RunOptions& runOptions,
|
||||
const char* const* inputNames,
|
||||
Value* inputValues,
|
||||
size_t inputCount,
|
||||
const char* const* outputNames,
|
||||
size_t outputCount)
|
||||
{
|
||||
const OrtApi* api = getApi();
|
||||
|
||||
// 准备输入输出 OrtValue 指针数组
|
||||
std::vector<OrtValue*> inputPtrs(inputCount);
|
||||
for (size_t i = 0; i < inputCount; i++) {
|
||||
inputPtrs[i] = inputValues[i].ptr();
|
||||
}
|
||||
|
||||
std::vector<OrtValue*> outputPtrs(outputCount, nullptr);
|
||||
|
||||
OrtStatus* status = api->Run(
|
||||
session.ptr(),
|
||||
runOptions.ptr(),
|
||||
inputNames, inputPtrs.data(), static_cast<int>(inputCount),
|
||||
outputNames, static_cast<int>(outputCount),
|
||||
outputPtrs.data());
|
||||
|
||||
if (status) {
|
||||
const char* msg = api->GetErrorMessage(status);
|
||||
throw Exception(msg ? msg : "Run failed");
|
||||
}
|
||||
|
||||
std::vector<Value> results;
|
||||
results.reserve(outputCount);
|
||||
for (size_t i = 0; i < outputCount; i++) {
|
||||
results.emplace_back(Value::fromRaw(outputPtrs[i]));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace ort
|
||||
|
||||
#endif // HAVE_ONNXRUNTIME
|
||||
134
src/core/ort_minimal.h
Normal file
134
src/core/ort_minimal.h
Normal file
@ -0,0 +1,134 @@
|
||||
#pragma once
|
||||
/**
|
||||
* @brief ONNX Runtime 轻量级 C API 包装器
|
||||
*
|
||||
* 替代 onnxruntime_cxx_api.h(该文件与 MinGW 存在 ABI 兼容性问题)。
|
||||
* 直接使用 C API(onnxruntime_c_api.h),用异常替代 C 风格返回值。
|
||||
*/
|
||||
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
|
||||
#ifndef ORT_MINIMAL_H
|
||||
#define ORT_MINIMAL_H
|
||||
|
||||
/* 使用 shim 头文件处理 MinGW 兼容性问题 */
|
||||
#include "ort_api_shim.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace ort {
|
||||
|
||||
/** 异常类型 */
|
||||
class Exception : public std::runtime_error {
|
||||
public:
|
||||
explicit Exception(const char* msg) : std::runtime_error(msg) {}
|
||||
explicit Exception(const std::string& msg) : std::runtime_error(msg) {}
|
||||
};
|
||||
|
||||
/** 获取 API 基础指针(内部使用) */
|
||||
const OrtApi* getApi();
|
||||
|
||||
/** Env: ONNX Runtime 环境 */
|
||||
class Env {
|
||||
public:
|
||||
explicit Env(OrtLoggingLevel logLevel = ORT_LOGGING_LEVEL_WARNING,
|
||||
const char* logId = "ort");
|
||||
~Env();
|
||||
OrtEnv* ptr() const { return env_; }
|
||||
Env(const Env&) = delete;
|
||||
Env& operator=(const Env&) = delete;
|
||||
private:
|
||||
OrtEnv* env_ = nullptr;
|
||||
};
|
||||
|
||||
/** SessionOptions: 会话配置选项 */
|
||||
class SessionOptions {
|
||||
public:
|
||||
SessionOptions();
|
||||
~SessionOptions();
|
||||
OrtSessionOptions* ptr() const { return opts_; }
|
||||
void setIntraOpNumThreads(int n);
|
||||
void setGraphOptimizationLevel(GraphOptimizationLevel level);
|
||||
SessionOptions(const SessionOptions&) = delete;
|
||||
SessionOptions& operator=(const SessionOptions&) = delete;
|
||||
private:
|
||||
OrtSessionOptions* opts_ = nullptr;
|
||||
};
|
||||
|
||||
/** Session: 推理会话 */
|
||||
class Session {
|
||||
public:
|
||||
Session(const Env& env, const char* modelPath, const SessionOptions& opts);
|
||||
Session(const Env& env, const wchar_t* modelPath, const SessionOptions& opts);
|
||||
~Session();
|
||||
OrtSession* ptr() const { return session_; }
|
||||
size_t getInputCount() const;
|
||||
size_t getOutputCount() const;
|
||||
std::string getInputName(size_t index) const;
|
||||
std::string getOutputName(size_t index) const;
|
||||
Session(const Session&) = delete;
|
||||
Session& operator=(const Session&) = delete;
|
||||
private:
|
||||
OrtSession* session_ = nullptr;
|
||||
};
|
||||
|
||||
/** MemoryInfo: 内存信息 */
|
||||
class MemoryInfo {
|
||||
public:
|
||||
static MemoryInfo createCpu(OrtAllocatorType type = OrtDeviceAllocator,
|
||||
OrtMemType memType = OrtMemTypeCPU);
|
||||
~MemoryInfo();
|
||||
const OrtMemoryInfo* ptr() const { return info_; }
|
||||
private:
|
||||
explicit MemoryInfo(OrtMemoryInfo* info);
|
||||
OrtMemoryInfo* info_ = nullptr;
|
||||
};
|
||||
|
||||
/** Value: 张量值 */
|
||||
class Value {
|
||||
public:
|
||||
static Value createTensor(const MemoryInfo& info, float* data, size_t elemCount,
|
||||
const int64_t* shape, size_t shapeLen);
|
||||
static Value createTensor(const MemoryInfo& info, int32_t* data, size_t elemCount,
|
||||
const int64_t* shape, size_t shapeLen);
|
||||
~Value();
|
||||
OrtValue* ptr() const { return value_; }
|
||||
std::vector<int64_t> getShape() const;
|
||||
const float* getTensorData() const;
|
||||
Value(Value&& other) noexcept;
|
||||
Value& operator=(Value&& other) noexcept;
|
||||
Value(const Value&) = delete;
|
||||
Value& operator=(const Value&) = delete;
|
||||
private:
|
||||
explicit Value(OrtValue* value);
|
||||
OrtValue* value_ = nullptr;
|
||||
public:
|
||||
/** @brief 从原始指针构造(用于接收 C API 返回的值) */
|
||||
static Value fromRaw(OrtValue* value);
|
||||
};
|
||||
|
||||
/** RunOptions: 推理选项 */
|
||||
class RunOptions {
|
||||
public:
|
||||
RunOptions();
|
||||
~RunOptions();
|
||||
OrtRunOptions* ptr() const { return opts_; }
|
||||
private:
|
||||
OrtRunOptions* opts_ = nullptr;
|
||||
};
|
||||
|
||||
/** 推理执行 */
|
||||
std::vector<Value> run(Session& session,
|
||||
const RunOptions& runOptions,
|
||||
const char* const* inputNames,
|
||||
Value* inputValues,
|
||||
size_t inputCount,
|
||||
const char* const* outputNames,
|
||||
size_t outputCount);
|
||||
|
||||
} // namespace ort
|
||||
|
||||
#endif // ORT_MINIMAL_H
|
||||
#endif // HAVE_ONNXRUNTIME
|
||||
@ -3,6 +3,7 @@
|
||||
#include "sense_voice_tokenizer.h"
|
||||
#include "sense_voice_cmvn.h"
|
||||
#include "audio_processor.h"
|
||||
#include "ort_minimal.h"
|
||||
#include "utils/logger.h"
|
||||
#include "utils/timer.h"
|
||||
|
||||
@ -19,11 +20,6 @@
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
// ONNX Runtime headers
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
#endif
|
||||
|
||||
static const char* const kTag = "SenseVoiceEngine";
|
||||
|
||||
/**
|
||||
@ -93,9 +89,9 @@ static int languageToInt(const QString& lang) {
|
||||
*/
|
||||
struct SenseVoiceEngine::Impl {
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
std::unique_ptr<Ort::Env> env;
|
||||
std::unique_ptr<Ort::SessionOptions> sessionOptions;
|
||||
std::unique_ptr<Ort::Session> session;
|
||||
std::unique_ptr<ort::Env> env;
|
||||
std::unique_ptr<ort::SessionOptions> sessionOptions;
|
||||
std::unique_ptr<ort::Session> session;
|
||||
|
||||
std::vector<std::string> inputNames;
|
||||
std::vector<std::string> outputNames;
|
||||
@ -111,11 +107,11 @@ struct SenseVoiceEngine::Impl {
|
||||
{
|
||||
QMutexLocker locker(&mutex);
|
||||
try {
|
||||
auto envPtr = std::make_unique<Ort::Env>(
|
||||
auto envPtr = std::make_unique<ort::Env>(
|
||||
ORT_LOGGING_LEVEL_WARNING, "impress_sensevoice");
|
||||
auto optionsPtr = std::make_unique<Ort::SessionOptions>();
|
||||
optionsPtr->SetIntraOpNumThreads(numThreads);
|
||||
optionsPtr->SetGraphOptimizationLevel(
|
||||
auto optionsPtr = std::make_unique<ort::SessionOptions>();
|
||||
optionsPtr->setIntraOpNumThreads(numThreads);
|
||||
optionsPtr->setGraphOptimizationLevel(
|
||||
GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
|
||||
if (device == "gpu") {
|
||||
@ -125,14 +121,17 @@ struct SenseVoiceEngine::Impl {
|
||||
LOG_INFO(kTag, QString("正在加载 SenseVoice 模型: %1 (线程: %2)")
|
||||
.arg(modelPath).arg(numThreads));
|
||||
|
||||
auto sessionPtr = std::make_unique<Ort::Session>(
|
||||
auto sessionPtr = std::make_unique<ort::Session>(
|
||||
*envPtr,
|
||||
#ifdef _WIN32
|
||||
modelPath.toStdWString().c_str(),
|
||||
#else
|
||||
modelPath.toUtf8().constData(),
|
||||
#endif
|
||||
*optionsPtr);
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
size_t inputCount = sessionPtr->GetInputCount();
|
||||
size_t outputCount = sessionPtr->GetOutputCount();
|
||||
size_t inputCount = sessionPtr->getInputCount();
|
||||
size_t outputCount = sessionPtr->getOutputCount();
|
||||
|
||||
LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出")
|
||||
.arg(inputCount).arg(outputCount));
|
||||
@ -141,15 +140,13 @@ struct SenseVoiceEngine::Impl {
|
||||
outputNames.clear();
|
||||
|
||||
for (size_t i = 0; i < inputCount; i++) {
|
||||
auto namePtr = sessionPtr->GetInputNameAllocated(i, allocator);
|
||||
inputNames.emplace_back(namePtr.get());
|
||||
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(namePtr.get()));
|
||||
inputNames.emplace_back(sessionPtr->getInputName(i));
|
||||
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(QString::fromStdString(inputNames.back())));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputCount; i++) {
|
||||
auto namePtr = sessionPtr->GetOutputNameAllocated(i, allocator);
|
||||
outputNames.emplace_back(namePtr.get());
|
||||
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(namePtr.get()));
|
||||
outputNames.emplace_back(sessionPtr->getOutputName(i));
|
||||
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(QString::fromStdString(outputNames.back())));
|
||||
}
|
||||
|
||||
env = std::move(envPtr);
|
||||
@ -174,7 +171,7 @@ struct SenseVoiceEngine::Impl {
|
||||
|
||||
LOG_INFO(kTag, QString("SenseVoice 模型加载成功: %1").arg(modelPath));
|
||||
return true;
|
||||
} catch (const Ort::Exception& e) {
|
||||
} catch (const ort::Exception& e) {
|
||||
errorMsg = QString("ONNX 异常: %1").arg(e.what());
|
||||
LOG_ERROR(kTag, errorMsg);
|
||||
return false;
|
||||
@ -185,6 +182,21 @@ struct SenseVoiceEngine::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
QMutex mutex;
|
||||
#else
|
||||
// 占位实现:无 ONNX Runtime 时仅提供基本结构
|
||||
bool loadInWorker(const QString& /*modelPath*/,
|
||||
const QString& /*tokensPath*/,
|
||||
const QString& /*device*/,
|
||||
int /*numThreads*/,
|
||||
QString& errorMsg)
|
||||
{
|
||||
errorMsg = "ONNX Runtime 未安装,推理功能不可用。"
|
||||
"请在 third_party/onnxruntime/ 中部署 ONNX Runtime 后重新编译。";
|
||||
LOG_ERROR(kTag, errorMsg);
|
||||
return false;
|
||||
}
|
||||
|
||||
QMutex mutex;
|
||||
#endif
|
||||
};
|
||||
@ -206,11 +218,13 @@ bool SenseVoiceEngine::loadModelSync(const QString& modelPath,
|
||||
if (loaded_) {
|
||||
LOG_WARNING(kTag, "模型已加载,先卸载再加载");
|
||||
// 内联清理,避免调用 unloadModel() 导致 mutex 递归死锁
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
impl_->session.reset();
|
||||
impl_->sessionOptions.reset();
|
||||
impl_->env.reset();
|
||||
impl_->features.reset();
|
||||
impl_->tokenizer = SenseVoiceTokenizer();
|
||||
#endif
|
||||
loaded_ = false;
|
||||
}
|
||||
|
||||
@ -235,11 +249,13 @@ void SenseVoiceEngine::loadModelAsync(const QString& modelPath,
|
||||
if (loaded_) {
|
||||
LOG_WARNING(kTag, "模型已加载,先卸载再加载");
|
||||
// 内联清理,避免调用 unloadModel() 导致 mutex 递归死锁
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
impl_->session.reset();
|
||||
impl_->sessionOptions.reset();
|
||||
impl_->env.reset();
|
||||
impl_->features.reset();
|
||||
impl_->tokenizer = SenseVoiceTokenizer();
|
||||
#endif
|
||||
loaded_ = false;
|
||||
}
|
||||
|
||||
@ -417,7 +433,7 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
||||
|
||||
// 输入: x, x_length, language, text_norm
|
||||
int64_t xShape[] = {1, numFrames, kLFROutputDim};
|
||||
auto memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
auto memInfo = ort::MemoryInfo::createCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
|
||||
int32_t xLengthVal = numFrames;
|
||||
int64_t xLengthShape[] = {1};
|
||||
@ -429,14 +445,14 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
||||
int32_t textNormVal = kTextNormWithITN;
|
||||
int64_t textNormShape[] = {1};
|
||||
|
||||
std::vector<Ort::Value> inputTensors;
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(
|
||||
std::vector<ort::Value> inputTensors;
|
||||
inputTensors.push_back(ort::Value::createTensor(
|
||||
memInfo, lfrFeatures.data(), lfrFeatures.size(), xShape, 3));
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<int32_t>(
|
||||
inputTensors.push_back(ort::Value::createTensor(
|
||||
memInfo, &xLengthVal, 1, xLengthShape, 1));
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<int32_t>(
|
||||
inputTensors.push_back(ort::Value::createTensor(
|
||||
memInfo, &langVal, 1, langShape, 1));
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<int32_t>(
|
||||
inputTensors.push_back(ort::Value::createTensor(
|
||||
memInfo, &textNormVal, 1, textNormShape, 1));
|
||||
|
||||
// 4. 运行推理
|
||||
@ -446,8 +462,10 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
||||
std::vector<const char*> outputNamePtrs;
|
||||
for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str());
|
||||
|
||||
auto outputTensors = impl_->session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
ort::RunOptions runOptions;
|
||||
auto outputTensors = ort::run(
|
||||
*impl_->session,
|
||||
runOptions,
|
||||
inputNamePtrs.data(), inputTensors.data(), inputTensors.size(),
|
||||
outputNamePtrs.data(), outputNamePtrs.size());
|
||||
|
||||
@ -455,8 +473,8 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
||||
|
||||
// 5. 解析输出 logits [1, seq_len, 25055]
|
||||
auto& outputTensor = outputTensors[0];
|
||||
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float* logitsData = outputTensor.GetTensorData<float>();
|
||||
auto shape = outputTensor.getShape();
|
||||
const float* logitsData = outputTensor.getTensorData();
|
||||
|
||||
LOG_DEBUG(kTag, QString("输出维度: [%1, %2, %3]")
|
||||
.arg(shape[0]).arg(shape[1]).arg(shape[2]));
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include "mel_spectrogram.h"
|
||||
#include "whisper_tokenizer.h"
|
||||
#include "audio_processor.h"
|
||||
#include "ort_minimal.h"
|
||||
#include "utils/logger.h"
|
||||
#include "utils/timer.h"
|
||||
|
||||
@ -15,11 +16,6 @@
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
// ONNX Runtime headers
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
#endif
|
||||
|
||||
static const char* const kTag = "STTEngine";
|
||||
|
||||
// Whisper 常量
|
||||
@ -33,9 +29,9 @@ namespace impress {
|
||||
*/
|
||||
struct STTEngine::Impl {
|
||||
#ifdef HAVE_ONNXRUNTIME
|
||||
std::unique_ptr<Ort::Env> env;
|
||||
std::unique_ptr<Ort::SessionOptions> sessionOptions;
|
||||
std::unique_ptr<Ort::Session> session;
|
||||
std::unique_ptr<ort::Env> env;
|
||||
std::unique_ptr<ort::SessionOptions> sessionOptions;
|
||||
std::unique_ptr<ort::Session> session;
|
||||
|
||||
std::vector<std::string> inputNames;
|
||||
std::vector<std::string> outputNames;
|
||||
@ -50,11 +46,11 @@ struct STTEngine::Impl {
|
||||
{
|
||||
QMutexLocker locker(&mutex);
|
||||
try {
|
||||
auto envPtr = std::make_unique<Ort::Env>(
|
||||
auto envPtr = std::make_unique<ort::Env>(
|
||||
ORT_LOGGING_LEVEL_WARNING, "impress_voice");
|
||||
auto optionsPtr = std::make_unique<Ort::SessionOptions>();
|
||||
optionsPtr->SetIntraOpNumThreads(numThreads);
|
||||
optionsPtr->SetGraphOptimizationLevel(
|
||||
auto optionsPtr = std::make_unique<ort::SessionOptions>();
|
||||
optionsPtr->setIntraOpNumThreads(numThreads);
|
||||
optionsPtr->setGraphOptimizationLevel(
|
||||
GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
|
||||
if (device == "gpu") {
|
||||
@ -63,14 +59,17 @@ struct STTEngine::Impl {
|
||||
|
||||
LOG_INFO(kTag, QString("正在加载模型: %1 (线程: %2)").arg(modelPath).arg(numThreads));
|
||||
|
||||
auto sessionPtr = std::make_unique<Ort::Session>(
|
||||
auto sessionPtr = std::make_unique<ort::Session>(
|
||||
*envPtr,
|
||||
#ifdef _WIN32
|
||||
modelPath.toStdWString().c_str(),
|
||||
#else
|
||||
modelPath.toUtf8().constData(),
|
||||
#endif
|
||||
*optionsPtr);
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
size_t inputCount = sessionPtr->GetInputCount();
|
||||
size_t outputCount = sessionPtr->GetOutputCount();
|
||||
size_t inputCount = sessionPtr->getInputCount();
|
||||
size_t outputCount = sessionPtr->getOutputCount();
|
||||
|
||||
LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出")
|
||||
.arg(inputCount).arg(outputCount));
|
||||
@ -79,15 +78,13 @@ struct STTEngine::Impl {
|
||||
outputNames.clear();
|
||||
|
||||
for (size_t i = 0; i < inputCount; i++) {
|
||||
auto namePtr = sessionPtr->GetInputNameAllocated(i, allocator);
|
||||
inputNames.emplace_back(namePtr.get());
|
||||
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(namePtr.get()));
|
||||
inputNames.emplace_back(sessionPtr->getInputName(i));
|
||||
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(QString::fromStdString(inputNames.back())));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputCount; i++) {
|
||||
auto namePtr = sessionPtr->GetOutputNameAllocated(i, allocator);
|
||||
outputNames.emplace_back(namePtr.get());
|
||||
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(namePtr.get()));
|
||||
outputNames.emplace_back(sessionPtr->getOutputName(i));
|
||||
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(QString::fromStdString(outputNames.back())));
|
||||
}
|
||||
|
||||
env = std::move(envPtr);
|
||||
@ -106,7 +103,7 @@ struct STTEngine::Impl {
|
||||
|
||||
LOG_INFO(kTag, QString("模型加载成功: %1").arg(modelPath));
|
||||
return true;
|
||||
} catch (const Ort::Exception& e) {
|
||||
} catch (const ort::Exception& e) {
|
||||
errorMsg = QString("ONNX 异常: %1").arg(e.what());
|
||||
LOG_ERROR(kTag, errorMsg);
|
||||
return false;
|
||||
@ -117,6 +114,20 @@ struct STTEngine::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
QMutex mutex;
|
||||
#else
|
||||
// 占位实现:无 ONNX Runtime 时仅提供基本结构
|
||||
bool loadInWorker(const QString& /*modelPath*/,
|
||||
const QString& /*device*/,
|
||||
int /*numThreads*/,
|
||||
QString& errorMsg)
|
||||
{
|
||||
errorMsg = "ONNX Runtime 未安装,推理功能不可用。"
|
||||
"请在 third_party/onnxruntime/ 中部署 ONNX Runtime 后重新编译。";
|
||||
LOG_ERROR(kTag, errorMsg);
|
||||
return false;
|
||||
}
|
||||
|
||||
QMutex mutex;
|
||||
#endif
|
||||
};
|
||||
@ -276,9 +287,10 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
QMutexLocker locker(&impl_->mutex);
|
||||
|
||||
int64_t melShape[] = {1, kMelBins, static_cast<int64_t>(nFrames)};
|
||||
auto memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
std::vector<Ort::Value> inputTensors;
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(
|
||||
auto memInfo = ort::MemoryInfo::createCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
|
||||
std::vector<ort::Value> inputTensors;
|
||||
inputTensors.push_back(ort::Value::createTensor(
|
||||
memInfo, melSpec.data(), melSpec.size(), melShape, 3));
|
||||
|
||||
std::vector<const char*> inputNamePtrs;
|
||||
@ -286,8 +298,10 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
std::vector<const char*> outputNamePtrs;
|
||||
for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str());
|
||||
|
||||
auto outputTensors = impl_->session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
ort::RunOptions runOptions;
|
||||
auto outputTensors = ort::run(
|
||||
*impl_->session,
|
||||
runOptions,
|
||||
inputNamePtrs.data(), inputTensors.data(), inputTensors.size(),
|
||||
outputNamePtrs.data(), impl_->outputNames.size());
|
||||
|
||||
@ -295,8 +309,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
||||
|
||||
// 4. 解析输出
|
||||
auto& outputTensor = outputTensors[0];
|
||||
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float* outputData = outputTensor.GetTensorMutableData<float>();
|
||||
auto shape = outputTensor.getShape();
|
||||
const float* outputData = outputTensor.getTensorData();
|
||||
|
||||
LOG_DEBUG(kTag, QString("输出维度: %1").arg(shape.size()));
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
|
||||
@ -5,6 +5,8 @@
|
||||
#include <windows.h>
|
||||
#include <QAbstractNativeEventFilter>
|
||||
#include <QGuiApplication>
|
||||
#include <QThread>
|
||||
#include <QWidget>
|
||||
#endif
|
||||
|
||||
static const char* const kTag = "CapsLockVoiceHotkey";
|
||||
@ -20,6 +22,7 @@ struct CapsLockVoiceHotkey::Impl {
|
||||
bool longPressFired = false;
|
||||
bool pollThreadRunning = false;
|
||||
void* nativeEventFilter = nullptr;
|
||||
QWidget* hiddenWindow = nullptr;
|
||||
static constexpr int kLongPressMs = 1000;
|
||||
#endif
|
||||
};
|
||||
@ -63,14 +66,19 @@ bool CapsLockVoiceHotkey::start() {
|
||||
if (active_) return true;
|
||||
|
||||
#ifdef Q_OS_WIN
|
||||
HWND hwnd = reinterpret_cast<HWND>(QGuiApplication::instance()->winId());
|
||||
if (!hwnd) {
|
||||
// Try to get the top-level widget's window handle
|
||||
hwnd = GetForegroundWindow();
|
||||
// 创建隐藏窗口用于接收 WM_HOTKEY 消息
|
||||
if (!impl_->hiddenWindow) {
|
||||
impl_->hiddenWindow = new QWidget();
|
||||
impl_->hiddenWindow->setObjectName("HotkeyReceiver");
|
||||
impl_->hiddenWindow->setWindowFlags(Qt::Tool | Qt::FramelessWindowHint);
|
||||
impl_->hiddenWindow->resize(0, 0);
|
||||
}
|
||||
// 确保窗口已创建(show 会创建原生句柄)
|
||||
impl_->hiddenWindow->show();
|
||||
|
||||
HWND hwnd = reinterpret_cast<HWND>(impl_->hiddenWindow->winId());
|
||||
if (!hwnd) {
|
||||
emit error("无法获取窗口句柄");
|
||||
emit error("无法创建窗口句柄");
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -146,11 +154,14 @@ void CapsLockVoiceHotkey::stop() {
|
||||
}
|
||||
|
||||
// 注销快捷键
|
||||
HWND hwnd = reinterpret_cast<HWND>(QGuiApplication::instance()->winId());
|
||||
if (hwnd && impl_->hotkeyId) {
|
||||
UnregisterHotKey(hwnd, impl_->hotkeyId);
|
||||
GlobalDeleteAtom(impl_->hotkeyId);
|
||||
impl_->hotkeyId = 0;
|
||||
if (impl_->hiddenWindow) {
|
||||
HWND hwnd = reinterpret_cast<HWND>(impl_->hiddenWindow->winId());
|
||||
if (hwnd && impl_->hotkeyId) {
|
||||
UnregisterHotKey(hwnd, impl_->hotkeyId);
|
||||
GlobalDeleteAtom(impl_->hotkeyId);
|
||||
impl_->hotkeyId = 0;
|
||||
}
|
||||
impl_->hiddenWindow->hide();
|
||||
}
|
||||
|
||||
active_ = false;
|
||||
@ -208,7 +219,7 @@ void CapsLockVoiceHotkey::onHotkeyEvent(int /*hotkeyId*/) {
|
||||
QThread::msleep(50);
|
||||
}
|
||||
impl_->pollThreadRunning = false;
|
||||
}).start();
|
||||
})->start();
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ signals:
|
||||
void error(const QString& message);
|
||||
|
||||
/** @brief 处理 WM_HOTKEY 事件(由原生事件过滤器调用) */
|
||||
public:
|
||||
void onHotkeyEvent(int hotkeyId);
|
||||
|
||||
private:
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include <QLabel>
|
||||
#include <QFileDialog>
|
||||
#include <QMessageBox>
|
||||
#include <QScrollArea>
|
||||
|
||||
static const char* const kTag = "SettingsPage";
|
||||
|
||||
@ -34,10 +35,24 @@ SettingsPage::~SettingsPage() = default;
|
||||
|
||||
void SettingsPage::setupUI() {
|
||||
auto* mainLayout = new QVBoxLayout(this);
|
||||
mainLayout->setContentsMargins(0, 0, 0, 0);
|
||||
|
||||
// ---- 可滚动内容区域 ----
|
||||
auto* scrollArea = new QScrollArea(this);
|
||||
scrollArea->setWidgetResizable(true);
|
||||
scrollArea->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
|
||||
scrollArea->setFrameShape(QFrame::NoFrame);
|
||||
|
||||
auto* contentWidget = new QWidget(scrollArea);
|
||||
auto* contentLayout = new QVBoxLayout(contentWidget);
|
||||
contentLayout->setContentsMargins(12, 8, 12, 8);
|
||||
contentLayout->setSpacing(12);
|
||||
|
||||
// STT 设置
|
||||
auto* sttGroup = new QGroupBox("STT 推理设置", this);
|
||||
auto* sttGroup = new QGroupBox("STT 推理设置", contentWidget);
|
||||
auto* sttLayout = new QFormLayout(sttGroup);
|
||||
sttLayout->setSpacing(8);
|
||||
sttLayout->setContentsMargins(10, 12, 10, 12);
|
||||
|
||||
auto* modelRow = new QHBoxLayout();
|
||||
modelPathEdit_ = new QLineEdit(this);
|
||||
@ -109,11 +124,13 @@ void SettingsPage::setupUI() {
|
||||
temperatureSpin_->setValue(0.0);
|
||||
sttLayout->addRow("温度 (Temperature):", temperatureSpin_);
|
||||
|
||||
mainLayout->addWidget(sttGroup);
|
||||
contentLayout->addWidget(sttGroup);
|
||||
|
||||
// 音频设置
|
||||
auto* audioGroup = new QGroupBox("音频设置", this);
|
||||
auto* audioGroup = new QGroupBox("音频设置", contentWidget);
|
||||
auto* audioLayout = new QFormLayout(audioGroup);
|
||||
audioLayout->setSpacing(8);
|
||||
audioLayout->setContentsMargins(10, 12, 10, 12);
|
||||
|
||||
// 音频输入设备选择器
|
||||
audioDeviceCombo_ = new QComboBox(this);
|
||||
@ -150,11 +167,13 @@ void SettingsPage::setupUI() {
|
||||
paddingSpin_->setSuffix(" ms");
|
||||
audioLayout->addRow("块间重叠:", paddingSpin_);
|
||||
|
||||
mainLayout->addWidget(audioGroup);
|
||||
contentLayout->addWidget(audioGroup);
|
||||
|
||||
// UI 设置
|
||||
auto* uiGroup = new QGroupBox("界面设置", this);
|
||||
auto* uiGroup = new QGroupBox("界面设置", contentWidget);
|
||||
auto* uiLayout = new QFormLayout(uiGroup);
|
||||
uiLayout->setSpacing(8);
|
||||
uiLayout->setContentsMargins(10, 12, 10, 12);
|
||||
|
||||
themeCombo_ = new QComboBox(this);
|
||||
themeCombo_->addItems({"light", "dark"});
|
||||
@ -173,26 +192,32 @@ void SettingsPage::setupUI() {
|
||||
showConfidenceCheck_->setChecked(true);
|
||||
uiLayout->addRow("置信度显示:", showConfidenceCheck_);
|
||||
|
||||
mainLayout->addWidget(uiGroup);
|
||||
contentLayout->addWidget(uiGroup);
|
||||
contentLayout->addStretch();
|
||||
|
||||
// 操作按钮
|
||||
auto* btnLayout = new QHBoxLayout();
|
||||
auto* saveBtn = new QPushButton("保存配置", this);
|
||||
scrollArea->setWidget(contentWidget);
|
||||
mainLayout->addWidget(scrollArea);
|
||||
|
||||
// ---- 底部操作按钮(固定不滚动) ----
|
||||
auto* btnBar = new QWidget(this);
|
||||
auto* btnLayout = new QHBoxLayout(btnBar);
|
||||
btnLayout->setContentsMargins(12, 4, 12, 8);
|
||||
|
||||
auto* saveBtn = new QPushButton("保存配置", btnBar);
|
||||
saveBtn->setStyleSheet("QPushButton { font-weight: bold; padding: 8px 16px; }");
|
||||
connect(saveBtn, &QPushButton::clicked, this, &SettingsPage::onSaveConfig);
|
||||
btnLayout->addWidget(saveBtn);
|
||||
|
||||
auto* resetBtn = new QPushButton("恢复默认", this);
|
||||
auto* resetBtn = new QPushButton("恢复默认", btnBar);
|
||||
connect(resetBtn, &QPushButton::clicked, this, &SettingsPage::onResetConfig);
|
||||
btnLayout->addWidget(resetBtn);
|
||||
btnLayout->addStretch();
|
||||
|
||||
statusLabel_ = new QLabel("配置未修改", this);
|
||||
statusLabel_ = new QLabel("配置未修改", btnBar);
|
||||
statusLabel_->setStyleSheet("color: gray;");
|
||||
btnLayout->addWidget(statusLabel_);
|
||||
|
||||
mainLayout->addLayout(btnLayout);
|
||||
mainLayout->addStretch();
|
||||
mainLayout->addWidget(btnBar);
|
||||
}
|
||||
|
||||
void SettingsPage::loadFromConfig() {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user