diff --git a/src/core/model-loader.ts b/src/core/model-loader.ts index 616151e..ca8798c 100644 --- a/src/core/model-loader.ts +++ b/src/core/model-loader.ts @@ -3,8 +3,9 @@ * 负责加载和管理 ONNX 模型 */ -import { existsSync } from 'fs'; -import { join } from 'path'; +import { existsSync, readdirSync } from 'fs'; +import { join, dirname, resolve } from 'path'; +import { fileURLToPath } from 'url'; import * as ort from 'onnxruntime-web'; export interface ModelConfig { @@ -16,11 +17,10 @@ export interface ModelConfig { description: string; } -// 预定义模型配置 -export const MODEL_CONFIGS: Record = { +// 预定义模型配置(不含路径) +export const MODEL_CONFIGS: Record> = { sensevoice: { name: 'SenseVoice', - path: './models/sensevoice.onnx', language: ['zh', 'en', 'ja', 'ko'], sampleRate: 16000, inputShape: [1, 16000], @@ -28,7 +28,6 @@ export const MODEL_CONFIGS: Record = { }, whisper: { name: 'Whisper', - path: './models/whisper.onnx', language: ['zh', 'en', 'ja', 'ko', 'de', 'fr', 'es'], sampleRate: 16000, inputShape: [1, 480000], // 30 秒音频 @@ -36,7 +35,6 @@ export const MODEL_CONFIGS: Record = { }, paraformer: { name: 'Paraformer', - path: './models/paraformer.onnx', language: ['zh'], sampleRate: 16000, inputShape: [1, 16000], @@ -44,6 +42,13 @@ export const MODEL_CONFIGS: Record = { }, }; +// 模型文件名列表(按优先级) +export const MODEL_FILES = ['sensevoice.onnx', 'whisper.onnx', 'paraformer.onnx']; + +// 获取当前模块目录 +const __dirname = dirname(fileURLToPath(import.meta.url)); +const DEFAULT_MODELS_DIR = resolve(__dirname, '../../models'); + export class ModelLoader { private session: ort.InferenceSession | null = null; private config: ModelConfig | null = null; @@ -51,45 +56,65 @@ export class ModelLoader { /** * 获取可用的模型列表 */ - static getAvailableModels(): ModelConfig[] { - return Object.values(MODEL_CONFIGS).filter((config) => - existsSync(config.path) - ); + static getAvailableModels(modelsDir: string = DEFAULT_MODELS_DIR): ModelConfig[] { + const models: ModelConfig[] = []; + + for (const [key, config] of Object.entries(MODEL_CONFIGS)) { + const modelPath = join(modelsDir, `${key}.onnx`); + if (existsSync(modelPath)) { + models.push({ + ...config, + path: modelPath, + }); + } + } + + return models; } /** * 检查模型文件是否存在 */ - static checkModelExists(modelName: string): boolean { - const config = MODEL_CONFIGS[modelName]; - if (!config) return false; - return existsSync(config.path); + static checkModelExists(modelNameOrPath: string): boolean { + // 如果是完整路径 + if (existsSync(modelNameOrPath)) { + return true; + } + // 检查预定义模型 + const config = MODEL_CONFIGS[modelNameOrPath]; + if (config) { + return existsSync(join(DEFAULT_MODELS_DIR, `${modelNameOrPath}.onnx`)); + } + return false; } /** * 从目录加载模型 */ static async loadFromDir( - modelsDir: string + modelsDir: string = DEFAULT_MODELS_DIR ): Promise<{ session: ort.InferenceSession; config: ModelConfig } | null> { - // 按优先级查找模型 - const modelOrder = ['sensevoice.onnx', 'whisper.onnx', 'paraformer.onnx']; + if (!existsSync(modelsDir)) { + return null; + } - for (const modelName of modelOrder) { + // 按优先级查找模型 + for (const modelName of MODEL_FILES) { const modelPath = join(modelsDir, modelName); if (existsSync(modelPath)) { try { const session = await ort.InferenceSession.create(modelPath); - const config = Object.values(MODEL_CONFIGS).find((c) => - c.path.endsWith(modelName) - ) || { - name: modelName.replace('.onnx', ''), - path: modelPath, - language: ['zh'], - sampleRate: 16000, - inputShape: [1, 16000], - description: '自定义模型', - }; + const baseConfig = MODEL_CONFIGS[modelName.replace('.onnx', '')]; + const config: ModelConfig = baseConfig + ? { ...baseConfig, path: modelPath } + : { + name: modelName.replace('.onnx', ''), + path: modelPath, + language: ['zh'], + sampleRate: 16000, + inputShape: [1, 16000], + description: '自定义模型', + }; return { session, config }; } catch (error) { console.warn(`加载模型 ${modelName} 失败:`, error); @@ -105,23 +130,16 @@ export class ModelLoader { */ async load(modelNameOrPath: string): Promise { let modelPath: string; - let modelConfig: ModelConfig | undefined; + let baseConfig: Omit | undefined; // 检查是否为预定义模型名称 if (MODEL_CONFIGS[modelNameOrPath]) { - modelConfig = MODEL_CONFIGS[modelNameOrPath]; - modelPath = modelConfig.path; + baseConfig = MODEL_CONFIGS[modelNameOrPath]; + modelPath = join(DEFAULT_MODELS_DIR, `${modelNameOrPath}.onnx`); } else { // 直接使用路径 modelPath = modelNameOrPath; - modelConfig = { - name: 'custom', - path: modelPath, - language: ['zh'], - sampleRate: 16000, - inputShape: [1, 16000], - description: '自定义模型路径', - }; + baseConfig = undefined; } if (!existsSync(modelPath)) { @@ -136,11 +154,16 @@ export class ModelLoader { }; this.session = await ort.InferenceSession.create(modelPath, sessionOptions); - this.config = modelConfig; - console.log(`✅ 模型加载成功:${modelConfig.name}`); - console.log(` 支持语言:${modelConfig.language.join(', ')}`); - console.log(` 采样率:${modelConfig.sampleRate}Hz`); + const base = baseConfig || MODEL_CONFIGS['sensevoice']; + this.config = { + ...base, + path: modelPath, + }; + + console.log(`✅ 模型加载成功:${this.config.name}`); + console.log(` 支持语言:${this.config.language.join(', ')}`); + console.log(` 采样率:${this.config.sampleRate}Hz`); } catch (error) { throw new Error(`模型加载失败:${error}`); } diff --git a/src/main.ts b/src/main.ts index a1ac3f6..b195a9c 100644 --- a/src/main.ts +++ b/src/main.ts @@ -26,16 +26,37 @@ program .command('start') .description('开始语音识别') .option('-l, --language ', '识别语言', 'zh') - .option('-m, --model ', '模型文件路径', join(__dirname, '../models/model.onnx')) + .option('-m, --model ', '模型文件路径(可选,无模型时以配置模式启动)') .option('-o, --output ', '输出模式:clipboard|keyboard|both', 'clipboard') .action(async (options) => { - console.log('🎤 启动语音识别...'); + console.log('🎤 Impress ASR Input'); + console.log(` 版本:${packageJson.version}`); console.log(` 语言:${options.language}`); - console.log(` 模型:${options.model}`); console.log(` 输出:${options.output}`); + // 检查模型文件是否存在 + const { existsSync } = await import('fs'); + const modelPath = options.model || join(__dirname, '../models/model.onnx'); + + if (!existsSync(modelPath)) { + console.log('\n⚠️ 未检测到模型文件,以配置模式启动'); + console.log('\n📥 模型下载指引:'); + console.log(' 1. SenseVoice (推荐): https://huggingface.co/FunAudioLLM/SenseVoice'); + console.log(' 2. Whisper: https://huggingface.co/onnx-community/whisper-base'); + console.log(' 3. Paraformer: https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punct'); + console.log('\n📁 将下载的模型文件放入以下目录之一:'); + console.log(' - ./models/sensevoice.onnx'); + console.log(' - ./models/whisper.onnx'); + console.log(' - ./models/paraformer.onnx'); + console.log('\n💡 或使用 --model 参数指定模型路径'); + console.log(' 示例:npm start -- start -m /path/to/your/model.onnx'); + return; + } + + console.log(` 模型:${modelPath}`); + const recognizer = new SpeechRecognizer({ - modelPath: options.model, + modelPath, language: options.language, useVad: true, beamSize: 5,