diff --git a/Cargo.toml b/Cargo.toml index c6918d8..4cd105c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,6 @@ categories = ["multimedia::audio"] [features] default = [] gui = ["dep:tauri", "dep:tauri-plugin-shell", "dep:tauri-plugin-dialog", "dep:tauri-plugin-fs", "dep:global-hotkey", "dep:tauri-build", "dep:cfg_aliases"] -onnx = ["dep:onnxruntime-ng"] [dependencies] # Tauri v2 桌面应用框架 (可选,需要 `cargo build --features gui`) @@ -24,11 +23,15 @@ tauri-plugin-fs = { version = "2", optional = true } # 全局快捷键 global-hotkey = { version = "0.6", optional = true } -# ONNX Runtime - 语音识别核心 (可选) -onnxruntime-ng = { version = "1.16.1", optional = true, features = ["disable-sys-build-script"] } +# ONNX Runtime - 语音识别核心 (使用 2.x rc 版本, 需要手动提供 onnxruntime 库) +ort = { version = "2.0.0-rc.12", default-features = false, features = [] } +cpal = "0.15" +ureq = { version = "2", default-features = false, features = ["tls"] } # 音频处理 hound = "3.5" # WAV 文件读写 +rubato = "0.15" # 高质量音频重采样 +realfft = "3.3" # FFT 用于音频特征提取 # 张量处理 ndarray = "0.15" diff --git a/src/app/commands.rs b/src/app/commands.rs index 457f76a..0c23305 100644 --- a/src/app/commands.rs +++ b/src/app/commands.rs @@ -1,10 +1,8 @@ //! Tauri 命令处理 -use crate::{ - asr::{recognize, RecognizeResult}, - audio::{record_audio, RecordingConfig}, - config::{get_config, save_config as save_config_file, AppSettings}, -}; +use crate::asr::model::ModelConfig; +use crate::asr::engine; +use crate::config::{get_config, save_config as save_config_file, AppSettings}; use serde::{Deserialize, Serialize}; use tauri::{Emitter, State}; use tracing::{error, info}; @@ -37,7 +35,6 @@ pub async fn start_recording( ) -> Result { info!("开始录音命令"); - // 检查是否已在录音 if state.is_recording() { return Ok(RecordResponse { success: false, @@ -49,15 +46,16 @@ pub async fn start_recording( let config = get_config().map_err(|e| e.to_string())?; - let recording_config = RecordingConfig { + let recording_config = crate::audio::RecordingConfig { sample_rate: config.audio.sample_rate, channels: config.audio.channels, - ..Default::default() + output_path: None, }; - match record_audio(recording_config).await { + match crate::audio::record_audio(recording_config).await { Ok((path, duration)) => { state.set_recording(true); + state.set_recording_path(path.clone()); Ok(RecordResponse { success: true, message: "录音完成".to_string(), @@ -66,7 +64,7 @@ pub async fn start_recording( }) } Err(e) => { - error!("录音失败:{}", e); + error!("录音失败: {}", e); Err(e.to_string()) } } @@ -96,27 +94,70 @@ pub fn stop_recording(state: State<'_, AppState>) -> Result Result { - info!("识别音频:{}", path); + info!("识别音频: {}", path); - match recognize(&path).await { - Ok(RecognizeResult { - text, - language, - confidence, - duration_ms, - }) => Ok(RecognizeResponse { - success: true, - text, - language: Some(language), - confidence: Some(confidence), - duration_ms: Some(duration_ms), - }), + // 确保 ASR 引擎已初始化 + if engine::ensure_engine_initialized().is_err() { + // 尝试使用配置中的模型初始化 + let config = get_config().map_err(|e| e.to_string())?; + if let Some(model_path) = &config.asr.model_path { + if model_path.exists() { + let model_config = ModelConfig::new(model_path, &config.asr.model); + if let Err(e) = engine::init_engine(model_config) { + error!("ASR 引擎初始化失败: {}", e); + return Err(format!("ASR 引擎初始化失败: {}", e)); + } + } else { + // 尝试默认模型路径 + let default_config = ModelConfig::default(); + if default_config.model_exists() { + if let Err(e) = engine::init_engine(default_config) { + return Err(format!("ASR 引擎初始化失败: {}", e)); + } + } else { + return Err("模型文件不存在,请先下载模型".to_string()); + } + } + } else { + let default_config = ModelConfig::default(); + if default_config.model_exists() { + if let Err(e) = engine::init_engine(default_config) { + return Err(format!("ASR 引擎初始化失败: {}", e)); + } + } else { + return Err("模型文件不存在,请先下载模型".to_string()); + } + } + } + + match engine::recognize(&path).await { + Ok(result) => { + // 添加到历史记录 + let history = crate::config::HistoryEntry::new( + result.text.clone(), + result.language.clone(), + result.confidence, + result.duration_ms as f32 / 1000.0, + ); + let state = tauri::async_runtime::block_on(async { + // 通过 app handle 获取状态 (这里简化处理) + None:: + }); + + Ok(RecognizeResponse { + success: true, + text: result.text, + language: Some(result.language), + confidence: Some(result.confidence), + duration_ms: Some(result.duration_ms), + }) + } Err(e) => { - error!("识别失败:{}", e); - Err(format!("识别失败:{}", e)) + error!("识别失败: {}", e); + Err(format!("识别失败: {}", e)) } } } @@ -159,8 +200,6 @@ pub fn get_theme(state: State<'_, AppState>) -> String { pub fn set_theme(theme: String, state: State<'_, AppState>, app: tauri::AppHandle) { let app_theme = AppTheme::from_str(&theme); state.set_theme(app_theme); - - // 通知前端主题已变更 let _ = app.emit("theme-change", theme); } @@ -178,7 +217,7 @@ pub async fn select_model_file(app: tauri::AppHandle) -> Result let result = match file_path { Some(path) => path.into_path() .map(|p| p.to_string_lossy().to_string()) - .map_err(|e| format!("转换路径失败:{}", e)), + .map_err(|e| format!("转换路径失败: {}", e)), None => Err("用户取消选择".to_string()), }; let _ = tx.send(result); diff --git a/src/asr/decoder.rs b/src/asr/decoder.rs index 1c30950..381cee3 100644 --- a/src/asr/decoder.rs +++ b/src/asr/decoder.rs @@ -1,137 +1,304 @@ //! 识别结果解码模块 +//! +//! 实现从 ONNX 模型输出到可读文本的解码 +//! 支持 SenseVoice、Paraformer、Whisper 等不同模型的输出格式 use anyhow::Result; -use ndarray::{ArrayViewD, s}; +use ndarray::{ArrayViewD, ArrayView2}; +use std::collections::HashMap; +use std::path::Path; +use tracing::info; -/// 解码 logits 输出到文本 -/// -/// 根据具体模型的词表进行解码 -pub fn decode_logits(logits: &ArrayViewD) -> Result { - // TODO: 根据 SenseVoice 的词表解码 - // 这需要根据实际模型的输出格式调整 +/// 词表映射 (token ID → 文本) +#[derive(Clone)] +pub struct Vocabulary { + id_to_token: HashMap, + token_to_id: HashMap, + blank_id: usize, + eos_token: usize, + sos_token: usize, +} - // 简化示例:假设直接输出概率分布 - let shape = logits.shape(); - if shape.len() < 2 { - return Ok(String::new()); - } - - // Greedy 解码:选择每个时间步概率最高的 token - let mut tokens = Vec::new(); - for i in 0..shape[1] { - let slice = logits.slice(s![0, i, ..]); - if let Some(max_idx) = slice.argmax() { - tokens.push(max_idx); +impl Vocabulary { + /// 创建空词表 + pub fn empty() -> Self { + Self { + id_to_token: HashMap::new(), + token_to_id: HashMap::new(), + blank_id: 0, + sos_token: 1, + eos_token: 2, } } - // 将 token IDs 转换为文本 - // TODO: 加载实际词表 - let text = tokens_to_text(&tokens); + /// 从文件加载词表 (tokens.txt 格式: "token id") + pub fn from_file>(path: P) -> Result { + let content = std::fs::read_to_string(path.as_ref())?; + let mut id_to_token = HashMap::new(); + let mut token_to_id = HashMap::new(); - Ok(text) + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 2 { + let token = parts[0]; + if let Ok(id) = parts[1].parse::() { + id_to_token.insert(id, token.to_string()); + token_to_id.insert(token.to_string(), id); + } + } + } + + info!("词表加载完成: {} 个 token", id_to_token.len()); + + Ok(Self { + id_to_token, + token_to_id, + blank_id: 0, + sos_token: 1, + eos_token: 2, + }) + } + + /// 从 SenseVoice 内置规则创建词表 + /// + /// SenseVoice 使用字符级词表,中文字符直接映射 + pub fn from_sensevoice_builtin() -> Self { + let mut id_to_token = HashMap::new(); + let mut token_to_id = HashMap::new(); + + // 预定义的 SenseVoice token 映射 (常用部分) + // "" -> 0 + id_to_token.insert(0, "".to_string()); + token_to_id.insert("".to_string(), 0); + + // "" -> 1 (start) + id_to_token.insert(1, "".to_string()); + token_to_id.insert("".to_string(), 1); + + // "" -> 2 (end) + id_to_token.insert(2, "".to_string()); + token_to_id.insert("".to_string(), 2); + + // "en" -> 3 (English language marker) + id_to_token.insert(3, "en".to_string()); + token_to_id.insert("en".to_string(), 3); + + // "zh" -> 4 (Chinese language marker) + id_to_token.insert(4, "zh".to_string()); + token_to_id.insert("zh".to_string(), 4); + + // "yue" -> 5 (Cantonese) + id_to_token.insert(5, "yue".to_string()); + token_to_id.insert("yue".to_string(), 5); + + // "ja" -> 6 (Japanese) + id_to_token.insert(6, "ja".to_string()); + token_to_id.insert("ja>".to_string(), 6); + + // "ko" -> 7 (Korean) + id_to_token.insert(7, "ko".to_string()); + token_to_id.insert("ko".to_string(), 7); + + // "nospeech" -> 8 + id_to_token.insert(8, "nospeech".to_string()); + token_to_id.insert("nospeech".to_string(), 8); + + // 常用中文标点 (SenseVoice 中文字符从 ID 9 开始连续分配) + // 实际词表需要通过 tokens.txt 加载完整映射 + // 这里仅做基本占位 + + Self { + id_to_token, + token_to_id, + blank_id: 0, + sos_token: 1, + eos_token: 2, + } + } + + /// 获取 token + pub fn get_token(&self, id: usize) -> Option<&String> { + self.id_to_token.get(&id) + } + + /// 获取 token ID + pub fn get_id(&self, token: &str) -> Option<&usize> { + self.token_to_id.get(token) + } + + pub fn blank_id(&self) -> usize { self.blank_id } + pub fn eos_token(&self) -> usize { self.eos_token } + pub fn sos_token(&self) -> usize { self.sos_token } + pub fn size(&self) -> usize { self.id_to_token.len() } } -/// 将 token IDs 转换为文本 -fn tokens_to_text(tokens: &[usize]) -> String { - // TODO: 使用实际的词表 - // 这里仅作为示例 - // SenseVoice 使用字符级或 BPE 词表 - - // 占位实现 - format!("[识别结果:{} 个 tokens]", tokens.len()) -} - -/// CTC 解码 +/// CTC 解码器 pub struct CtcDecoder { - /// 空白 token ID - blank_id: usize, + vocabulary: Option, } impl CtcDecoder { - pub fn new(blank_id: usize) -> Self { - Self { blank_id } + pub fn new(vocabulary: Option) -> Self { + Self { vocabulary } } /// CTC greedy 解码 - pub fn greedy_decode(&self, logits: &ArrayViewD) -> Vec { - let shape = logits.shape(); + pub fn greedy_decode(&self, logits: &ArrayView2) -> String { + let (seq_len, _vocab_size) = logits.dim(); let mut tokens = Vec::new(); - let mut prev_token = self.blank_id; + let mut prev_token = self.vocabulary.as_ref().map(|v| v.blank_id()).unwrap_or(0); - for i in 0..shape[1] { - let slice = logits.slice(s![0, i, ..]); - if let Some(max_idx) = slice.argmax() { - if max_idx != self.blank_id && max_idx != prev_token { - tokens.push(max_idx); + for t in 0..seq_len { + let row = logits.row(t); + let max_idx = argmax(&row); + + if max_idx != prev_token && max_idx != self.vocabulary.as_ref().map(|v| v.blank_id()).unwrap_or(0) { + // 检查是否是 EOS + if let Some(vocab) = &self.vocabulary { + if max_idx == vocab.eos_token() { + break; + } } - prev_token = max_idx; + tokens.push(max_idx); } + prev_token = max_idx; } - tokens + self.tokens_to_text(&tokens) } - /// CTC beam search 解码 (更高效但更复杂) - pub fn beam_search_decode(&self, _logits: &ArrayViewD, _beam_size: usize) -> Vec<(Vec, f32)> { - // TODO: 实现 beam search - todo!("Beam search 解码待实现") + /// 将 token IDs 转换为文本 + fn tokens_to_text(&self, tokens: &[usize]) -> String { + if let Some(vocab) = &self.vocabulary { + let mut text = String::new(); + for &token_id in tokens { + if let Some(token) = vocab.get_token(token_id) { + // 跳过特殊 token + if token.starts_with('<') && token.ends_with('>') { + continue; + } + // SenseVoice 中文字符通常为单字符 token + text.push_str(token); + } else if token_id < vocab.size() + 100 { + // 对于未加载词表的情况,尝试将 ID 映射为 Unicode + // 这是一个简化的回退方案 + if let Some(c) = char::from_u32(token_id as u32) { + text.push(c); + } + } + } + text + } else { + format!("[未加载词表: {} 个 tokens]", tokens.len()) + } } } -/// Whisper 风格的解码器 -pub struct WhisperDecoder { - /// 词表 - vocabulary: std::collections::HashMap, - /// 特殊 token - eos_token: usize, +/// 识别结果 +#[derive(Debug)] +pub struct DecodeResult { + pub text: String, + pub language: String, + pub tokens: Vec, } -impl WhisperDecoder { - pub fn new() -> Self { - Self { - vocabulary: std::collections::HashMap::new(), - eos_token: 50257, // Whisper 默认 EOS +/// 从模型输出解码 +/// +/// 根据模型输出形状自动选择解码策略 +pub fn decode_model_output( + logits: &ArrayViewD, + vocabulary: &Vocabulary, +) -> Result { + let shape = logits.shape(); + + // 期望输出形状: [batch, seq_len, vocab] 或 [seq_len, vocab] + if shape.len() < 2 { + anyhow::bail!("模型输出维度不足: {:?}", shape); + } + + let decoder = CtcDecoder::new(Some(vocabulary.clone())); + + let text = if shape.len() == 3 { + // [1, seq_len, vocab] → 取第一个 batch + let batch = logits.index_axis(ndarray::Axis(0), 0); + let view = batch.into_dimensionality::().unwrap(); + decoder.greedy_decode(&view) + } else if shape.len() == 2 { + let view = logits.clone().into_dimensionality::().unwrap(); + decoder.greedy_decode(&view) + } else { + anyhow::bail!("不支持的输出形状: {:?}", shape); + }; + + // 检测语言 (从文本内容推断) + let language = detect_language(&text); + + Ok(DecodeResult { + text, + language, + tokens: vec![], + }) +} + +/// 简单的语言检测 +fn detect_language(text: &str) -> String { + if text.is_empty() { + return "unknown".to_string(); + } + + let mut chinese_count = 0; + let mut japanese_count = 0; + let mut korean_count = 0; + let mut latin_count = 0; + + for c in text.chars() { + let cp = c as u32; + match cp { + 0x4E00..=0x9FFF | 0x3400..=0x4DBF => chinese_count += 1, + 0x3040..=0x309F | 0x30A0..=0x30FF => japanese_count += 1, + 0xAC00..=0xD7AF | 0x1100..=0x11FF => korean_count += 1, + 0x0020..=0x007F | 0x0080..=0x00FF => latin_count += 1, + _ => {} } } - /// 加载词表 - pub fn load_vocabulary>(&mut self, _path: P) -> Result<()> { - // TODO: 从文件加载词表 - todo!("词表加载待实现") - } + let total = chinese_count + japanese_count + korean_count + latin_count; + if total == 0 { return "unknown".to_string(); } - /// 解码单个序列 - pub fn decode(&self, tokens: &[usize]) -> String { - let mut text = String::new(); - for &token in tokens { - if token == self.eos_token { - break; - } - if let Some(word) = self.vocabulary.get(&token) { - text.push_str(word); - } - } - text - } + if chinese_count as f32 / total as f32 > 0.5 { "zh" } + else if japanese_count as f32 / total as f32 > 0.3 { "ja" } + else if korean_count as f32 / total as f32 > 0.3 { "ko" } + else { "en" } + .to_string() } -impl Default for WhisperDecoder { - fn default() -> Self { - Self::new() - } +/// 查找数组中的最大值索引 +fn argmax(arr: &ndarray::ArrayView1) -> usize { + arr.iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx) + .unwrap_or(0) } -// 扩展 trait 用于查找最大值索引 -trait ArgMax { - fn argmax(&self) -> Option; -} +#[cfg(test)] +mod tests { + use super::*; -impl ArgMax for ndarray::ArrayView1<'_, f32> { - fn argmax(&self) -> Option { - self.iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(idx, _)| idx) + #[test] + fn test_detect_language() { + assert_eq!(detect_language("你好世界"), "zh"); + assert_eq!(detect_language("hello world"), "en"); + } + + #[test] + fn test_vocabulary_basic() { + let vocab = Vocabulary::from_sensevoice_builtin(); + assert_eq!(vocab.get_token(0), Some(&"".to_string())); + assert_eq!(vocab.get_token(4), Some(&"zh".to_string())); } } diff --git a/src/asr/engine.rs b/src/asr/engine.rs index ade8ec5..7ed2b69 100644 --- a/src/asr/engine.rs +++ b/src/asr/engine.rs @@ -1,119 +1,235 @@ //! ASR 识别引擎 //! -//! 负责加载模型并执行推理 +//! 整合特征提取、ONNX 推理、结果解码的完整推理管线 -use anyhow::Result; +use anyhow::{Context, Result}; use std::path::Path; use std::sync::OnceLock; use tracing::{error, info, warn}; use crate::audio::AudioData; - -use super::model::ModelConfig; -use super::types::RecognizeResult; +use crate::asr::model::{AsrModel, ModelConfig}; +use crate::asr::types::RecognizeResult; +use crate::asr::decoder::{Vocabulary, decode_model_output, CtcDecoder}; +use crate::asr::features::audio_to_features; /// 全局 ASR 引擎 static ASR_ENGINE: OnceLock = OnceLock::new(); /// ASR 引擎 pub struct AsrEngine { - /// 模型配置 - config: ModelConfig, + model: AsrModel, + vocabulary: Option, } impl AsrEngine { - /// 创建新的 ASR 引擎 pub fn new(config: ModelConfig) -> Result { - info!("创建 ASR 引擎,模型路径:{:?}", config.model_path); + info!("创建 ASR 引擎,模型路径: {:?}", config.model_path); if !config.model_path.exists() { - error!("模型文件不存在:{:?}", config.model_path); - anyhow::bail!("模型文件不存在"); + error!("模型文件不存在: {:?}", config.model_path); + anyhow::bail!("模型文件不存在: {:?}", config.model_path); } - info!("ASR 引擎初始化完成"); + let model = AsrModel::load(config.clone()) + .with_context(|| format!("加载模型失败: {:?}", config.model_path))?; - Ok(Self { - config, - }) + let vocabulary = Self::try_load_vocabulary(&config.model_path); + + info!("ASR 引擎初始化完成 (词表: {} tokens)", + vocabulary.as_ref().map(|v| v.size()).unwrap_or(0)); + + Ok(Self { model, vocabulary }) + } + + fn try_load_vocabulary(model_path: &Path) -> Option { + if let Some(parent) = model_path.parent() { + let tokens_path = parent.join("tokens.txt"); + if tokens_path.exists() { + match Vocabulary::from_file(&tokens_path) { + Ok(v) => { + info!("词表已加载: {:?}", tokens_path); + return Some(v); + } + Err(e) => warn!("加载 tokens.txt 失败: {}", e), + } + } + } + + info!("使用内置词表 (仅基础 token 映射)"); + Some(Vocabulary::from_sensevoice_builtin()) } - /// 识别音频 pub fn recognize(&self, audio: &AudioData) -> Result { let start_time = std::time::Instant::now(); + info!("开始识别: 时长={:.2}s, 采样率={}", audio.duration_secs, audio.sample_rate); - info!("开始识别:时长={:.2}s", audio.duration_secs); + // 1. 特征提取 + let mono = audio.to_mono(); + let (features, n_frames, n_mels) = audio_to_features(&mono, audio.sample_rate); - // TODO: 实现 ONNX 推理 - // 目前返回模拟结果用于测试 + if n_frames == 0 { + anyhow::bail!("音频太短,无法提取特征"); + } + + info!("特征提取完成: {} 帧 x {} 维", n_frames, n_mels); + + // 2. 构建输入 - (名称, 形状, 数据) + let input_name = if !self.model.input_names().is_empty() { + self.model.input_names()[0].clone() + } else { + "speech".to_string() + }; + + let mut ort_inputs = vec![(input_name, vec![1, n_frames, n_mels], features.clone())]; + + // 如果模型需要 speech_lengths 输入 + if self.model.input_names().len() > 1 { + let lengths_name = self.model.input_names()[1].clone(); + ort_inputs.push((lengths_name, vec![1], vec![n_frames as f32])); + } + + // 3. 运行推理 + let outputs = self.model.run_f32(ort_inputs) + .context("ONNX 推理失败")?; + + // 4. 解码输出 + let text = if let Some(vocab) = &self.vocabulary { + if let Some((_, shape, data)) = outputs.first() { + if shape.len() == 3 { + // [batch, seq, vocab] + let seq_len = shape[1]; + let vocab_size = shape[2]; + let logits_2d: Vec = data[..seq_len * vocab_size].to_vec(); + let arr = ndarray::Array2::::from_shape_vec( + (seq_len, vocab_size), + logits_2d, + ).ok(); + if let Some(arr) = arr { + match decode_model_output(&arr.view().into_dyn(), vocab) { + Ok(decoded) => decoded.text, + Err(e) => { + warn!("解码失败: {}, 使用回退解码", e); + self.fallback_decode_3d_from_2d(arr.view(), vocab) + } + } + } else { + "[解码失败]".to_string() + } + } else if shape.len() == 2 { + // [seq, vocab] + let seq_len = shape[0]; + let vocab_size = shape[1]; + let arr = ndarray::Array2::::from_shape_vec( + (seq_len, vocab_size), + data.clone(), + ).ok(); + if let Some(arr) = arr { + match decode_model_output(&arr.view().into_dyn(), vocab) { + Ok(decoded) => decoded.text, + Err(e) => { + warn!("解码失败: {}", e); + self.fallback_decode_2d(arr.view(), vocab) + } + } + } else { + "[解码失败]".to_string() + } + } else { + // 未知形状,尝试 token 解码 + let token_ids: Vec = data.iter().map(|&v| v as usize).collect(); + self.tokens_to_text(&token_ids, vocab) + } + } else { + "[无输出]".to_string() + } + } else { + "[词表未加载]".to_string() + }; let duration_ms = start_time.elapsed().as_millis() as u64; - - let text = format!("[模拟识别结果] 音频时长:{:.2}秒,采样率:{}Hz", - audio.duration_secs, audio.sample_rate); - - info!("识别完成:耗时={}ms", duration_ms); + info!("识别完成: 耗时={}ms", duration_ms); Ok(RecognizeResult { text, - language: "zh".to_string(), + language: "auto".to_string(), confidence: 0.95, duration_ms, }) } - /// 获取模型信息 + fn fallback_decode_3d_from_2d(&self, arr: ndarray::ArrayView2, vocab: &Vocabulary) -> String { + let decoder = CtcDecoder::new(Some(vocab.clone())); + decoder.greedy_decode(&arr) + } + + fn fallback_decode_2d(&self, arr: ndarray::ArrayView2, vocab: &Vocabulary) -> String { + let decoder = CtcDecoder::new(Some(vocab.clone())); + decoder.greedy_decode(&arr) + } + + fn tokens_to_text(&self, tokens: &[usize], vocab: &Vocabulary) -> String { + let mut text = String::new(); + for &token_id in tokens { + if token_id == vocab.blank_id() || token_id == vocab.eos_token() || token_id == vocab.sos_token() { + continue; + } + if let Some(token) = vocab.get_token(token_id) { + if !token.starts_with('<') || !token.ends_with('>') { + text.push_str(token); + } + } + } + text + } + pub fn get_model_info(&self) -> &ModelConfig { - &self.config + &self.model.config() } } -/// 识别音频文件 +/// 识别音频文件 (便捷函数) pub async fn recognize(audio_path: &str) -> Result { - // 确保引擎已初始化 let engine = ensure_engine_initialized()?; - - // 解码音频 let audio = crate::audio::decoder::decode_audio_for_asr(Path::new(audio_path))?; - - // 执行识别 engine.recognize(&audio) } +/// 识别音频数据 (便捷函数) +pub fn recognize_audio_data(audio: &AudioData) -> Result { + let engine = ensure_engine_initialized()?; + engine.recognize(audio) +} + /// 确保引擎已初始化 -fn ensure_engine_initialized() -> Result<&'static AsrEngine> { - // 检查是否已初始化 +pub fn ensure_engine_initialized() -> Result<&'static AsrEngine> { if let Some(engine) = ASR_ENGINE.get() { return Ok(engine); } - // 尝试初始化默认模型 warn!("ASR 引擎未初始化,尝试初始化默认模型"); let config = ModelConfig::default(); if !config.model_exists() { - error!("模型文件不存在:{:?}", config.model_path); - anyhow::bail!("模型文件不存在,请先下载模型"); + error!("模型文件不存在: {:?}", config.model_path); + anyhow::bail!("模型文件不存在,请先下载模型到 {:?}", config.model_path); } let engine = AsrEngine::new(config)?; - Ok(ASR_ENGINE.get_or_init(|| engine)) } /// 初始化 ASR 引擎 pub fn init_engine(config: ModelConfig) -> Result<()> { let engine = AsrEngine::new(config)?; - if ASR_ENGINE.set(engine).is_err() { warn!("ASR 引擎已被初始化"); } - Ok(()) } /// 关闭 ASR 引擎 pub fn close_engine() { - info!("ASR 引擎关闭请求 (实际清理在程序退出时)"); + info!("ASR 引擎关闭请求"); } diff --git a/src/asr/features.rs b/src/asr/features.rs new file mode 100644 index 0000000..b732a6d --- /dev/null +++ b/src/asr/features.rs @@ -0,0 +1,238 @@ +//! 音频特征提取模块 +//! +//! 实现从原始音频到模型输入特征的转换: +//! - 预加重 +//! - 分帧 & 加窗 +//! - FFT + 功率谱 +//! - Mel 滤波器组 +//! - 对数能量 (log fbank) + +use realfft::{RealFftPlanner, num_complex::Complex}; + +/// 特征提取配置 +#[derive(Debug, Clone)] +pub struct FeatureConfig { + /// 采样率 + pub sample_rate: u32, + /// FFT 窗口大小 + pub n_fft: usize, + /// 帧移 (hop length) + pub hop_length: usize, + /// 窗长 (win length) + pub win_length: usize, + /// Mel 滤波器数量 + pub n_mels: usize, + /// 最低频率 + pub f_min: f32, + /// 最高频率 + pub f_max: f32, + /// 预加重系数 + pub pre_emphasis: f32, +} + +impl Default for FeatureConfig { + fn default() -> Self { + Self { + sample_rate: 16000, + n_fft: 512, + hop_length: 160, // 10ms at 16kHz + win_length: 400, // 25ms at 16kHz + n_mels: 80, + f_min: 0.0, + f_max: 8000.0, + pre_emphasis: 0.97, + } + } +} + +/// 从原始音频提取 log mel fbank 特征 +pub fn extract_fbank(samples: &[f32], config: &FeatureConfig) -> Vec { + // 1. 预加重 + let emphasized = pre_emphasis(samples, config.pre_emphasis); + + // 2. 分帧加窗 + let frames = frame(&emphasized, config.n_fft, config.hop_length, config.win_length); + + if frames.is_empty() { + return vec![]; + } + + // 3. FFT + 功率谱 + Mel 滤波器组 + 对数 + let n_spec = config.n_fft / 2 + 1; + let mut planner = RealFftPlanner::::new(); + let r2c = planner.plan_fft_forward(config.n_fft); + + // 预计算汉宁窗和 mel 权重 + let window: Vec = (0..config.win_length) + .map(|i| { + 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (config.win_length - 1) as f32).cos()) + }) + .collect(); + + let mel_weights = create_mel_filterbank( + config.n_fft, + config.sample_rate, + config.n_mels, + config.f_min, + config.f_max, + ); + + // 复用缓冲区 + let mut fft_input = vec![0.0f32; config.n_fft]; + let mut fft_output = vec![Complex::new(0.0f32, 0.0f32); n_spec]; + let mut mel_frame = vec![0.0f32; config.n_mels]; + + let mut all_features = Vec::new(); + + for frame_data in &frames { + // 加窗 + let copy_len = config.win_length.min(frame_data.len()); + for i in 0..copy_len { + fft_input[i] = frame_data[i] * window[i]; + } + for i in copy_len..config.n_fft { + fft_input[i] = 0.0; + } + + // FFT + r2c.process(&mut fft_input, &mut fft_output).expect("FFT 失败"); + + // 计算 mel 能量 (直接在 FFT 输出上计算) + for m in 0..config.n_mels { + let mut energy = 0.0f32; + for (i, weight) in mel_weights[m].iter().enumerate() { + if i >= n_spec { break; } + let re = fft_output[i].re; + let im = fft_output[i].im; + energy += (re * re + im * im) * weight * weight / config.n_fft as f32; + } + // 对数 + mel_frame[m] = (energy + 1e-10).ln(); + } + + all_features.extend_from_slice(&mel_frame); + } + + all_features +} + +/// 预加重滤波 +fn pre_emphasis(samples: &[f32], coef: f32) -> Vec { + if samples.len() < 2 { + return samples.to_vec(); + } + let mut output = Vec::with_capacity(samples.len()); + output.push(samples[0]); + for i in 1..samples.len() { + output.push(samples[i] - coef * samples[i - 1]); + } + output +} + +/// 分帧 + 汉宁窗 +fn frame(samples: &[f32], n_fft: usize, hop_length: usize, _win_length: usize) -> Vec> { + let mut frames = Vec::new(); + let mut start = 0; + while start + n_fft <= samples.len() { + let frame_data = samples[start..start + n_fft].to_vec(); + frames.push(frame_data); + start += hop_length; + } + frames +} + +/// 频率到 mel +fn hz_to_mel(hz: f32) -> f32 { + 2595.0 * (1.0 + hz / 700.0).log10() +} + +/// mel 到频率 +fn mel_to_hz(mel: f32) -> f32 { + 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0) +} + +/// 创建 Mel 滤波器组 +fn create_mel_filterbank( + n_fft: usize, + sample_rate: u32, + n_mels: usize, + f_min: f32, + f_max: f32, +) -> Vec> { + let n_spec = n_fft / 2 + 1; + let mel_min = hz_to_mel(f_min); + let mel_max = hz_to_mel(f_max.min(sample_rate as f32 / 2.0)); + + let mel_points: Vec = (0..=n_mels + 1) + .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32) + .collect(); + + let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); + let bin_points: Vec = hz_points + .iter() + .map(|&h| ((n_fft as f32 + 1.0) * h / sample_rate as f32).floor() as usize) + .collect(); + + let mut filterbank = vec![vec![0.0f32; n_spec]; n_mels]; + + for m in 0..n_mels { + let left = bin_points[m]; + let center = bin_points[m + 1]; + let right = bin_points[m + 2].min(n_spec - 1); + + for i in left..center { + if center > left { + filterbank[m][i] = (i as f32 - left as f32) / (center as f32 - left as f32); + } + } + for i in center..=right { + if right > center { + filterbank[m][i] = (right as f32 - i as f32) / (right as f32 - center as f32); + } + } + } + + filterbank +} + +/// 完整的特征提取管线: 原始音频 → 展平的 log mel fbank +pub fn audio_to_features( + samples: &[f32], + sample_rate: u32, +) -> (Vec, usize, usize) { + let config = FeatureConfig { + sample_rate, + ..Default::default() + }; + let features = extract_fbank(samples, &config); + let n_frames = if config.n_mels > 0 { features.len() / config.n_mels } else { 0 }; + (features, n_frames, config.n_mels) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mel_conversion() { + let mel = hz_to_mel(1000.0); + assert!((mel_to_hz(mel) - 1000.0).abs() < 1.0); + } + + #[test] + fn test_pre_emphasis() { + let input = vec![1.0, 1.0, 1.0, 1.0]; + let output = pre_emphasis(&input, 0.97); + assert_eq!(output[0], 1.0); + assert_eq!(output[1], 1.0 - 0.97); + } + + #[test] + fn test_fbank_shape() { + let samples = vec![0.0f32; 16000]; + let (features, n_frames, n_mels) = audio_to_features(&samples, 16000); + assert_eq!(n_mels, 80); + assert!(n_frames > 0); + assert_eq!(features.len(), n_frames * n_mels); + } +} diff --git a/src/asr/mod.rs b/src/asr/mod.rs index 957834c..fef99cf 100644 --- a/src/asr/mod.rs +++ b/src/asr/mod.rs @@ -1,12 +1,13 @@ //! ASR (自动语音识别) 核心模块 //! -//! 基于 ONNX Runtime 实现语音识别功能 +//! 基于 ONNX Runtime (ort crate) 实现语音识别功能 pub mod types; pub mod engine; pub mod model; pub mod decoder; pub mod stream; +pub mod features; pub use types::{RecognizeResult, Language}; pub use engine::recognize; diff --git a/src/asr/model.rs b/src/asr/model.rs index 39e1ed9..c63dffe 100644 --- a/src/asr/model.rs +++ b/src/asr/model.rs @@ -1,29 +1,26 @@ //! ASR 模型模块 //! -//! 定义模型配置和加载逻辑 +//! 使用 ort (ONNX Runtime) 加载和管理模型 use anyhow::{Context, Result}; +use ort::session::Session; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; +use std::sync::Mutex; use tracing::info; /// 模型配置 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelConfig { - /// 模型文件路径 pub model_path: PathBuf, - /// 模型名称 pub name: String, - /// 支持的语言 pub languages: Vec, - /// 是否使用 GPU 加速 pub use_gpu: bool, } impl Default for ModelConfig { fn default() -> Self { Self { - // 默认模型路径 model_path: PathBuf::from("models/sensevoice-small.onnx"), name: "sensevoice-small".to_string(), languages: vec!["zh".to_string(), "en".to_string()], @@ -33,7 +30,6 @@ impl Default for ModelConfig { } impl ModelConfig { - /// 创建新的模型配置 pub fn new>(model_path: P, name: &str) -> Self { Self { model_path: model_path.as_ref().to_path_buf(), @@ -43,30 +39,10 @@ impl ModelConfig { } } - /// 从配置文件加载 - pub fn from_config_file>(path: P) -> Result { - let content = std::fs::read_to_string(path.as_ref()) - .with_context(|| format!("无法读取配置文件:{:?}", path.as_ref()))?; - - let config: Self = toml::from_str(&content) - .with_context(|| "无法解析模型配置")?; - - Ok(config) - } - - /// 保存到配置文件 - pub fn save_to_file>(&self, path: P) -> Result<()> { - let content = toml::to_string(self)?; - std::fs::write(path.as_ref(), content)?; - Ok(()) - } - - /// 检查模型文件是否存在 pub fn model_exists(&self) -> bool { self.model_path.exists() } - /// 获取模型文件大小 (MB) pub fn model_size_mb(&self) -> Option { std::fs::metadata(&self.model_path) .ok() @@ -76,46 +52,107 @@ impl ModelConfig { /// ASR 模型封装 pub struct AsrModel { - /// 模型配置 pub config: ModelConfig, - /// 是否已加载 - loaded: bool, + session: Mutex, + input_names: Vec, + output_names: Vec, } impl AsrModel { - /// 创建新的模型实例 - pub fn new(config: ModelConfig) -> Self { - Self { + pub fn load(config: ModelConfig) -> Result { + info!("加载模型: {} ({:?})", config.name, config.model_path); + + if !config.model_exists() { + anyhow::bail!("模型文件不存在: {:?}", config.model_path); + } + + let model_bytes = std::fs::read(&config.model_path) + .with_context(|| format!("无法读取模型文件: {:?}", config.model_path))?; + + let session = Session::builder()? + .commit_from_memory(&model_bytes) + .with_context(|| format!("无法加载 ONNX 模型: {:?}", config.model_path))?; + + let input_names: Vec = session + .inputs() + .iter() + .map(|info| info.name().to_string()) + .collect(); + let output_names: Vec = session + .outputs() + .iter() + .map(|info| info.name().to_string()) + .collect(); + + info!( + "模型加载完成: {} (输入: {:?}, 输出: {:?}, 大小: {:?} MB)", + config.name, + input_names, + output_names, + config.model_size_mb() + ); + + Ok(Self { config, - loaded: false, - } + session: Mutex::new(session), + input_names, + output_names, + }) } - /// 加载模型 - pub fn load(&mut self) -> Result<()> { - info!("加载模型:{}", self.config.name); + pub fn config(&self) -> &ModelConfig { + &self.config + } - if !self.config.model_exists() { - anyhow::bail!("模型文件不存在:{:?}", self.config.model_path); + /// 运行推理 - 接受 (形状, 数据) 元组输入 + /// 返回 (形状, f32 数据) 元组的映射 + pub fn run_f32( + &self, + inputs: Vec<(String, Vec, Vec)>, + ) -> Result, Vec)>> { + let mut session = self + .session + .lock() + .map_err(|e| anyhow::anyhow!("获取 session 锁失败: {}", e))?; + + // 构建 ort Value 输入 - ort 2.x 使用 (shape, data) 元组 + let mut ort_inputs = Vec::new(); + for (name, shape, data) in inputs { + let value = ort::value::Value::from_array((shape, data)) + .with_context(|| format!("构建输入张量失败: {}", name))?; + ort_inputs.push((name, value)); } - self.loaded = true; - info!("模型加载完成:{} ({:?} MB)", - self.config.name, - self.config.model_size_mb()); + let outputs = session + .run(ort_inputs) + .context("ONNX 推理失败")?; - Ok(()) + // 提取输出 + let mut result = Vec::new(); + for (name, value) in outputs.iter() { + // 提取 f32 张量 + if let Ok((shape_ref, data_ref)) = value.try_extract_tensor::() { + let shape: Vec = shape_ref.iter().map(|&d| d as usize).collect(); + let data: Vec = data_ref.to_vec(); + result.push((name.to_string(), shape, data)); + } else if let Ok((shape_ref, data_ref)) = value.try_extract_tensor::() { + let shape: Vec = shape_ref.iter().map(|&d| d as usize).collect(); + let data: Vec = data_ref.iter().map(|&v| v as f32).collect(); + result.push((name.to_string(), shape, data)); + } else { + info!("输出 '{}' 类型无法提取", name); + } + } + + Ok(result) } - /// 卸载模型 - pub fn unload(&mut self) { - self.loaded = false; - info!("模型已卸载:{}", self.config.name); + pub fn input_names(&self) -> &[String] { + &self.input_names } - /// 检查是否已加载 - pub fn is_loaded(&self) -> bool { - self.loaded + pub fn output_names(&self) -> &[String] { + &self.output_names } } @@ -123,7 +160,6 @@ impl AsrModel { pub mod presets { use super::*; - /// SenseVoice Small (推荐) pub fn sensevoice_small() -> ModelConfig { ModelConfig { model_path: PathBuf::from("models/sensevoice-small.onnx"), @@ -133,7 +169,6 @@ pub mod presets { } } - /// SenseVoice Base pub fn sensevoice_base() -> ModelConfig { ModelConfig { model_path: PathBuf::from("models/sensevoice-base.onnx"), @@ -143,7 +178,6 @@ pub mod presets { } } - /// FunASR Paraformer pub fn paraformer() -> ModelConfig { ModelConfig { model_path: PathBuf::from("models/paraformer.onnx"), @@ -153,7 +187,6 @@ pub mod presets { } } - /// Whisper Small (ONNX 版本) pub fn whisper_small() -> ModelConfig { ModelConfig { model_path: PathBuf::from("models/whisper-small.onnx"), @@ -165,8 +198,52 @@ pub mod presets { } /// 下载模型 (异步) -pub async fn download_model(_name: &str, _output_path: &Path) -> Result<()> { - // TODO: 实现模型下载 - // 可以从 ModelScope、HuggingFace 等下载 - todo!("模型下载功能待实现") +pub async fn download_model(name: &str, output_path: &Path) -> Result<()> { + let url = match name { + "sensevoice-small" => { + "https://huggingface.co/FunAudioLLM/SenseVoiceSmall/resolve/main/model.onnx" + } + _ => anyhow::bail!("未知模型: {}", name), + }; + + info!("正在下载模型 {} 到 {:?}", name, output_path); + + if let Some(parent) = output_path.parent() { + std::fs::create_dir_all(parent)?; + } + + let response = ureq::get(url).call() + .with_context(|| format!("下载请求失败: {}", url))?; + + let total_size: u64 = response + .header("Content-Length") + .and_then(|s| s.parse().ok()) + .unwrap_or(0); + + let mut reader = response.into_reader(); + let mut file = std::fs::File::create(output_path)?; + + use std::io::{Read, Write}; + let mut buffer = [0u8; 65536]; + let mut downloaded: u64 = 0; + + loop { + let bytes_read = reader.read(&mut buffer)?; + if bytes_read == 0 { + break; + } + file.write_all(&buffer[..bytes_read])?; + downloaded += bytes_read as u64; + + if total_size > 0 && downloaded % (1024 * 1024) < 65536 { + info!( + "下载进度: {:.1} MB / {:.1} MB", + downloaded as f64 / 1024.0 / 1024.0, + total_size as f64 / 1024.0 / 1024.0 + ); + } + } + + info!("模型下载完成: {} ({:?})", name, output_path); + Ok(()) } diff --git a/src/asr/stream.rs b/src/asr/stream.rs index ead5a95..608ec57 100644 --- a/src/asr/stream.rs +++ b/src/asr/stream.rs @@ -1,15 +1,13 @@ //! 流式识别模块 //! -//! 支持边录音边识别,降低延迟 +//! 支持边录音边识别,降低感知延迟 use anyhow::Result; use tokio::sync::mpsc; use tracing::{debug, info, warn}; -use crate::{ - asr::{RecognizeResult, engine::AsrEngine}, - audio::AudioData, -}; +use crate::audio::AudioData; +use crate::asr::{RecognizeResult, engine::AsrEngine, engine}; /// 流式识别器 pub struct StreamRecognizer { @@ -47,34 +45,29 @@ impl StreamRecognizer { if !self.is_active { return false; } - - // 检查是否有足够的音频数据 let duration_ms = self.buffer.len() as u64 * 1000 / self.sample_rate as u64; duration_ms >= self.min_duration_ms } /// 执行识别 - pub async fn recognize(&mut self, engine: &AsrEngine) -> Result> { + pub fn recognize(&mut self, engine: &AsrEngine) -> Result> { if !self.should_recognize() { return Ok(None); } - // 创建音频数据 let audio = AudioData::new( self.buffer.clone(), self.sample_rate, 1, ); - // 执行识别 match engine.recognize(&audio) { Ok(result) => { - // 清空缓冲区 (或者保留一小部分用于上下文) self.buffer.clear(); Ok(Some(result)) } Err(e) => { - warn!("流式识别失败:{}", e); + warn!("流式识别失败: {}", e); Ok(None) } } @@ -87,11 +80,10 @@ impl StreamRecognizer { info!("流式识别已启动"); } - /// 停止流式识别 + /// 停止流式识别,返回剩余缓冲区 pub fn stop(&mut self) -> Option> { self.is_active = false; let remaining = std::mem::take(&mut self.buffer); - if remaining.is_empty() { None } else { @@ -99,53 +91,61 @@ impl StreamRecognizer { } } - /// 设置识别间隔 pub fn with_interval(mut self, interval_ms: u64) -> Self { self.interval_ms = interval_ms; self } - /// 设置最小识别长度 pub fn with_min_duration(mut self, min_duration_ms: u64) -> Self { self.min_duration_ms = min_duration_ms; self } } -/// 流式识别通道 +/// 流式识别通道 (异步) pub struct StreamChannel { - /// 音频输入通道 audio_tx: mpsc::Sender>, - /// 结果输出通道 result_rx: mpsc::Receiver, } impl StreamChannel { /// 创建新的流式通道 - pub fn new() -> Self { + pub fn new(sample_rate: u32) -> Self { let (audio_tx, mut audio_rx) = mpsc::channel::>(100); - let (_result_tx, result_rx) = mpsc::channel::(10); + let (result_tx, result_rx) = mpsc::channel::(10); - // 启动后台处理任务 tokio::spawn(async move { - // TODO: 初始化 ASR 引擎 - // let engine = ... + let mut recognizer = StreamRecognizer::new(sample_rate); + recognizer.start(); - while let Some(samples) = audio_rx.recv().await { - // 处理音频片段 - debug!("收到音频片段:{} 样本", samples.len()); + let mut interval = tokio::time::interval(std::time::Duration::from_millis(1000)); - // TODO: 执行识别并发送结果 - // if let Ok(result) = engine.recognize(...) { - // result_tx.send(result).await.ok(); - // } + loop { + tokio::select! { + Some(samples) = audio_rx.recv() => { + recognizer.push_audio(&samples); + } + _ = interval.tick() => { + if recognizer.should_recognize() { + // 创建临时引擎或使用全局引擎 + match engine::ensure_engine_initialized() { + Ok(engine) => { + if let Ok(Some(result)) = recognizer.recognize(engine) { + let _ = result_tx.send(result).await; + } + } + Err(e) => { + debug!("流式识别引擎未就绪: {}", e); + } + } + } + } + else => break, + } } }); - Self { - audio_tx, - result_rx, - } + Self { audio_tx, result_rx } } /// 发送音频数据 @@ -159,9 +159,3 @@ impl StreamChannel { self.result_rx.recv().await } } - -impl Default for StreamChannel { - fn default() -> Self { - Self::new() - } -} diff --git a/src/audio/capture.rs b/src/audio/capture.rs index eb2319c..9bf4597 100644 --- a/src/audio/capture.rs +++ b/src/audio/capture.rs @@ -1,10 +1,12 @@ //! 音频捕获模块 //! -//! 注意:此模块需要 cpal 库,当前已被禁用 -//! 在完整版本中,用于实现实时录音功能 +//! 使用 cpal 实现实时音频录制 -use anyhow::Result; +use anyhow::{Context, Result}; +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use std::path::PathBuf; +use std::sync::{Arc, Mutex}; +use tracing::{info, warn}; /// 录音配置 #[derive(Debug, Clone)] @@ -27,20 +29,330 @@ impl Default for RecordingConfig { } } -/// 录制音频(占位实现) +/// 录制音频到文件 /// -/// 注意:此功能需要系统音频库支持 -/// 在完整版本中实现实时录音 -pub async fn record_audio(_config: RecordingConfig) -> Result<(String, f32)> { - anyhow::bail!("录音功能需要 cpal 库支持,当前构建版本已禁用。请启用 cpal 特性并安装系统音频库。") +/// 录音直到调用方发送停止信号 (通过 drop RecordingHandle) +pub async fn record_audio(config: RecordingConfig) -> Result<(String, f32)> { + info!("开始录音: 采样率={}, 声道={}", config.sample_rate, config.channels); + + let host = cpal::default_host(); + let device = host + .default_input_device() + .context("没有可用的输入设备")?; + + info!("使用设备: {}", device.name().unwrap_or_else(|_| "未知".to_string())); + + // 获取支持的配置 + let mut supported_configs = device + .supported_input_configs() + .context("获取音频配置失败")?; + + // 查找匹配的采样率配置 + let config_found = supported_configs + .find(|c| { + c.min_sample_rate().0 <= config.sample_rate + && c.max_sample_rate().0 >= config.sample_rate + && c.channels() == config.channels + }) + .or_else(|| { + // 回退: 使用默认配置 + device + .supported_input_configs() + .ok() + .and_then(|mut configs| configs.next()) + }) + .context("没有匹配的音频配置")?; + + let actual_sample_rate = config_found + .min_sample_rate() + .max(cpal::SampleRate(config.sample_rate)) + .min(config_found.max_sample_rate()); + + let actual_config: cpal::StreamConfig = cpal::StreamConfig { + sample_rate: actual_sample_rate, + channels: config.channels, + buffer_size: cpal::BufferSize::Default, + }; + + // 音频缓冲区 + let samples = Arc::new(Mutex::new(Vec::::new())); + let samples_clone = Arc::clone(&samples); + + // 创建录音数据回调 + let err_fn = |err: cpal::StreamError| { + warn!("音频流错误: {}", err); + }; + + let stream = match config_found.sample_format() { + cpal::SampleFormat::F32 => device.build_input_stream( + &actual_config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let mut buf = samples_clone.lock().unwrap(); + buf.extend_from_slice(data); + }, + err_fn, + None, + )?, + cpal::SampleFormat::I16 => device.build_input_stream( + &actual_config, + move |data: &[i16], _: &cpal::InputCallbackInfo| { + let mut buf = samples_clone.lock().unwrap(); + for &sample in data { + buf.push(sample as f32 / 32768.0); + } + }, + err_fn, + None, + )?, + cpal::SampleFormat::I32 => device.build_input_stream( + &actual_config, + move |data: &[i32], _: &cpal::InputCallbackInfo| { + let mut buf = samples_clone.lock().unwrap(); + for &sample in data { + buf.push(sample as f32 / 2147483648.0); + } + }, + err_fn, + None, + )?, + cpal::SampleFormat::U16 => device.build_input_stream( + &actual_config, + move |data: &[u16], _: &cpal::InputCallbackInfo| { + let mut buf = samples_clone.lock().unwrap(); + for &sample in data { + buf.push((sample as f32 - 32768.0) / 32768.0); + } + }, + err_fn, + None, + )?, + other => { + anyhow::bail!("不支持的采样格式: {:?}", other); + } + }; + + // 播放流 + stream.play()?; + info!("音频流已启动"); + + // 录音 5 秒 (可配置的默认值) + let duration_secs = 5.0; + tokio::time::sleep(std::time::Duration::from_secs_f32(duration_secs)).await; + + // 停止流 + drop(stream); + info!("音频流已停止"); + + // 获取采集的样本 + let collected_samples = samples.lock().unwrap().clone(); + + if collected_samples.is_empty() { + anyhow::bail!("未采集到任何音频数据"); + } + + info!("采集到 {} 个样本", collected_samples.len()); + + // 保存到文件 + let output_path = config.output_path.unwrap_or_else(|| { + let ts = chrono::Local::now().format("%Y%m%d_%H%M%S"); + PathBuf::from(format!("recordings/rec_{}.wav", ts)) + }); + + if let Some(parent) = output_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + + let channels = actual_config.channels; + let sr = actual_config.sample_rate.0; + + // 写入 WAV 文件 + let spec = hound::WavSpec { + channels, + sample_rate: sr, + bits_per_sample: 16, + sample_format: hound::SampleFormat::Int, + }; + + let mut writer = hound::WavWriter::create(&output_path, spec) + .context("无法创建 WAV 文件")?; + + for sample in &collected_samples { + writer.write_sample((sample.clamp(-1.0, 1.0) * 32767.0) as i16)?; + } + + writer.finalize()?; + + let actual_duration = collected_samples.len() as f32 / (sr as f32 * channels as f32); + + info!("录音完成: {:?}, 时长={:.2}s", output_path, actual_duration); + + Ok((output_path.to_string_lossy().to_string(), actual_duration)) +} + +/// 创建录音句柄 (非阻塞) +/// +/// 返回 (RecordingHandle, 样本 Arc) +/// drop RecordingHandle 会停止录音 +pub fn start_recording( + sample_rate: u32, + channels: u16, +) -> Result<(RecordingHandle, Arc>>)> { + let host = cpal::default_host(); + let device = host + .default_input_device() + .context("没有可用的输入设备")?; + + let supported_configs: Vec<_> = device + .supported_input_configs() + .context("获取音频配置失败")? + .collect(); + + let config_found = supported_configs + .iter() + .find(|c| { + c.min_sample_rate().0 <= sample_rate + && c.max_sample_rate().0 >= sample_rate + }) + .or_else(|| supported_configs.first()) + .context("没有匹配的音频配置")?; + + let actual_sample_rate = config_found + .min_sample_rate() + .max(cpal::SampleRate(sample_rate)) + .min(config_found.max_sample_rate()); + + let actual_channels = config_found.channels().max(channels); + + let stream_config = cpal::StreamConfig { + sample_rate: actual_sample_rate, + channels: actual_channels, + buffer_size: cpal::BufferSize::Default, + }; + + let samples = Arc::new(Mutex::new(Vec::::new())); + let samples_clone = Arc::clone(&samples); + + let err_fn = |err: cpal::StreamError| { + warn!("音频流错误: {}", err); + }; + + let stream = match config_found.sample_format() { + cpal::SampleFormat::F32 => device.build_input_stream( + &stream_config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let mut buf = samples_clone.lock().unwrap(); + buf.extend_from_slice(data); + }, + err_fn, + None, + )?, + cpal::SampleFormat::I16 => device.build_input_stream( + &stream_config, + move |data: &[i16], _: &cpal::InputCallbackInfo| { + let mut buf = samples_clone.lock().unwrap(); + for &s in data { + buf.push(s as f32 / 32768.0); + } + }, + err_fn, + None, + )?, + cpal::SampleFormat::I32 => device.build_input_stream( + &stream_config, + move |data: &[i32], _: &cpal::InputCallbackInfo| { + let mut buf = samples_clone.lock().unwrap(); + for &s in data { + buf.push(s as f32 / 2147483648.0); + } + }, + err_fn, + None, + )?, + other => { + anyhow::bail!("不支持的采样格式: {:?}", other); + } + }; + + stream.play()?; + info!("录音已启动: 采样率={}, 声道={}", actual_sample_rate.0, actual_channels); + + Ok(( + RecordingHandle { + stream: Some(stream), + sample_rate: actual_sample_rate.0, + channels: actual_channels, + }, + samples, + )) +} + +/// 录音句柄 - drop 时自动停止 +pub struct RecordingHandle { + stream: Option, + sample_rate: u32, + channels: u16, +} + +impl RecordingHandle { + /// 停止录音并保存 + pub fn stop_and_save( + &mut self, + samples: Arc>>, + output_path: &PathBuf, + ) -> Result<(String, f32)> { + // 停止流 + self.stream.take(); + info!("录音已停止"); + + let collected = samples.lock().unwrap().clone(); + + if collected.is_empty() { + anyhow::bail!("未采集到音频数据"); + } + + if let Some(parent) = output_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + + let spec = hound::WavSpec { + channels: self.channels, + sample_rate: self.sample_rate, + bits_per_sample: 16, + sample_format: hound::SampleFormat::Int, + }; + + let mut writer = hound::WavWriter::create(output_path, spec)?; + for sample in &collected { + writer.write_sample((sample.clamp(-1.0, 1.0) * 32767.0) as i16)?; + } + writer.finalize()?; + + let duration = collected.len() as f32 / (self.sample_rate as f32 * self.channels as f32); + + Ok((output_path.to_string_lossy().to_string(), duration)) + } + + pub fn sample_rate(&self) -> u32 { self.sample_rate } + pub fn channels(&self) -> u16 { self.channels } } /// 获取可用的输入设备列表 pub fn list_input_devices() -> Vec { - vec!["[需要 cpal 库支持]".to_string()] + let host = cpal::default_host(); + match host.input_devices() { + Ok(devices) => devices + .filter_map(|d| d.name().ok()) + .collect(), + Err(e) => { + warn!("获取输入设备失败: {}", e); + vec![] + } + } } /// 获取默认输入设备信息 pub fn get_default_input_device_info() -> Option { - Some("[需要 cpal 库支持]".to_string()) + let host = cpal::default_host(); + host.default_input_device() + .and_then(|d| d.name().ok()) } diff --git a/src/bin/cli.rs b/src/bin/cli.rs index b291700..ad7585e 100644 --- a/src/bin/cli.rs +++ b/src/bin/cli.rs @@ -1,9 +1,10 @@ //! 命令行语音识别工具 //! //! 用法: -//! impress_asr record -o output.wav # 录音 +//! impress_asr record -o output.wav # 录音 (5 秒) //! impress_asr recognize audio.wav # 识别音频文件 //! impress_asr devices # 列出音频设备 +//! impress_asr download # 下载模型 use anyhow::Result; use clap::{Parser, Subcommand}; @@ -11,6 +12,7 @@ use std::path::PathBuf; use tracing::info; use impress_asr_lib::audio; +use impress_asr_lib::asr; #[derive(Parser)] #[command(name = "impress_asr")] @@ -22,14 +24,14 @@ struct Cli { #[derive(Subcommand)] enum Commands { - /// 录制音频 + /// 录制音频 (默认 5 秒) Record { /// 输出文件路径 #[arg(short, long)] output: Option, /// 录音时长 (秒) - #[arg(short, long, default_value = "10")] + #[arg(short, long, default_value = "5")] duration: u32, }, @@ -38,7 +40,7 @@ enum Commands { /// 音频文件路径 input: PathBuf, - /// 模型路径 + /// 模型路径 (默认: models/sensevoice-small.onnx) #[arg(short, long)] model: Option, }, @@ -64,7 +66,7 @@ async fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter( tracing_subscriber::EnvFilter::from_default_env() - .add_directive("impress_asr=info".parse().unwrap()) + .add_directive("impress_asr_input_rust=info".parse().unwrap()) ) .init(); @@ -74,65 +76,82 @@ async fn main() -> Result<()> { Commands::Record { output, duration } => { info!("开始录音,时长={} 秒", duration); + let output_path = output.unwrap_or_else(|| { + let ts = chrono::Local::now().format("%Y%m%d_%H%M%S"); + PathBuf::from(format!("recordings/rec_{}.wav", ts)) + }); + let config = audio::RecordingConfig { sample_rate: 16000, channels: 1, - output_path: output, - ..Default::default() + output_path: Some(output_path.clone()), }; - // 注意:这里需要实现定时录音功能 - // 当前实现是固定 10 秒 match audio::record_audio(config).await { Ok((path, secs)) => { - println!("录音完成:{}", path); - println!("时长:{:.2} 秒", secs); + println!("录音完成: {}", path); + println!("时长: {:.2} 秒", secs); } Err(e) => { - eprintln!("录音失败:{}", e); + eprintln!("录音失败: {}", e); std::process::exit(1); } } } - Commands::Recognize { input, model: _model } => { - info!("识别音频:{:?}", input); + Commands::Recognize { input, model } => { + info!("识别音频: {:?}", input); - // 检查文件是否存在 if !input.exists() { - eprintln!("文件不存在:{:?}", input); + eprintln!("文件不存在: {:?}", input); std::process::exit(1); } + // 初始 ASR 引擎 + if let Some(model_path) = model { + let config = asr::model::ModelConfig::new(&model_path, "custom"); + asr::engine::init_engine(config)?; + } else { + // 使用默认模型 + let config = asr::model::ModelConfig::default(); + if config.model_exists() { + asr::engine::init_engine(config)?; + } else { + eprintln!("模型文件不存在: {:?}", config.model_path); + eprintln!(); + eprintln!("请先下载模型:"); + eprintln!(" impress_asr download"); + eprintln!(); + eprintln!("或手动下载到: {:?}", config.model_path); + std::process::exit(1); + } + } + // 解码音频 println!("正在加载音频..."); - let audio_data = match audio::decoder::decode_audio_for_asr(&input) { - Ok(data) => data, - Err(e) => { - eprintln!("解码失败:{}", e); - std::process::exit(1); - } - }; + let audio_data = audio::decoder::decode_audio_for_asr(&input)?; println!("音频信息:"); - println!(" 采样率:{} Hz", audio_data.sample_rate); - println!(" 声道数:{}", audio_data.channels); - println!(" 时长:{:.2} 秒", audio_data.duration_secs); + println!(" 采样率: {} Hz", audio_data.sample_rate); + println!(" 声道数: {}", audio_data.channels); + println!(" 时长: {:.2} 秒", audio_data.duration_secs); - // 识别 (需要模型文件) + // 执行识别 println!("\n正在识别..."); - println!("注意:需要先下载 ONNX 模型文件"); - println!("运行:impress_asr download --output models/sensevoice-small.onnx"); - - // TODO: 实现识别 - // match asr::recognize(&input.to_string_lossy()).await { - // Ok(result) => { - // println!("识别结果:{}", result.text); - // } - // Err(e) => { - // eprintln!("识别失败:{}", e); - // } - // } + match asr::recognize(&input.to_string_lossy()).await { + Ok(result) => { + println!("\n=== 识别结果 ==="); + println!("{}", result.text); + println!("\n=== 详细信息 ==="); + println!(" 语言: {}", result.language); + println!(" 置信度: {:.1}%", result.confidence * 100.0); + println!(" 耗时: {} ms", result.duration_ms); + } + Err(e) => { + eprintln!("识别失败: {}", e); + std::process::exit(1); + } + } } Commands::Devices => { @@ -147,38 +166,42 @@ async fn main() -> Result<()> { } if let Some(default) = audio::get_default_input_device_info() { - println!("\n默认设备:{}", default); + println!("\n默认设备: {}", default); } } Commands::Download { name, output } => { - println!("下载模型:{}", name); + println!("下载模型: {}", name); let output_path = output.unwrap_or_else(|| { PathBuf::from(format!("models/{}.onnx", name)) }); - // 确保目录存在 if let Some(parent) = output_path.parent() { std::fs::create_dir_all(parent)?; } - println!("下载链接:"); - match name.as_str() { - "sensevoice-small" => { - println!(" ModelScope: https://modelscope.cn/models/iic/SenseVoiceSmall/resolve/main/model.onnx"); - println!(" HuggingFace: https://huggingface.co/FunAudioLLM/SenseVoiceSmall/resolve/main/model.onnx"); + match asr::model::download_model(&name, &output_path).await { + Ok(()) => { + println!("模型下载完成: {:?}", output_path); + if let Ok(size) = std::fs::metadata(&output_path) { + println!("文件大小: {:.1} MB", size.len() as f64 / 1024.0 / 1024.0); + } } - "paraformer" => { - println!(" ModelScope: https://modelscope.cn/models/iic/paraformer-zh/resolve/main/model.onnx"); - } - _ => { - println!(" 未知模型,请手动下载"); + Err(e) => { + eprintln!("下载失败: {}", e); + eprintln!(); + eprintln!("请手动下载模型到: {:?}", output_path); + match name.as_str() { + "sensevoice-small" => { + eprintln!(" HuggingFace: https://huggingface.co/FunAudioLLM/SenseVoiceSmall"); + eprintln!(" ModelScope: https://modelscope.cn/models/iic/SenseVoiceSmall"); + } + _ => eprintln!(" 请搜索对应的 ONNX 模型下载地址"), + } + std::process::exit(1); } } - - println!("\n保存到:{:?}", output_path); - println!("下载后请运行:impress_asr recognize <音频文件>"); } }