feat: 添加 Windows 交叉编译支持与 ONNX Runtime MinGW 兼容方案
- 新增 C API shim (ort_api_shim.h) 解决 MinGW 与 ONNX Runtime 的 SAL 注解/_stdcall 兼容性问题 - 新增轻量级 C++ 包装器 (ort_minimal) 替代 onnxruntime_cxx_api.h - cmake/dependencies.cmake 支持 Windows/ Linux 平台自动识别依赖路径 - 修复音频采集 paNonInterleaved bug(指针被误解析为 float 导致 RMS=inf) - 修复 Windows 热键和 UI 相关代码 - 添加 MinGW 交叉编译工具链配置 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
01a39ddc8c
commit
8c2e787a25
9
.gitignore
vendored
9
.gitignore
vendored
@ -64,3 +64,12 @@ configs/config.json
|
|||||||
|
|
||||||
# 日志
|
# 日志
|
||||||
*.log
|
*.log
|
||||||
|
|
||||||
|
# 构建目录(所有平台)
|
||||||
|
build-win/
|
||||||
|
build_linux/
|
||||||
|
build_win/
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# Windows 第三方依赖(ONNX Runtime Win x64 预编译二进制)
|
||||||
|
third_party/onnxruntime-win-x64/
|
||||||
|
|||||||
@ -55,6 +55,7 @@ set(SOURCES
|
|||||||
# Core (平台无关)
|
# Core (平台无关)
|
||||||
src/core/stt_engine.cpp
|
src/core/stt_engine.cpp
|
||||||
src/core/sense_voice_engine.cpp
|
src/core/sense_voice_engine.cpp
|
||||||
|
src/core/ort_minimal.cpp
|
||||||
src/core/sense_voice_features.cpp
|
src/core/sense_voice_features.cpp
|
||||||
src/core/sense_voice_tokenizer.cpp
|
src/core/sense_voice_tokenizer.cpp
|
||||||
src/core/mel_spectrogram.cpp
|
src/core/mel_spectrogram.cpp
|
||||||
@ -166,6 +167,36 @@ target_compile_options(${PROJECT_NAME} PRIVATE
|
|||||||
$<$<CXX_COMPILER_ID:MSVC>:/W4>
|
$<$<CXX_COMPILER_ID:MSVC>:/W4>
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Windows DLL 复制(交叉编译时自动拷贝到输出目录)
|
||||||
|
# ============================================================================
|
||||||
|
if(WIN32)
|
||||||
|
if(ONNXRUNTIME_DLL)
|
||||||
|
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||||
|
"${ONNXRUNTIME_DLL}"
|
||||||
|
$<TARGET_FILE_DIR:${PROJECT_NAME}>
|
||||||
|
COMMENT "拷贝 onnxruntime.dll 到输出目录"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
if(ONNXRUNTIME_PROVIDERS_DLL)
|
||||||
|
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||||
|
"${ONNXRUNTIME_PROVIDERS_DLL}"
|
||||||
|
$<TARGET_FILE_DIR:${PROJECT_NAME}>
|
||||||
|
COMMENT "拷贝 onnxruntime_providers_shared.dll 到输出目录"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
if(PORTAUDIO_DLL)
|
||||||
|
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||||
|
"${PORTAUDIO_DLL}"
|
||||||
|
$<TARGET_FILE_DIR:${PROJECT_NAME}>
|
||||||
|
COMMENT "拷贝 libportaudio.dll 到输出目录"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 资源文件
|
# 资源文件
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@ -2,21 +2,39 @@ include(FetchContent)
|
|||||||
|
|
||||||
set(THIRD_PARTY_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party")
|
set(THIRD_PARTY_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party")
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
# ONNX Runtime
|
# ONNX Runtime
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime")
|
if(WIN32)
|
||||||
|
# Windows 版本:onnxruntime.dll
|
||||||
find_library(ONNXRUNTIME_LIB
|
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime-win-x64")
|
||||||
NAMES onnxruntime
|
if(NOT EXISTS "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||||
PATHS "${ONNXRUNTIME_ROOT}/lib"
|
# 回退到旧目录名
|
||||||
NO_DEFAULT_PATH
|
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime")
|
||||||
)
|
endif()
|
||||||
find_path(ONNXRUNTIME_INCLUDE_DIR
|
# 直接用 DLL 路径(MinGW 可直接链接 DLL)
|
||||||
NAMES onnxruntime_cxx_api.h
|
if(EXISTS "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||||
PATHS "${ONNXRUNTIME_ROOT}/include"
|
set(ONNXRUNTIME_LIB "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||||
NO_DEFAULT_PATH
|
set(ONNXRUNTIME_DLL "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll")
|
||||||
)
|
set(ONNXRUNTIME_INCLUDE_DIR "${ONNXRUNTIME_ROOT}/include")
|
||||||
|
endif()
|
||||||
|
if(NOT ONNXRUNTIME_INCLUDE_DIR)
|
||||||
|
set(ONNXRUNTIME_INCLUDE_DIR "${ONNXRUNTIME_ROOT}/include")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# Linux 版本:libonnxruntime.so
|
||||||
|
set(ONNXRUNTIME_ROOT "${THIRD_PARTY_DIR}/onnxruntime")
|
||||||
|
find_library(ONNXRUNTIME_LIB
|
||||||
|
NAMES onnxruntime
|
||||||
|
PATHS "${ONNXRUNTIME_ROOT}/lib"
|
||||||
|
NO_DEFAULT_PATH
|
||||||
|
)
|
||||||
|
find_path(ONNXRUNTIME_INCLUDE_DIR
|
||||||
|
NAMES onnxruntime_cxx_api.h
|
||||||
|
PATHS "${ONNXRUNTIME_ROOT}/include"
|
||||||
|
NO_DEFAULT_PATH
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(ONNXRUNTIME_LIB AND ONNXRUNTIME_INCLUDE_DIR)
|
if(ONNXRUNTIME_LIB AND ONNXRUNTIME_INCLUDE_DIR)
|
||||||
set(ONNXRUNTIME_LIBRARIES ${ONNXRUNTIME_LIB})
|
set(ONNXRUNTIME_LIBRARIES ${ONNXRUNTIME_LIB})
|
||||||
@ -27,21 +45,45 @@ else()
|
|||||||
message(WARNING "未找到 ONNX Runtime,推理功能将使用占位实现")
|
message(WARNING "未找到 ONNX Runtime,推理功能将使用占位实现")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
# PortAudio
|
# PortAudio
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
set(PORTAUDIO_ROOT "${THIRD_PARTY_DIR}/portaudio")
|
set(PORTAUDIO_ROOT "${THIRD_PARTY_DIR}/portaudio")
|
||||||
|
|
||||||
find_library(PORTAUDIO_LIB
|
if(WIN32)
|
||||||
NAMES portaudio libportaudio
|
# Windows 版本:libportaudio.dll 在 bin/ 目录
|
||||||
PATHS "${PORTAUDIO_ROOT}/lib"
|
if(EXISTS "${PORTAUDIO_ROOT}/bin/libportaudio.dll")
|
||||||
NO_DEFAULT_PATH
|
set(PORTAUDIO_LIB "${PORTAUDIO_ROOT}/bin/libportaudio.dll")
|
||||||
)
|
set(PORTAUDIO_DLL "${PORTAUDIO_ROOT}/bin/libportaudio.dll")
|
||||||
find_path(PORTAUDIO_INCLUDE_DIR
|
endif()
|
||||||
NAMES portaudio.h
|
if(EXISTS "${PORTAUDIO_ROOT}/include/portaudio.h")
|
||||||
PATHS "${PORTAUDIO_ROOT}/include"
|
set(PORTAUDIO_INCLUDE_DIR "${PORTAUDIO_ROOT}/include")
|
||||||
NO_DEFAULT_PATH
|
endif()
|
||||||
)
|
else()
|
||||||
|
# Linux 版本:优先使用构建好的本地库
|
||||||
|
find_library(PORTAUDIO_LIB
|
||||||
|
NAMES portaudio libportaudio
|
||||||
|
PATHS "${PORTAUDIO_ROOT}/lib" /usr/lib64 /usr/lib /usr/local/lib
|
||||||
|
NO_DEFAULT_PATH
|
||||||
|
)
|
||||||
|
find_path(PORTAUDIO_INCLUDE_DIR
|
||||||
|
NAMES portaudio.h
|
||||||
|
PATHS "${PORTAUDIO_ROOT}/include" /usr/include /usr/include/portaudio /usr/local/include
|
||||||
|
NO_DEFAULT_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
# 回退:通过 pkg-config 查找
|
||||||
|
if(NOT PORTAUDIO_LIB OR NOT PORTAUDIO_INCLUDE_DIR)
|
||||||
|
find_package(PkgConfig QUIET)
|
||||||
|
if(PKG_CONFIG_FOUND)
|
||||||
|
pkg_check_modules(PORTAUDIO_PC portaudio-2.0 QUIET)
|
||||||
|
if(PORTAUDIO_PC_FOUND)
|
||||||
|
set(PORTAUDIO_LIBRARIES ${PORTAUDIO_PC_LIBRARIES})
|
||||||
|
set(PORTAUDIO_INCLUDE_DIRS ${PORTAUDIO_PC_INCLUDE_DIRS})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
if(PORTAUDIO_LIB AND PORTAUDIO_INCLUDE_DIR)
|
if(PORTAUDIO_LIB AND PORTAUDIO_INCLUDE_DIR)
|
||||||
set(PORTAUDIO_LIBRARIES ${PORTAUDIO_LIB})
|
set(PORTAUDIO_LIBRARIES ${PORTAUDIO_LIB})
|
||||||
@ -52,9 +94,9 @@ else()
|
|||||||
message(WARNING "未找到 PortAudio,音频采集功能将使用占位实现")
|
message(WARNING "未找到 PortAudio,音频采集功能将使用占位实现")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
# dr_libs (header-only)
|
# dr_libs (header-only)
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
set(DR_LIBS_INCLUDE_DIR "${THIRD_PARTY_DIR}/dr_libs")
|
set(DR_LIBS_INCLUDE_DIR "${THIRD_PARTY_DIR}/dr_libs")
|
||||||
if(EXISTS "${DR_LIBS_INCLUDE_DIR}/dr_wav.h")
|
if(EXISTS "${DR_LIBS_INCLUDE_DIR}/dr_wav.h")
|
||||||
message(STATUS "找到 dr_libs: ${DR_LIBS_INCLUDE_DIR}")
|
message(STATUS "找到 dr_libs: ${DR_LIBS_INCLUDE_DIR}")
|
||||||
@ -63,7 +105,7 @@ else()
|
|||||||
message(WARNING "未找到 dr_libs 头文件")
|
message(WARNING "未找到 dr_libs 头文件")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
# nlohmann/json (header-only)
|
# nlohmann/json (header-only)
|
||||||
# ----------------------------------------------------------------------------
|
# ============================================================================
|
||||||
set(NLOHMANN_JSON_INCLUDE_DIR "${THIRD_PARTY_DIR}/nlohmann_json")
|
set(NLOHMANN_JSON_INCLUDE_DIR "${THIRD_PARTY_DIR}/nlohmann_json")
|
||||||
|
|||||||
15
cmake/mingw-toolchain.cmake
Normal file
15
cmake/mingw-toolchain.cmake
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# MinGW Windows 交叉编译工具链
|
||||||
|
set(CMAKE_SYSTEM_NAME Windows)
|
||||||
|
set(CMAKE_SYSTEM_PROCESSOR x86_64)
|
||||||
|
|
||||||
|
set(CMAKE_C_COMPILER x86_64-w64-mingw32-gcc)
|
||||||
|
set(CMAKE_CXX_COMPILER x86_64-w64-mingw32-g++)
|
||||||
|
set(CMAKE_RC_COMPILER x86_64-w64-mingw32-windres)
|
||||||
|
|
||||||
|
set(CMAKE_FIND_ROOT_PATH /usr/x86_64-w64-mingw32/sys-root/mingw)
|
||||||
|
|
||||||
|
# 搜索策略:允许在 sysroot 和宿主路径之外查找
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
||||||
26
run.sh
26
run.sh
@ -1,11 +1,31 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Impress Voice Input 启动脚本
|
# Impress Voice Input 启动脚本
|
||||||
|
# 设置 ONNX Runtime / PortAudio 库路径并启动应用
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
BUILD_DIR="${SCRIPT_DIR}/build"
|
BUILD_DIR="${SCRIPT_DIR}/build"
|
||||||
|
ONNXRUNTIME_LIB_DIR="${SCRIPT_DIR}/third_party/onnxruntime/lib"
|
||||||
|
PORTAUDIO_LIB_DIR="${SCRIPT_DIR}/third_party/portaudio/lib"
|
||||||
|
|
||||||
# 设置库路径
|
# 检查可执行文件
|
||||||
export LD_LIBRARY_PATH="${SCRIPT_DIR}/third_party/onnxruntime/lib:${SCRIPT_DIR}/third_party/portaudio/lib:${LD_LIBRARY_PATH}"
|
if [ ! -f "${BUILD_DIR}/impress_voice_input" ]; then
|
||||||
|
echo "错误:未找到可执行文件,请先编译:"
|
||||||
|
echo " mkdir -p build && cd build"
|
||||||
|
echo " cmake .. -DCMAKE_BUILD_TYPE=RelWithDebInfo"
|
||||||
|
echo " cmake --build . -j\$(nproc)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
# 运行
|
# 检查 ONNX Runtime
|
||||||
|
if [ ! -f "${ONNXRUNTIME_LIB_DIR}/libonnxruntime.so" ]; then
|
||||||
|
echo "警告:ONNX Runtime 未部署,推理功能将不可用"
|
||||||
|
echo " 请按照 third_party/README.md 部署 ONNX Runtime"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 设置库路径(ONNX Runtime 优先,PortAudio 回退到系统)
|
||||||
|
export LD_LIBRARY_PATH="${ONNXRUNTIME_LIB_DIR}:${PORTAUDIO_LIB_DIR}:${LD_LIBRARY_PATH}"
|
||||||
|
|
||||||
|
# 启动应用
|
||||||
exec "${BUILD_DIR}/impress_voice_input" "$@"
|
exec "${BUILD_DIR}/impress_voice_input" "$@"
|
||||||
|
|||||||
@ -13,6 +13,7 @@ namespace impress {
|
|||||||
// 预分配缓冲区,避免在实时回调中分配内存
|
// 预分配缓冲区,避免在实时回调中分配内存
|
||||||
static constexpr int kMaxBufferSize = 8192;
|
static constexpr int kMaxBufferSize = 8192;
|
||||||
|
|
||||||
|
#ifdef HAVE_PORTAUDIO
|
||||||
// 全局 PortAudio 初始化状态
|
// 全局 PortAudio 初始化状态
|
||||||
static bool gPaInitialized = false;
|
static bool gPaInitialized = false;
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ static void safePaTerminate() {
|
|||||||
gPaInitialized = false;
|
gPaInitialized = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// 回调上下文:独立于 Impl 的 POD 结构,供静态回调使用
|
// 回调上下文:独立于 Impl 的 POD 结构,供静态回调使用
|
||||||
struct CallbackContext {
|
struct CallbackContext {
|
||||||
@ -52,13 +54,13 @@ struct AudioCapture::Impl {
|
|||||||
CallbackContext ctx;
|
CallbackContext ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef HAVE_PORTAUDIO
|
||||||
static int paCallback(const void* input, void* /*output*/,
|
static int paCallback(const void* input, void* /*output*/,
|
||||||
unsigned long frameCount,
|
unsigned long frameCount,
|
||||||
const PaStreamCallbackTimeInfo* /*timeInfo*/,
|
const PaStreamCallbackTimeInfo* /*timeInfo*/,
|
||||||
PaStreamCallbackFlags /*statusFlags*/,
|
PaStreamCallbackFlags /*statusFlags*/,
|
||||||
void* userData)
|
void* userData)
|
||||||
{
|
{
|
||||||
#ifdef HAVE_PORTAUDIO
|
|
||||||
auto* ctx = static_cast<CallbackContext*>(userData);
|
auto* ctx = static_cast<CallbackContext*>(userData);
|
||||||
|
|
||||||
const float* samples = static_cast<const float*>(input);
|
const float* samples = static_cast<const float*>(input);
|
||||||
@ -88,11 +90,15 @@ static int paCallback(const void* input, void* /*output*/,
|
|||||||
emit ctx->owner->audioDataReady(data, ctx->sampleRate);
|
emit ctx->owner->audioDataReady(data, ctx->sampleRate);
|
||||||
|
|
||||||
return paContinue;
|
return paContinue;
|
||||||
#else
|
|
||||||
(void)input; (void)frameCount; (void)userData;
|
|
||||||
return 0;
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
// 占位回调(无 PortAudio 时不使用)
|
||||||
|
static int paCallbackStub(const void*, void*, unsigned long,
|
||||||
|
const int*, int, void*)
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
AudioCapture::AudioCapture(QObject* parent)
|
AudioCapture::AudioCapture(QObject* parent)
|
||||||
: QObject(parent)
|
: QObject(parent)
|
||||||
@ -107,8 +113,8 @@ AudioCapture::~AudioCapture() {
|
|||||||
|
|
||||||
QStringList AudioCapture::getDeviceList() {
|
QStringList AudioCapture::getDeviceList() {
|
||||||
QStringList devices;
|
QStringList devices;
|
||||||
devices << "默认设备";
|
|
||||||
#ifdef HAVE_PORTAUDIO
|
#ifdef HAVE_PORTAUDIO
|
||||||
|
devices << "默认设备";
|
||||||
if (!ensurePaInitialized()) {
|
if (!ensurePaInitialized()) {
|
||||||
LOG_ERROR(kTag, "PortAudio 初始化失败");
|
LOG_ERROR(kTag, "PortAudio 初始化失败");
|
||||||
return devices;
|
return devices;
|
||||||
@ -124,6 +130,9 @@ QStringList AudioCapture::getDeviceList() {
|
|||||||
.arg(info->defaultSampleRate).arg(hostApiName);
|
.arg(info->defaultSampleRate).arg(hostApiName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
devices << "PortAudio 未启用(占位设备)";
|
||||||
|
LOG_WARNING(kTag, "PortAudio 未编译启用,设备列表为占位");
|
||||||
#endif
|
#endif
|
||||||
return devices;
|
return devices;
|
||||||
}
|
}
|
||||||
@ -195,7 +204,9 @@ bool AudioCapture::start(int deviceIndex, int sampleRate, int bufferSizeMs) {
|
|||||||
PaStreamParameters inputParams{};
|
PaStreamParameters inputParams{};
|
||||||
inputParams.device = devIdx;
|
inputParams.device = devIdx;
|
||||||
inputParams.channelCount = 1;
|
inputParams.channelCount = 1;
|
||||||
inputParams.sampleFormat = paFloat32 | paNonInterleaved;
|
inputParams.sampleFormat = paFloat32;
|
||||||
|
// 不使用 paNonInterleaved:input 指针直接是 float* 数组(interleaved mono),
|
||||||
|
// 回调中可以安全地 static_cast<const float*>(input)
|
||||||
// 使用高延迟以避免回调过快
|
// 使用高延迟以避免回调过快
|
||||||
inputParams.suggestedLatency = devInfo->defaultHighInputLatency;
|
inputParams.suggestedLatency = devInfo->defaultHighInputLatency;
|
||||||
|
|
||||||
@ -232,8 +243,9 @@ bool AudioCapture::start(int deviceIndex, int sampleRate, int bufferSizeMs) {
|
|||||||
.arg(deviceIndex).arg(sampleRate).arg(bufferSizeMs));
|
.arg(deviceIndex).arg(sampleRate).arg(bufferSizeMs));
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
LOG_ERROR(kTag, "PortAudio 未编译启用");
|
(void)deviceIndex; (void)sampleRate; (void)bufferSizeMs;
|
||||||
emit error("PortAudio 未编译启用");
|
LOG_ERROR(kTag, "PortAudio 未编译启用,无法启动采集");
|
||||||
|
emit error("PortAudio 未编译启用,请在 third_party/portaudio/ 中部署后重新编译");
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
62
src/core/ort_api_shim.h
Normal file
62
src/core/ort_api_shim.h
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
#pragma once
|
||||||
|
/**
|
||||||
|
* @brief ONNX Runtime C API Shim
|
||||||
|
*
|
||||||
|
* 包装 onnxruntime_c_api.h 以解决 MinGW 交叉编译兼容性问题:
|
||||||
|
* 1. SAL 注解(specstrings.h 中的宏与 MinGW 不兼容)
|
||||||
|
* 2. _stdcall 调用约定导致函数声明语法错误
|
||||||
|
*
|
||||||
|
* 关键:必须在包含 onnxruntime_c_api.h 之前定义这些宏,
|
||||||
|
* 因为 header 内部的 #define 会使用它们。
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
|
#ifndef ORT_API_SHIM_H
|
||||||
|
#define ORT_API_SHIM_H
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define ORT_DLL_IMPORT
|
||||||
|
|
||||||
|
/* 在 specstrings.h 被包含之前,抢先定义 SAL 注解为空。
|
||||||
|
onnxruntime_c_api.h 第 74 行 #include <specstrings.h>,
|
||||||
|
如果 specstrings.h 用 #ifndef 保护,我们的定义就不会被覆盖。
|
||||||
|
即使被覆盖,_stdcall 下面的定义也会生效。 */
|
||||||
|
#define _Success_(x)
|
||||||
|
#define _Check_return_
|
||||||
|
#define _Ret_maybenull_
|
||||||
|
#define _In_
|
||||||
|
#define _In_z_
|
||||||
|
#define _In_opt_
|
||||||
|
#define _In_opt_z_
|
||||||
|
#define _Out_
|
||||||
|
#define _Outptr_
|
||||||
|
#define _Out_opt_
|
||||||
|
#define _Inout_
|
||||||
|
#define _Inout_opt_
|
||||||
|
#define _Frees_ptr_opt_
|
||||||
|
#define _Ret_notnull_
|
||||||
|
#define _In_reads_(x)
|
||||||
|
#define _Inout_updates_(x)
|
||||||
|
#define _Out_writes_(x)
|
||||||
|
#define _Inout_updates_all_(x)
|
||||||
|
#define _Out_writes_bytes_all_(x)
|
||||||
|
#define _Out_writes_all_(x)
|
||||||
|
#define _Outptr_result_maybenull_(x)
|
||||||
|
#define _In_reads_opt_(x)
|
||||||
|
#define _Outptr_result_buffer_maybenull_(x)
|
||||||
|
#define _Return_type_success_(x)
|
||||||
|
#define _Out_writes_bytes_all_opt_(x)
|
||||||
|
#define _In_reads_bytes_(x)
|
||||||
|
|
||||||
|
/* 将 _stdcall 定义为空。
|
||||||
|
onnxruntime_c_api.h 第 86 行 #define ORT_API_CALL _stdcall,
|
||||||
|
所以当 ORT_API_CALL 展开为 _stdcall 后,_stdcall 再展开为空。
|
||||||
|
这样最终的函数声明没有调用约定修饰,MinGW 可以正常解析。 */
|
||||||
|
#define _stdcall
|
||||||
|
#define __stdcall
|
||||||
|
#endif /* _WIN32 */
|
||||||
|
|
||||||
|
#include <onnxruntime_c_api.h>
|
||||||
|
|
||||||
|
#endif /* ORT_API_SHIM_H */
|
||||||
|
#endif /* HAVE_ONNXRUNTIME */
|
||||||
365
src/core/ort_minimal.cpp
Normal file
365
src/core/ort_minimal.cpp
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
/**
|
||||||
|
* @brief ONNX Runtime 轻量级 C API 包装器 - 实现
|
||||||
|
*
|
||||||
|
* 直接使用 C API,避免 onnxruntime_cxx_api.h 的 MinGW 兼容性问题。
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
|
|
||||||
|
#include "ort_minimal.h"
|
||||||
|
#include <cstring>
|
||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <windows.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace ort {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 OrtApi 指针
|
||||||
|
*
|
||||||
|
* 调用路径: OrtGetApiBase() -> OrtApiBase -> GetApi(version) -> OrtApi
|
||||||
|
*/
|
||||||
|
const OrtApi* getApi() {
|
||||||
|
static const OrtApi* api = nullptr;
|
||||||
|
if (!api) {
|
||||||
|
const OrtApiBase* apiBase = OrtGetApiBase();
|
||||||
|
api = apiBase->GetApi(ORT_API_VERSION);
|
||||||
|
}
|
||||||
|
return api;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Env
|
||||||
|
// ============================================================================
|
||||||
|
Env::Env(OrtLoggingLevel logLevel, const char* logId) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtStatus* status = api->CreateEnv(logLevel, logId, &env_);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateEnv failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Env::~Env() {
|
||||||
|
if (env_) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
api->ReleaseEnv(env_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// SessionOptions
|
||||||
|
// ============================================================================
|
||||||
|
SessionOptions::SessionOptions() {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtStatus* status = api->CreateSessionOptions(&opts_);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateSessionOptions failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SessionOptions::~SessionOptions() {
|
||||||
|
if (opts_) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
api->ReleaseSessionOptions(opts_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SessionOptions::setIntraOpNumThreads(int n) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtStatus* status = api->SetIntraOpNumThreads(opts_, n);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "SetIntraOpNumThreads failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SessionOptions::setGraphOptimizationLevel(GraphOptimizationLevel level) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtStatus* status = api->SetSessionGraphOptimizationLevel(opts_, level);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "SetSessionGraphOptimizationLevel failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Session
|
||||||
|
// ============================================================================
|
||||||
|
Session::Session(const Env& env, const char* modelPath, const SessionOptions& opts) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
#ifdef _WIN32
|
||||||
|
// Windows: ORTCHAR_T = wchar_t, 需要将 UTF-8 转换为 UTF-16
|
||||||
|
int len = MultiByteToWideChar(CP_UTF8, 0, modelPath, -1, nullptr, 0);
|
||||||
|
std::vector<wchar_t> widePath(len);
|
||||||
|
MultiByteToWideChar(CP_UTF8, 0, modelPath, -1, widePath.data(), len);
|
||||||
|
OrtStatus* status = api->CreateSession(env.ptr(), widePath.data(), opts.ptr(), &session_);
|
||||||
|
#else
|
||||||
|
OrtStatus* status = api->CreateSession(env.ptr(), modelPath, opts.ptr(), &session_);
|
||||||
|
#endif
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateSession failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
Session::Session(const Env& env, const wchar_t* modelPath, const SessionOptions& opts) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtStatus* status = api->CreateSession(env.ptr(), modelPath, opts.ptr(), &session_);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateSession failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Session::~Session() {
|
||||||
|
if (session_) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
api->ReleaseSession(session_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Session::getInputCount() const {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
size_t count = 0;
|
||||||
|
OrtStatus* status = api->SessionGetInputCount(session_, &count);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "SessionGetInputCount failed");
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Session::getOutputCount() const {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
size_t count = 0;
|
||||||
|
OrtStatus* status = api->SessionGetOutputCount(session_, &count);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "SessionGetOutputCount failed");
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Session::getInputName(size_t index) const {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
char* name = nullptr;
|
||||||
|
OrtAllocator* allocator = nullptr;
|
||||||
|
OrtStatus* status = api->GetAllocatorWithDefaultOptions(&allocator);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "GetAllocatorWithDefaultOptions failed");
|
||||||
|
}
|
||||||
|
status = api->SessionGetInputName(session_, index, allocator, &name);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "SessionGetInputName failed");
|
||||||
|
}
|
||||||
|
std::string result(name);
|
||||||
|
// 不要释放 name - allocator 分配的内存由 allocator 管理
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Session::getOutputName(size_t index) const {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
char* name = nullptr;
|
||||||
|
OrtAllocator* allocator = nullptr;
|
||||||
|
OrtStatus* status = api->GetAllocatorWithDefaultOptions(&allocator);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "GetAllocatorWithDefaultOptions failed");
|
||||||
|
}
|
||||||
|
status = api->SessionGetOutputName(session_, index, allocator, &name);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "SessionGetOutputName failed");
|
||||||
|
}
|
||||||
|
std::string result(name);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// MemoryInfo
|
||||||
|
// ============================================================================
|
||||||
|
MemoryInfo::MemoryInfo(OrtMemoryInfo* info) : info_(info) {}
|
||||||
|
|
||||||
|
MemoryInfo MemoryInfo::createCpu(OrtAllocatorType type, OrtMemType memType) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtMemoryInfo* info = nullptr;
|
||||||
|
OrtStatus* status = api->CreateCpuMemoryInfo(type, memType, &info);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateCpuMemoryInfo failed");
|
||||||
|
}
|
||||||
|
return MemoryInfo(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
MemoryInfo::~MemoryInfo() {
|
||||||
|
if (info_) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
api->ReleaseMemoryInfo(info_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Value
|
||||||
|
// ============================================================================
|
||||||
|
Value::Value(OrtValue* value) : value_(value) {}
|
||||||
|
|
||||||
|
Value Value::fromRaw(OrtValue* value) {
|
||||||
|
return Value(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value Value::createTensor(const MemoryInfo& info, float* data, size_t elemCount,
|
||||||
|
const int64_t* shape, size_t shapeLen) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtValue* value = nullptr;
|
||||||
|
OrtStatus* status = api->CreateTensorWithDataAsOrtValue(
|
||||||
|
info.ptr(), data, elemCount * sizeof(float),
|
||||||
|
shape, shapeLen, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateTensorWithDataAsOrtValue failed");
|
||||||
|
}
|
||||||
|
return Value(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value Value::createTensor(const MemoryInfo& info, int32_t* data, size_t elemCount,
|
||||||
|
const int64_t* shape, size_t shapeLen) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtValue* value = nullptr;
|
||||||
|
OrtStatus* status = api->CreateTensorWithDataAsOrtValue(
|
||||||
|
info.ptr(), data, elemCount * sizeof(int32_t),
|
||||||
|
shape, shapeLen, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, &value);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateTensorWithDataAsOrtValue failed");
|
||||||
|
}
|
||||||
|
return Value(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value::~Value() {
|
||||||
|
if (value_) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
api->ReleaseValue(value_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> Value::getShape() const {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtTensorTypeAndShapeInfo* info = nullptr;
|
||||||
|
OrtStatus* status = api->GetTensorTypeAndShape(value_, &info);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "GetTensorTypeAndShape failed");
|
||||||
|
}
|
||||||
|
size_t dimCount = 0;
|
||||||
|
status = api->GetDimensionsCount(info, &dimCount);
|
||||||
|
std::vector<int64_t> shape(dimCount);
|
||||||
|
if (dimCount > 0) {
|
||||||
|
status = api->GetDimensions(info, shape.data(), dimCount);
|
||||||
|
}
|
||||||
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "GetDimensions failed");
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float* Value::getTensorData() const {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
void* data = nullptr;
|
||||||
|
OrtStatus* status = api->GetTensorMutableData(value_, &data);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "GetTensorMutableData failed");
|
||||||
|
}
|
||||||
|
return static_cast<const float*>(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value::Value(Value&& other) noexcept : value_(other.value_) {
|
||||||
|
other.value_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value& Value::operator=(Value&& other) noexcept {
|
||||||
|
if (this != &other) {
|
||||||
|
if (value_) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
api->ReleaseValue(value_);
|
||||||
|
}
|
||||||
|
value_ = other.value_;
|
||||||
|
other.value_ = nullptr;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// RunOptions
|
||||||
|
// ============================================================================
|
||||||
|
RunOptions::RunOptions() {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
OrtStatus* status = api->CreateRunOptions(&opts_);
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "CreateRunOptions failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RunOptions::~RunOptions() {
|
||||||
|
if (opts_) {
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
api->ReleaseRunOptions(opts_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// run
|
||||||
|
// ============================================================================
|
||||||
|
std::vector<Value> run(Session& session,
|
||||||
|
const RunOptions& runOptions,
|
||||||
|
const char* const* inputNames,
|
||||||
|
Value* inputValues,
|
||||||
|
size_t inputCount,
|
||||||
|
const char* const* outputNames,
|
||||||
|
size_t outputCount)
|
||||||
|
{
|
||||||
|
const OrtApi* api = getApi();
|
||||||
|
|
||||||
|
// 准备输入输出 OrtValue 指针数组
|
||||||
|
std::vector<OrtValue*> inputPtrs(inputCount);
|
||||||
|
for (size_t i = 0; i < inputCount; i++) {
|
||||||
|
inputPtrs[i] = inputValues[i].ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<OrtValue*> outputPtrs(outputCount, nullptr);
|
||||||
|
|
||||||
|
OrtStatus* status = api->Run(
|
||||||
|
session.ptr(),
|
||||||
|
runOptions.ptr(),
|
||||||
|
inputNames, inputPtrs.data(), static_cast<int>(inputCount),
|
||||||
|
outputNames, static_cast<int>(outputCount),
|
||||||
|
outputPtrs.data());
|
||||||
|
|
||||||
|
if (status) {
|
||||||
|
const char* msg = api->GetErrorMessage(status);
|
||||||
|
throw Exception(msg ? msg : "Run failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value> results;
|
||||||
|
results.reserve(outputCount);
|
||||||
|
for (size_t i = 0; i < outputCount; i++) {
|
||||||
|
results.emplace_back(Value::fromRaw(outputPtrs[i]));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ort
|
||||||
|
|
||||||
|
#endif // HAVE_ONNXRUNTIME
|
||||||
134
src/core/ort_minimal.h
Normal file
134
src/core/ort_minimal.h
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
#pragma once
|
||||||
|
/**
|
||||||
|
* @brief ONNX Runtime 轻量级 C API 包装器
|
||||||
|
*
|
||||||
|
* 替代 onnxruntime_cxx_api.h(该文件与 MinGW 存在 ABI 兼容性问题)。
|
||||||
|
* 直接使用 C API(onnxruntime_c_api.h),用异常替代 C 风格返回值。
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
|
|
||||||
|
#ifndef ORT_MINIMAL_H
|
||||||
|
#define ORT_MINIMAL_H
|
||||||
|
|
||||||
|
/* 使用 shim 头文件处理 MinGW 兼容性问题 */
|
||||||
|
#include "ort_api_shim.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
namespace ort {
|
||||||
|
|
||||||
|
/** 异常类型 */
|
||||||
|
class Exception : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
explicit Exception(const char* msg) : std::runtime_error(msg) {}
|
||||||
|
explicit Exception(const std::string& msg) : std::runtime_error(msg) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/** 获取 API 基础指针(内部使用) */
|
||||||
|
const OrtApi* getApi();
|
||||||
|
|
||||||
|
/** Env: ONNX Runtime 环境 */
|
||||||
|
class Env {
|
||||||
|
public:
|
||||||
|
explicit Env(OrtLoggingLevel logLevel = ORT_LOGGING_LEVEL_WARNING,
|
||||||
|
const char* logId = "ort");
|
||||||
|
~Env();
|
||||||
|
OrtEnv* ptr() const { return env_; }
|
||||||
|
Env(const Env&) = delete;
|
||||||
|
Env& operator=(const Env&) = delete;
|
||||||
|
private:
|
||||||
|
OrtEnv* env_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** SessionOptions: 会话配置选项 */
|
||||||
|
class SessionOptions {
|
||||||
|
public:
|
||||||
|
SessionOptions();
|
||||||
|
~SessionOptions();
|
||||||
|
OrtSessionOptions* ptr() const { return opts_; }
|
||||||
|
void setIntraOpNumThreads(int n);
|
||||||
|
void setGraphOptimizationLevel(GraphOptimizationLevel level);
|
||||||
|
SessionOptions(const SessionOptions&) = delete;
|
||||||
|
SessionOptions& operator=(const SessionOptions&) = delete;
|
||||||
|
private:
|
||||||
|
OrtSessionOptions* opts_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Session: 推理会话 */
|
||||||
|
class Session {
|
||||||
|
public:
|
||||||
|
Session(const Env& env, const char* modelPath, const SessionOptions& opts);
|
||||||
|
Session(const Env& env, const wchar_t* modelPath, const SessionOptions& opts);
|
||||||
|
~Session();
|
||||||
|
OrtSession* ptr() const { return session_; }
|
||||||
|
size_t getInputCount() const;
|
||||||
|
size_t getOutputCount() const;
|
||||||
|
std::string getInputName(size_t index) const;
|
||||||
|
std::string getOutputName(size_t index) const;
|
||||||
|
Session(const Session&) = delete;
|
||||||
|
Session& operator=(const Session&) = delete;
|
||||||
|
private:
|
||||||
|
OrtSession* session_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** MemoryInfo: 内存信息 */
|
||||||
|
class MemoryInfo {
|
||||||
|
public:
|
||||||
|
static MemoryInfo createCpu(OrtAllocatorType type = OrtDeviceAllocator,
|
||||||
|
OrtMemType memType = OrtMemTypeCPU);
|
||||||
|
~MemoryInfo();
|
||||||
|
const OrtMemoryInfo* ptr() const { return info_; }
|
||||||
|
private:
|
||||||
|
explicit MemoryInfo(OrtMemoryInfo* info);
|
||||||
|
OrtMemoryInfo* info_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Value: 张量值 */
|
||||||
|
class Value {
|
||||||
|
public:
|
||||||
|
static Value createTensor(const MemoryInfo& info, float* data, size_t elemCount,
|
||||||
|
const int64_t* shape, size_t shapeLen);
|
||||||
|
static Value createTensor(const MemoryInfo& info, int32_t* data, size_t elemCount,
|
||||||
|
const int64_t* shape, size_t shapeLen);
|
||||||
|
~Value();
|
||||||
|
OrtValue* ptr() const { return value_; }
|
||||||
|
std::vector<int64_t> getShape() const;
|
||||||
|
const float* getTensorData() const;
|
||||||
|
Value(Value&& other) noexcept;
|
||||||
|
Value& operator=(Value&& other) noexcept;
|
||||||
|
Value(const Value&) = delete;
|
||||||
|
Value& operator=(const Value&) = delete;
|
||||||
|
private:
|
||||||
|
explicit Value(OrtValue* value);
|
||||||
|
OrtValue* value_ = nullptr;
|
||||||
|
public:
|
||||||
|
/** @brief 从原始指针构造(用于接收 C API 返回的值) */
|
||||||
|
static Value fromRaw(OrtValue* value);
|
||||||
|
};
|
||||||
|
|
||||||
|
/** RunOptions: 推理选项 */
|
||||||
|
class RunOptions {
|
||||||
|
public:
|
||||||
|
RunOptions();
|
||||||
|
~RunOptions();
|
||||||
|
OrtRunOptions* ptr() const { return opts_; }
|
||||||
|
private:
|
||||||
|
OrtRunOptions* opts_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** 推理执行 */
|
||||||
|
std::vector<Value> run(Session& session,
|
||||||
|
const RunOptions& runOptions,
|
||||||
|
const char* const* inputNames,
|
||||||
|
Value* inputValues,
|
||||||
|
size_t inputCount,
|
||||||
|
const char* const* outputNames,
|
||||||
|
size_t outputCount);
|
||||||
|
|
||||||
|
} // namespace ort
|
||||||
|
|
||||||
|
#endif // ORT_MINIMAL_H
|
||||||
|
#endif // HAVE_ONNXRUNTIME
|
||||||
@ -3,6 +3,7 @@
|
|||||||
#include "sense_voice_tokenizer.h"
|
#include "sense_voice_tokenizer.h"
|
||||||
#include "sense_voice_cmvn.h"
|
#include "sense_voice_cmvn.h"
|
||||||
#include "audio_processor.h"
|
#include "audio_processor.h"
|
||||||
|
#include "ort_minimal.h"
|
||||||
#include "utils/logger.h"
|
#include "utils/logger.h"
|
||||||
#include "utils/timer.h"
|
#include "utils/timer.h"
|
||||||
|
|
||||||
@ -19,11 +20,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
// ONNX Runtime headers
|
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
|
||||||
#include <onnxruntime_cxx_api.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static const char* const kTag = "SenseVoiceEngine";
|
static const char* const kTag = "SenseVoiceEngine";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -93,9 +89,9 @@ static int languageToInt(const QString& lang) {
|
|||||||
*/
|
*/
|
||||||
struct SenseVoiceEngine::Impl {
|
struct SenseVoiceEngine::Impl {
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
std::unique_ptr<Ort::Env> env;
|
std::unique_ptr<ort::Env> env;
|
||||||
std::unique_ptr<Ort::SessionOptions> sessionOptions;
|
std::unique_ptr<ort::SessionOptions> sessionOptions;
|
||||||
std::unique_ptr<Ort::Session> session;
|
std::unique_ptr<ort::Session> session;
|
||||||
|
|
||||||
std::vector<std::string> inputNames;
|
std::vector<std::string> inputNames;
|
||||||
std::vector<std::string> outputNames;
|
std::vector<std::string> outputNames;
|
||||||
@ -111,11 +107,11 @@ struct SenseVoiceEngine::Impl {
|
|||||||
{
|
{
|
||||||
QMutexLocker locker(&mutex);
|
QMutexLocker locker(&mutex);
|
||||||
try {
|
try {
|
||||||
auto envPtr = std::make_unique<Ort::Env>(
|
auto envPtr = std::make_unique<ort::Env>(
|
||||||
ORT_LOGGING_LEVEL_WARNING, "impress_sensevoice");
|
ORT_LOGGING_LEVEL_WARNING, "impress_sensevoice");
|
||||||
auto optionsPtr = std::make_unique<Ort::SessionOptions>();
|
auto optionsPtr = std::make_unique<ort::SessionOptions>();
|
||||||
optionsPtr->SetIntraOpNumThreads(numThreads);
|
optionsPtr->setIntraOpNumThreads(numThreads);
|
||||||
optionsPtr->SetGraphOptimizationLevel(
|
optionsPtr->setGraphOptimizationLevel(
|
||||||
GraphOptimizationLevel::ORT_ENABLE_ALL);
|
GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||||
|
|
||||||
if (device == "gpu") {
|
if (device == "gpu") {
|
||||||
@ -125,14 +121,17 @@ struct SenseVoiceEngine::Impl {
|
|||||||
LOG_INFO(kTag, QString("正在加载 SenseVoice 模型: %1 (线程: %2)")
|
LOG_INFO(kTag, QString("正在加载 SenseVoice 模型: %1 (线程: %2)")
|
||||||
.arg(modelPath).arg(numThreads));
|
.arg(modelPath).arg(numThreads));
|
||||||
|
|
||||||
auto sessionPtr = std::make_unique<Ort::Session>(
|
auto sessionPtr = std::make_unique<ort::Session>(
|
||||||
*envPtr,
|
*envPtr,
|
||||||
|
#ifdef _WIN32
|
||||||
|
modelPath.toStdWString().c_str(),
|
||||||
|
#else
|
||||||
modelPath.toUtf8().constData(),
|
modelPath.toUtf8().constData(),
|
||||||
|
#endif
|
||||||
*optionsPtr);
|
*optionsPtr);
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
size_t inputCount = sessionPtr->getInputCount();
|
||||||
size_t inputCount = sessionPtr->GetInputCount();
|
size_t outputCount = sessionPtr->getOutputCount();
|
||||||
size_t outputCount = sessionPtr->GetOutputCount();
|
|
||||||
|
|
||||||
LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出")
|
LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出")
|
||||||
.arg(inputCount).arg(outputCount));
|
.arg(inputCount).arg(outputCount));
|
||||||
@ -141,15 +140,13 @@ struct SenseVoiceEngine::Impl {
|
|||||||
outputNames.clear();
|
outputNames.clear();
|
||||||
|
|
||||||
for (size_t i = 0; i < inputCount; i++) {
|
for (size_t i = 0; i < inputCount; i++) {
|
||||||
auto namePtr = sessionPtr->GetInputNameAllocated(i, allocator);
|
inputNames.emplace_back(sessionPtr->getInputName(i));
|
||||||
inputNames.emplace_back(namePtr.get());
|
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(QString::fromStdString(inputNames.back())));
|
||||||
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(namePtr.get()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < outputCount; i++) {
|
for (size_t i = 0; i < outputCount; i++) {
|
||||||
auto namePtr = sessionPtr->GetOutputNameAllocated(i, allocator);
|
outputNames.emplace_back(sessionPtr->getOutputName(i));
|
||||||
outputNames.emplace_back(namePtr.get());
|
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(QString::fromStdString(outputNames.back())));
|
||||||
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(namePtr.get()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
env = std::move(envPtr);
|
env = std::move(envPtr);
|
||||||
@ -174,7 +171,7 @@ struct SenseVoiceEngine::Impl {
|
|||||||
|
|
||||||
LOG_INFO(kTag, QString("SenseVoice 模型加载成功: %1").arg(modelPath));
|
LOG_INFO(kTag, QString("SenseVoice 模型加载成功: %1").arg(modelPath));
|
||||||
return true;
|
return true;
|
||||||
} catch (const Ort::Exception& e) {
|
} catch (const ort::Exception& e) {
|
||||||
errorMsg = QString("ONNX 异常: %1").arg(e.what());
|
errorMsg = QString("ONNX 异常: %1").arg(e.what());
|
||||||
LOG_ERROR(kTag, errorMsg);
|
LOG_ERROR(kTag, errorMsg);
|
||||||
return false;
|
return false;
|
||||||
@ -185,6 +182,21 @@ struct SenseVoiceEngine::Impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QMutex mutex;
|
||||||
|
#else
|
||||||
|
// 占位实现:无 ONNX Runtime 时仅提供基本结构
|
||||||
|
bool loadInWorker(const QString& /*modelPath*/,
|
||||||
|
const QString& /*tokensPath*/,
|
||||||
|
const QString& /*device*/,
|
||||||
|
int /*numThreads*/,
|
||||||
|
QString& errorMsg)
|
||||||
|
{
|
||||||
|
errorMsg = "ONNX Runtime 未安装,推理功能不可用。"
|
||||||
|
"请在 third_party/onnxruntime/ 中部署 ONNX Runtime 后重新编译。";
|
||||||
|
LOG_ERROR(kTag, errorMsg);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
QMutex mutex;
|
QMutex mutex;
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
@ -206,11 +218,13 @@ bool SenseVoiceEngine::loadModelSync(const QString& modelPath,
|
|||||||
if (loaded_) {
|
if (loaded_) {
|
||||||
LOG_WARNING(kTag, "模型已加载,先卸载再加载");
|
LOG_WARNING(kTag, "模型已加载,先卸载再加载");
|
||||||
// 内联清理,避免调用 unloadModel() 导致 mutex 递归死锁
|
// 内联清理,避免调用 unloadModel() 导致 mutex 递归死锁
|
||||||
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
impl_->session.reset();
|
impl_->session.reset();
|
||||||
impl_->sessionOptions.reset();
|
impl_->sessionOptions.reset();
|
||||||
impl_->env.reset();
|
impl_->env.reset();
|
||||||
impl_->features.reset();
|
impl_->features.reset();
|
||||||
impl_->tokenizer = SenseVoiceTokenizer();
|
impl_->tokenizer = SenseVoiceTokenizer();
|
||||||
|
#endif
|
||||||
loaded_ = false;
|
loaded_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,11 +249,13 @@ void SenseVoiceEngine::loadModelAsync(const QString& modelPath,
|
|||||||
if (loaded_) {
|
if (loaded_) {
|
||||||
LOG_WARNING(kTag, "模型已加载,先卸载再加载");
|
LOG_WARNING(kTag, "模型已加载,先卸载再加载");
|
||||||
// 内联清理,避免调用 unloadModel() 导致 mutex 递归死锁
|
// 内联清理,避免调用 unloadModel() 导致 mutex 递归死锁
|
||||||
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
impl_->session.reset();
|
impl_->session.reset();
|
||||||
impl_->sessionOptions.reset();
|
impl_->sessionOptions.reset();
|
||||||
impl_->env.reset();
|
impl_->env.reset();
|
||||||
impl_->features.reset();
|
impl_->features.reset();
|
||||||
impl_->tokenizer = SenseVoiceTokenizer();
|
impl_->tokenizer = SenseVoiceTokenizer();
|
||||||
|
#endif
|
||||||
loaded_ = false;
|
loaded_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -417,7 +433,7 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
|||||||
|
|
||||||
// 输入: x, x_length, language, text_norm
|
// 输入: x, x_length, language, text_norm
|
||||||
int64_t xShape[] = {1, numFrames, kLFROutputDim};
|
int64_t xShape[] = {1, numFrames, kLFROutputDim};
|
||||||
auto memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
auto memInfo = ort::MemoryInfo::createCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||||
|
|
||||||
int32_t xLengthVal = numFrames;
|
int32_t xLengthVal = numFrames;
|
||||||
int64_t xLengthShape[] = {1};
|
int64_t xLengthShape[] = {1};
|
||||||
@ -429,14 +445,14 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
|||||||
int32_t textNormVal = kTextNormWithITN;
|
int32_t textNormVal = kTextNormWithITN;
|
||||||
int64_t textNormShape[] = {1};
|
int64_t textNormShape[] = {1};
|
||||||
|
|
||||||
std::vector<Ort::Value> inputTensors;
|
std::vector<ort::Value> inputTensors;
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(
|
inputTensors.push_back(ort::Value::createTensor(
|
||||||
memInfo, lfrFeatures.data(), lfrFeatures.size(), xShape, 3));
|
memInfo, lfrFeatures.data(), lfrFeatures.size(), xShape, 3));
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<int32_t>(
|
inputTensors.push_back(ort::Value::createTensor(
|
||||||
memInfo, &xLengthVal, 1, xLengthShape, 1));
|
memInfo, &xLengthVal, 1, xLengthShape, 1));
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<int32_t>(
|
inputTensors.push_back(ort::Value::createTensor(
|
||||||
memInfo, &langVal, 1, langShape, 1));
|
memInfo, &langVal, 1, langShape, 1));
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<int32_t>(
|
inputTensors.push_back(ort::Value::createTensor(
|
||||||
memInfo, &textNormVal, 1, textNormShape, 1));
|
memInfo, &textNormVal, 1, textNormShape, 1));
|
||||||
|
|
||||||
// 4. 运行推理
|
// 4. 运行推理
|
||||||
@ -446,8 +462,10 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
|||||||
std::vector<const char*> outputNamePtrs;
|
std::vector<const char*> outputNamePtrs;
|
||||||
for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str());
|
for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str());
|
||||||
|
|
||||||
auto outputTensors = impl_->session->Run(
|
ort::RunOptions runOptions;
|
||||||
Ort::RunOptions{nullptr},
|
auto outputTensors = ort::run(
|
||||||
|
*impl_->session,
|
||||||
|
runOptions,
|
||||||
inputNamePtrs.data(), inputTensors.data(), inputTensors.size(),
|
inputNamePtrs.data(), inputTensors.data(), inputTensors.size(),
|
||||||
outputNamePtrs.data(), outputNamePtrs.size());
|
outputNamePtrs.data(), outputNamePtrs.size());
|
||||||
|
|
||||||
@ -455,8 +473,8 @@ RecognitionResult SenseVoiceEngine::infer(const std::vector<float>& samples,
|
|||||||
|
|
||||||
// 5. 解析输出 logits [1, seq_len, 25055]
|
// 5. 解析输出 logits [1, seq_len, 25055]
|
||||||
auto& outputTensor = outputTensors[0];
|
auto& outputTensor = outputTensors[0];
|
||||||
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
auto shape = outputTensor.getShape();
|
||||||
const float* logitsData = outputTensor.GetTensorData<float>();
|
const float* logitsData = outputTensor.getTensorData();
|
||||||
|
|
||||||
LOG_DEBUG(kTag, QString("输出维度: [%1, %2, %3]")
|
LOG_DEBUG(kTag, QString("输出维度: [%1, %2, %3]")
|
||||||
.arg(shape[0]).arg(shape[1]).arg(shape[2]));
|
.arg(shape[0]).arg(shape[1]).arg(shape[2]));
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
#include "mel_spectrogram.h"
|
#include "mel_spectrogram.h"
|
||||||
#include "whisper_tokenizer.h"
|
#include "whisper_tokenizer.h"
|
||||||
#include "audio_processor.h"
|
#include "audio_processor.h"
|
||||||
|
#include "ort_minimal.h"
|
||||||
#include "utils/logger.h"
|
#include "utils/logger.h"
|
||||||
#include "utils/timer.h"
|
#include "utils/timer.h"
|
||||||
|
|
||||||
@ -15,11 +16,6 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
// ONNX Runtime headers
|
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
|
||||||
#include <onnxruntime_cxx_api.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static const char* const kTag = "STTEngine";
|
static const char* const kTag = "STTEngine";
|
||||||
|
|
||||||
// Whisper 常量
|
// Whisper 常量
|
||||||
@ -33,9 +29,9 @@ namespace impress {
|
|||||||
*/
|
*/
|
||||||
struct STTEngine::Impl {
|
struct STTEngine::Impl {
|
||||||
#ifdef HAVE_ONNXRUNTIME
|
#ifdef HAVE_ONNXRUNTIME
|
||||||
std::unique_ptr<Ort::Env> env;
|
std::unique_ptr<ort::Env> env;
|
||||||
std::unique_ptr<Ort::SessionOptions> sessionOptions;
|
std::unique_ptr<ort::SessionOptions> sessionOptions;
|
||||||
std::unique_ptr<Ort::Session> session;
|
std::unique_ptr<ort::Session> session;
|
||||||
|
|
||||||
std::vector<std::string> inputNames;
|
std::vector<std::string> inputNames;
|
||||||
std::vector<std::string> outputNames;
|
std::vector<std::string> outputNames;
|
||||||
@ -50,11 +46,11 @@ struct STTEngine::Impl {
|
|||||||
{
|
{
|
||||||
QMutexLocker locker(&mutex);
|
QMutexLocker locker(&mutex);
|
||||||
try {
|
try {
|
||||||
auto envPtr = std::make_unique<Ort::Env>(
|
auto envPtr = std::make_unique<ort::Env>(
|
||||||
ORT_LOGGING_LEVEL_WARNING, "impress_voice");
|
ORT_LOGGING_LEVEL_WARNING, "impress_voice");
|
||||||
auto optionsPtr = std::make_unique<Ort::SessionOptions>();
|
auto optionsPtr = std::make_unique<ort::SessionOptions>();
|
||||||
optionsPtr->SetIntraOpNumThreads(numThreads);
|
optionsPtr->setIntraOpNumThreads(numThreads);
|
||||||
optionsPtr->SetGraphOptimizationLevel(
|
optionsPtr->setGraphOptimizationLevel(
|
||||||
GraphOptimizationLevel::ORT_ENABLE_ALL);
|
GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||||
|
|
||||||
if (device == "gpu") {
|
if (device == "gpu") {
|
||||||
@ -63,14 +59,17 @@ struct STTEngine::Impl {
|
|||||||
|
|
||||||
LOG_INFO(kTag, QString("正在加载模型: %1 (线程: %2)").arg(modelPath).arg(numThreads));
|
LOG_INFO(kTag, QString("正在加载模型: %1 (线程: %2)").arg(modelPath).arg(numThreads));
|
||||||
|
|
||||||
auto sessionPtr = std::make_unique<Ort::Session>(
|
auto sessionPtr = std::make_unique<ort::Session>(
|
||||||
*envPtr,
|
*envPtr,
|
||||||
|
#ifdef _WIN32
|
||||||
|
modelPath.toStdWString().c_str(),
|
||||||
|
#else
|
||||||
modelPath.toUtf8().constData(),
|
modelPath.toUtf8().constData(),
|
||||||
|
#endif
|
||||||
*optionsPtr);
|
*optionsPtr);
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
size_t inputCount = sessionPtr->getInputCount();
|
||||||
size_t inputCount = sessionPtr->GetInputCount();
|
size_t outputCount = sessionPtr->getOutputCount();
|
||||||
size_t outputCount = sessionPtr->GetOutputCount();
|
|
||||||
|
|
||||||
LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出")
|
LOG_INFO(kTag, QString("模型有 %1 个输入, %2 个输出")
|
||||||
.arg(inputCount).arg(outputCount));
|
.arg(inputCount).arg(outputCount));
|
||||||
@ -79,15 +78,13 @@ struct STTEngine::Impl {
|
|||||||
outputNames.clear();
|
outputNames.clear();
|
||||||
|
|
||||||
for (size_t i = 0; i < inputCount; i++) {
|
for (size_t i = 0; i < inputCount; i++) {
|
||||||
auto namePtr = sessionPtr->GetInputNameAllocated(i, allocator);
|
inputNames.emplace_back(sessionPtr->getInputName(i));
|
||||||
inputNames.emplace_back(namePtr.get());
|
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(QString::fromStdString(inputNames.back())));
|
||||||
LOG_DEBUG(kTag, QString("输入 #%1: %2").arg(i).arg(namePtr.get()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < outputCount; i++) {
|
for (size_t i = 0; i < outputCount; i++) {
|
||||||
auto namePtr = sessionPtr->GetOutputNameAllocated(i, allocator);
|
outputNames.emplace_back(sessionPtr->getOutputName(i));
|
||||||
outputNames.emplace_back(namePtr.get());
|
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(QString::fromStdString(outputNames.back())));
|
||||||
LOG_DEBUG(kTag, QString("输出 #%1: %2").arg(i).arg(namePtr.get()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
env = std::move(envPtr);
|
env = std::move(envPtr);
|
||||||
@ -106,7 +103,7 @@ struct STTEngine::Impl {
|
|||||||
|
|
||||||
LOG_INFO(kTag, QString("模型加载成功: %1").arg(modelPath));
|
LOG_INFO(kTag, QString("模型加载成功: %1").arg(modelPath));
|
||||||
return true;
|
return true;
|
||||||
} catch (const Ort::Exception& e) {
|
} catch (const ort::Exception& e) {
|
||||||
errorMsg = QString("ONNX 异常: %1").arg(e.what());
|
errorMsg = QString("ONNX 异常: %1").arg(e.what());
|
||||||
LOG_ERROR(kTag, errorMsg);
|
LOG_ERROR(kTag, errorMsg);
|
||||||
return false;
|
return false;
|
||||||
@ -117,6 +114,20 @@ struct STTEngine::Impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QMutex mutex;
|
||||||
|
#else
|
||||||
|
// 占位实现:无 ONNX Runtime 时仅提供基本结构
|
||||||
|
bool loadInWorker(const QString& /*modelPath*/,
|
||||||
|
const QString& /*device*/,
|
||||||
|
int /*numThreads*/,
|
||||||
|
QString& errorMsg)
|
||||||
|
{
|
||||||
|
errorMsg = "ONNX Runtime 未安装,推理功能不可用。"
|
||||||
|
"请在 third_party/onnxruntime/ 中部署 ONNX Runtime 后重新编译。";
|
||||||
|
LOG_ERROR(kTag, errorMsg);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
QMutex mutex;
|
QMutex mutex;
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
@ -276,9 +287,10 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
QMutexLocker locker(&impl_->mutex);
|
QMutexLocker locker(&impl_->mutex);
|
||||||
|
|
||||||
int64_t melShape[] = {1, kMelBins, static_cast<int64_t>(nFrames)};
|
int64_t melShape[] = {1, kMelBins, static_cast<int64_t>(nFrames)};
|
||||||
auto memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
auto memInfo = ort::MemoryInfo::createCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||||
std::vector<Ort::Value> inputTensors;
|
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(
|
std::vector<ort::Value> inputTensors;
|
||||||
|
inputTensors.push_back(ort::Value::createTensor(
|
||||||
memInfo, melSpec.data(), melSpec.size(), melShape, 3));
|
memInfo, melSpec.data(), melSpec.size(), melShape, 3));
|
||||||
|
|
||||||
std::vector<const char*> inputNamePtrs;
|
std::vector<const char*> inputNamePtrs;
|
||||||
@ -286,8 +298,10 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
std::vector<const char*> outputNamePtrs;
|
std::vector<const char*> outputNamePtrs;
|
||||||
for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str());
|
for (auto& name : impl_->outputNames) outputNamePtrs.push_back(name.c_str());
|
||||||
|
|
||||||
auto outputTensors = impl_->session->Run(
|
ort::RunOptions runOptions;
|
||||||
Ort::RunOptions{nullptr},
|
auto outputTensors = ort::run(
|
||||||
|
*impl_->session,
|
||||||
|
runOptions,
|
||||||
inputNamePtrs.data(), inputTensors.data(), inputTensors.size(),
|
inputNamePtrs.data(), inputTensors.data(), inputTensors.size(),
|
||||||
outputNamePtrs.data(), impl_->outputNames.size());
|
outputNamePtrs.data(), impl_->outputNames.size());
|
||||||
|
|
||||||
@ -295,8 +309,8 @@ RecognitionResult STTEngine::infer(const std::vector<float>& samples,
|
|||||||
|
|
||||||
// 4. 解析输出
|
// 4. 解析输出
|
||||||
auto& outputTensor = outputTensors[0];
|
auto& outputTensor = outputTensors[0];
|
||||||
auto shape = outputTensor.GetTensorTypeAndShapeInfo().GetShape();
|
auto shape = outputTensor.getShape();
|
||||||
const float* outputData = outputTensor.GetTensorMutableData<float>();
|
const float* outputData = outputTensor.getTensorData();
|
||||||
|
|
||||||
LOG_DEBUG(kTag, QString("输出维度: %1").arg(shape.size()));
|
LOG_DEBUG(kTag, QString("输出维度: %1").arg(shape.size()));
|
||||||
for (size_t i = 0; i < shape.size(); i++) {
|
for (size_t i = 0; i < shape.size(); i++) {
|
||||||
|
|||||||
@ -5,6 +5,8 @@
|
|||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
#include <QAbstractNativeEventFilter>
|
#include <QAbstractNativeEventFilter>
|
||||||
#include <QGuiApplication>
|
#include <QGuiApplication>
|
||||||
|
#include <QThread>
|
||||||
|
#include <QWidget>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static const char* const kTag = "CapsLockVoiceHotkey";
|
static const char* const kTag = "CapsLockVoiceHotkey";
|
||||||
@ -20,6 +22,7 @@ struct CapsLockVoiceHotkey::Impl {
|
|||||||
bool longPressFired = false;
|
bool longPressFired = false;
|
||||||
bool pollThreadRunning = false;
|
bool pollThreadRunning = false;
|
||||||
void* nativeEventFilter = nullptr;
|
void* nativeEventFilter = nullptr;
|
||||||
|
QWidget* hiddenWindow = nullptr;
|
||||||
static constexpr int kLongPressMs = 1000;
|
static constexpr int kLongPressMs = 1000;
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
@ -63,14 +66,19 @@ bool CapsLockVoiceHotkey::start() {
|
|||||||
if (active_) return true;
|
if (active_) return true;
|
||||||
|
|
||||||
#ifdef Q_OS_WIN
|
#ifdef Q_OS_WIN
|
||||||
HWND hwnd = reinterpret_cast<HWND>(QGuiApplication::instance()->winId());
|
// 创建隐藏窗口用于接收 WM_HOTKEY 消息
|
||||||
if (!hwnd) {
|
if (!impl_->hiddenWindow) {
|
||||||
// Try to get the top-level widget's window handle
|
impl_->hiddenWindow = new QWidget();
|
||||||
hwnd = GetForegroundWindow();
|
impl_->hiddenWindow->setObjectName("HotkeyReceiver");
|
||||||
|
impl_->hiddenWindow->setWindowFlags(Qt::Tool | Qt::FramelessWindowHint);
|
||||||
|
impl_->hiddenWindow->resize(0, 0);
|
||||||
}
|
}
|
||||||
|
// 确保窗口已创建(show 会创建原生句柄)
|
||||||
|
impl_->hiddenWindow->show();
|
||||||
|
|
||||||
|
HWND hwnd = reinterpret_cast<HWND>(impl_->hiddenWindow->winId());
|
||||||
if (!hwnd) {
|
if (!hwnd) {
|
||||||
emit error("无法获取窗口句柄");
|
emit error("无法创建窗口句柄");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,11 +154,14 @@ void CapsLockVoiceHotkey::stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 注销快捷键
|
// 注销快捷键
|
||||||
HWND hwnd = reinterpret_cast<HWND>(QGuiApplication::instance()->winId());
|
if (impl_->hiddenWindow) {
|
||||||
if (hwnd && impl_->hotkeyId) {
|
HWND hwnd = reinterpret_cast<HWND>(impl_->hiddenWindow->winId());
|
||||||
UnregisterHotKey(hwnd, impl_->hotkeyId);
|
if (hwnd && impl_->hotkeyId) {
|
||||||
GlobalDeleteAtom(impl_->hotkeyId);
|
UnregisterHotKey(hwnd, impl_->hotkeyId);
|
||||||
impl_->hotkeyId = 0;
|
GlobalDeleteAtom(impl_->hotkeyId);
|
||||||
|
impl_->hotkeyId = 0;
|
||||||
|
}
|
||||||
|
impl_->hiddenWindow->hide();
|
||||||
}
|
}
|
||||||
|
|
||||||
active_ = false;
|
active_ = false;
|
||||||
@ -208,7 +219,7 @@ void CapsLockVoiceHotkey::onHotkeyEvent(int /*hotkeyId*/) {
|
|||||||
QThread::msleep(50);
|
QThread::msleep(50);
|
||||||
}
|
}
|
||||||
impl_->pollThreadRunning = false;
|
impl_->pollThreadRunning = false;
|
||||||
}).start();
|
})->start();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@ -49,6 +49,7 @@ signals:
|
|||||||
void error(const QString& message);
|
void error(const QString& message);
|
||||||
|
|
||||||
/** @brief 处理 WM_HOTKEY 事件(由原生事件过滤器调用) */
|
/** @brief 处理 WM_HOTKEY 事件(由原生事件过滤器调用) */
|
||||||
|
public:
|
||||||
void onHotkeyEvent(int hotkeyId);
|
void onHotkeyEvent(int hotkeyId);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
#include <QLabel>
|
#include <QLabel>
|
||||||
#include <QFileDialog>
|
#include <QFileDialog>
|
||||||
#include <QMessageBox>
|
#include <QMessageBox>
|
||||||
|
#include <QScrollArea>
|
||||||
|
|
||||||
static const char* const kTag = "SettingsPage";
|
static const char* const kTag = "SettingsPage";
|
||||||
|
|
||||||
@ -34,10 +35,24 @@ SettingsPage::~SettingsPage() = default;
|
|||||||
|
|
||||||
void SettingsPage::setupUI() {
|
void SettingsPage::setupUI() {
|
||||||
auto* mainLayout = new QVBoxLayout(this);
|
auto* mainLayout = new QVBoxLayout(this);
|
||||||
|
mainLayout->setContentsMargins(0, 0, 0, 0);
|
||||||
|
|
||||||
|
// ---- 可滚动内容区域 ----
|
||||||
|
auto* scrollArea = new QScrollArea(this);
|
||||||
|
scrollArea->setWidgetResizable(true);
|
||||||
|
scrollArea->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
|
||||||
|
scrollArea->setFrameShape(QFrame::NoFrame);
|
||||||
|
|
||||||
|
auto* contentWidget = new QWidget(scrollArea);
|
||||||
|
auto* contentLayout = new QVBoxLayout(contentWidget);
|
||||||
|
contentLayout->setContentsMargins(12, 8, 12, 8);
|
||||||
|
contentLayout->setSpacing(12);
|
||||||
|
|
||||||
// STT 设置
|
// STT 设置
|
||||||
auto* sttGroup = new QGroupBox("STT 推理设置", this);
|
auto* sttGroup = new QGroupBox("STT 推理设置", contentWidget);
|
||||||
auto* sttLayout = new QFormLayout(sttGroup);
|
auto* sttLayout = new QFormLayout(sttGroup);
|
||||||
|
sttLayout->setSpacing(8);
|
||||||
|
sttLayout->setContentsMargins(10, 12, 10, 12);
|
||||||
|
|
||||||
auto* modelRow = new QHBoxLayout();
|
auto* modelRow = new QHBoxLayout();
|
||||||
modelPathEdit_ = new QLineEdit(this);
|
modelPathEdit_ = new QLineEdit(this);
|
||||||
@ -109,11 +124,13 @@ void SettingsPage::setupUI() {
|
|||||||
temperatureSpin_->setValue(0.0);
|
temperatureSpin_->setValue(0.0);
|
||||||
sttLayout->addRow("温度 (Temperature):", temperatureSpin_);
|
sttLayout->addRow("温度 (Temperature):", temperatureSpin_);
|
||||||
|
|
||||||
mainLayout->addWidget(sttGroup);
|
contentLayout->addWidget(sttGroup);
|
||||||
|
|
||||||
// 音频设置
|
// 音频设置
|
||||||
auto* audioGroup = new QGroupBox("音频设置", this);
|
auto* audioGroup = new QGroupBox("音频设置", contentWidget);
|
||||||
auto* audioLayout = new QFormLayout(audioGroup);
|
auto* audioLayout = new QFormLayout(audioGroup);
|
||||||
|
audioLayout->setSpacing(8);
|
||||||
|
audioLayout->setContentsMargins(10, 12, 10, 12);
|
||||||
|
|
||||||
// 音频输入设备选择器
|
// 音频输入设备选择器
|
||||||
audioDeviceCombo_ = new QComboBox(this);
|
audioDeviceCombo_ = new QComboBox(this);
|
||||||
@ -150,11 +167,13 @@ void SettingsPage::setupUI() {
|
|||||||
paddingSpin_->setSuffix(" ms");
|
paddingSpin_->setSuffix(" ms");
|
||||||
audioLayout->addRow("块间重叠:", paddingSpin_);
|
audioLayout->addRow("块间重叠:", paddingSpin_);
|
||||||
|
|
||||||
mainLayout->addWidget(audioGroup);
|
contentLayout->addWidget(audioGroup);
|
||||||
|
|
||||||
// UI 设置
|
// UI 设置
|
||||||
auto* uiGroup = new QGroupBox("界面设置", this);
|
auto* uiGroup = new QGroupBox("界面设置", contentWidget);
|
||||||
auto* uiLayout = new QFormLayout(uiGroup);
|
auto* uiLayout = new QFormLayout(uiGroup);
|
||||||
|
uiLayout->setSpacing(8);
|
||||||
|
uiLayout->setContentsMargins(10, 12, 10, 12);
|
||||||
|
|
||||||
themeCombo_ = new QComboBox(this);
|
themeCombo_ = new QComboBox(this);
|
||||||
themeCombo_->addItems({"light", "dark"});
|
themeCombo_->addItems({"light", "dark"});
|
||||||
@ -173,26 +192,32 @@ void SettingsPage::setupUI() {
|
|||||||
showConfidenceCheck_->setChecked(true);
|
showConfidenceCheck_->setChecked(true);
|
||||||
uiLayout->addRow("置信度显示:", showConfidenceCheck_);
|
uiLayout->addRow("置信度显示:", showConfidenceCheck_);
|
||||||
|
|
||||||
mainLayout->addWidget(uiGroup);
|
contentLayout->addWidget(uiGroup);
|
||||||
|
contentLayout->addStretch();
|
||||||
|
|
||||||
// 操作按钮
|
scrollArea->setWidget(contentWidget);
|
||||||
auto* btnLayout = new QHBoxLayout();
|
mainLayout->addWidget(scrollArea);
|
||||||
auto* saveBtn = new QPushButton("保存配置", this);
|
|
||||||
|
// ---- 底部操作按钮(固定不滚动) ----
|
||||||
|
auto* btnBar = new QWidget(this);
|
||||||
|
auto* btnLayout = new QHBoxLayout(btnBar);
|
||||||
|
btnLayout->setContentsMargins(12, 4, 12, 8);
|
||||||
|
|
||||||
|
auto* saveBtn = new QPushButton("保存配置", btnBar);
|
||||||
saveBtn->setStyleSheet("QPushButton { font-weight: bold; padding: 8px 16px; }");
|
saveBtn->setStyleSheet("QPushButton { font-weight: bold; padding: 8px 16px; }");
|
||||||
connect(saveBtn, &QPushButton::clicked, this, &SettingsPage::onSaveConfig);
|
connect(saveBtn, &QPushButton::clicked, this, &SettingsPage::onSaveConfig);
|
||||||
btnLayout->addWidget(saveBtn);
|
btnLayout->addWidget(saveBtn);
|
||||||
|
|
||||||
auto* resetBtn = new QPushButton("恢复默认", this);
|
auto* resetBtn = new QPushButton("恢复默认", btnBar);
|
||||||
connect(resetBtn, &QPushButton::clicked, this, &SettingsPage::onResetConfig);
|
connect(resetBtn, &QPushButton::clicked, this, &SettingsPage::onResetConfig);
|
||||||
btnLayout->addWidget(resetBtn);
|
btnLayout->addWidget(resetBtn);
|
||||||
btnLayout->addStretch();
|
btnLayout->addStretch();
|
||||||
|
|
||||||
statusLabel_ = new QLabel("配置未修改", this);
|
statusLabel_ = new QLabel("配置未修改", btnBar);
|
||||||
statusLabel_->setStyleSheet("color: gray;");
|
statusLabel_->setStyleSheet("color: gray;");
|
||||||
btnLayout->addWidget(statusLabel_);
|
btnLayout->addWidget(statusLabel_);
|
||||||
|
|
||||||
mainLayout->addLayout(btnLayout);
|
mainLayout->addWidget(btnBar);
|
||||||
mainLayout->addStretch();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SettingsPage::loadFromConfig() {
|
void SettingsPage::loadFromConfig() {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user