- 适配 ort 2.0.0-rc.12 ONNX Runtime API(Session, Value, Shape) - 实现 log mel fbank 音频特征提取(预加重→分帧→加窗→FFT→Mel滤波器组→对数) - 实现 cpal 实时音频捕获模块(支持多采样格式: F32/I16/I32/U16) - 实现 CTC 贪婪解码器和 Vocabulary 词表管理 - 完成 ASR 推理引擎(特征提取→ONNX推理→结果解码完整管线) - 更新 Tauri 命令和 CLI 工具接入真实 ASR 引擎 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
359 lines
10 KiB
Rust
359 lines
10 KiB
Rust
//! 音频捕获模块
|
|
//!
|
|
//! 使用 cpal 实现实时音频录制
|
|
|
|
use anyhow::{Context, Result};
|
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
|
use std::path::PathBuf;
|
|
use std::sync::{Arc, Mutex};
|
|
use tracing::{info, warn};
|
|
|
|
/// 录音配置
|
|
#[derive(Debug, Clone)]
|
|
pub struct RecordingConfig {
|
|
/// 采样率
|
|
pub sample_rate: u32,
|
|
/// 声道数
|
|
pub channels: u16,
|
|
/// 输出路径
|
|
pub output_path: Option<PathBuf>,
|
|
}
|
|
|
|
impl Default for RecordingConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
sample_rate: 16000,
|
|
channels: 1,
|
|
output_path: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// 录制音频到文件
|
|
///
|
|
/// 录音直到调用方发送停止信号 (通过 drop RecordingHandle)
|
|
pub async fn record_audio(config: RecordingConfig) -> Result<(String, f32)> {
|
|
info!("开始录音: 采样率={}, 声道={}", config.sample_rate, config.channels);
|
|
|
|
let host = cpal::default_host();
|
|
let device = host
|
|
.default_input_device()
|
|
.context("没有可用的输入设备")?;
|
|
|
|
info!("使用设备: {}", device.name().unwrap_or_else(|_| "未知".to_string()));
|
|
|
|
// 获取支持的配置
|
|
let mut supported_configs = device
|
|
.supported_input_configs()
|
|
.context("获取音频配置失败")?;
|
|
|
|
// 查找匹配的采样率配置
|
|
let config_found = supported_configs
|
|
.find(|c| {
|
|
c.min_sample_rate().0 <= config.sample_rate
|
|
&& c.max_sample_rate().0 >= config.sample_rate
|
|
&& c.channels() == config.channels
|
|
})
|
|
.or_else(|| {
|
|
// 回退: 使用默认配置
|
|
device
|
|
.supported_input_configs()
|
|
.ok()
|
|
.and_then(|mut configs| configs.next())
|
|
})
|
|
.context("没有匹配的音频配置")?;
|
|
|
|
let actual_sample_rate = config_found
|
|
.min_sample_rate()
|
|
.max(cpal::SampleRate(config.sample_rate))
|
|
.min(config_found.max_sample_rate());
|
|
|
|
let actual_config: cpal::StreamConfig = cpal::StreamConfig {
|
|
sample_rate: actual_sample_rate,
|
|
channels: config.channels,
|
|
buffer_size: cpal::BufferSize::Default,
|
|
};
|
|
|
|
// 音频缓冲区
|
|
let samples = Arc::new(Mutex::new(Vec::<f32>::new()));
|
|
let samples_clone = Arc::clone(&samples);
|
|
|
|
// 创建录音数据回调
|
|
let err_fn = |err: cpal::StreamError| {
|
|
warn!("音频流错误: {}", err);
|
|
};
|
|
|
|
let stream = match config_found.sample_format() {
|
|
cpal::SampleFormat::F32 => device.build_input_stream(
|
|
&actual_config,
|
|
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
|
let mut buf = samples_clone.lock().unwrap();
|
|
buf.extend_from_slice(data);
|
|
},
|
|
err_fn,
|
|
None,
|
|
)?,
|
|
cpal::SampleFormat::I16 => device.build_input_stream(
|
|
&actual_config,
|
|
move |data: &[i16], _: &cpal::InputCallbackInfo| {
|
|
let mut buf = samples_clone.lock().unwrap();
|
|
for &sample in data {
|
|
buf.push(sample as f32 / 32768.0);
|
|
}
|
|
},
|
|
err_fn,
|
|
None,
|
|
)?,
|
|
cpal::SampleFormat::I32 => device.build_input_stream(
|
|
&actual_config,
|
|
move |data: &[i32], _: &cpal::InputCallbackInfo| {
|
|
let mut buf = samples_clone.lock().unwrap();
|
|
for &sample in data {
|
|
buf.push(sample as f32 / 2147483648.0);
|
|
}
|
|
},
|
|
err_fn,
|
|
None,
|
|
)?,
|
|
cpal::SampleFormat::U16 => device.build_input_stream(
|
|
&actual_config,
|
|
move |data: &[u16], _: &cpal::InputCallbackInfo| {
|
|
let mut buf = samples_clone.lock().unwrap();
|
|
for &sample in data {
|
|
buf.push((sample as f32 - 32768.0) / 32768.0);
|
|
}
|
|
},
|
|
err_fn,
|
|
None,
|
|
)?,
|
|
other => {
|
|
anyhow::bail!("不支持的采样格式: {:?}", other);
|
|
}
|
|
};
|
|
|
|
// 播放流
|
|
stream.play()?;
|
|
info!("音频流已启动");
|
|
|
|
// 录音 5 秒 (可配置的默认值)
|
|
let duration_secs = 5.0;
|
|
tokio::time::sleep(std::time::Duration::from_secs_f32(duration_secs)).await;
|
|
|
|
// 停止流
|
|
drop(stream);
|
|
info!("音频流已停止");
|
|
|
|
// 获取采集的样本
|
|
let collected_samples = samples.lock().unwrap().clone();
|
|
|
|
if collected_samples.is_empty() {
|
|
anyhow::bail!("未采集到任何音频数据");
|
|
}
|
|
|
|
info!("采集到 {} 个样本", collected_samples.len());
|
|
|
|
// 保存到文件
|
|
let output_path = config.output_path.unwrap_or_else(|| {
|
|
let ts = chrono::Local::now().format("%Y%m%d_%H%M%S");
|
|
PathBuf::from(format!("recordings/rec_{}.wav", ts))
|
|
});
|
|
|
|
if let Some(parent) = output_path.parent() {
|
|
let _ = std::fs::create_dir_all(parent);
|
|
}
|
|
|
|
let channels = actual_config.channels;
|
|
let sr = actual_config.sample_rate.0;
|
|
|
|
// 写入 WAV 文件
|
|
let spec = hound::WavSpec {
|
|
channels,
|
|
sample_rate: sr,
|
|
bits_per_sample: 16,
|
|
sample_format: hound::SampleFormat::Int,
|
|
};
|
|
|
|
let mut writer = hound::WavWriter::create(&output_path, spec)
|
|
.context("无法创建 WAV 文件")?;
|
|
|
|
for sample in &collected_samples {
|
|
writer.write_sample((sample.clamp(-1.0, 1.0) * 32767.0) as i16)?;
|
|
}
|
|
|
|
writer.finalize()?;
|
|
|
|
let actual_duration = collected_samples.len() as f32 / (sr as f32 * channels as f32);
|
|
|
|
info!("录音完成: {:?}, 时长={:.2}s", output_path, actual_duration);
|
|
|
|
Ok((output_path.to_string_lossy().to_string(), actual_duration))
|
|
}
|
|
|
|
/// 创建录音句柄 (非阻塞)
|
|
///
|
|
/// 返回 (RecordingHandle, 样本 Arc)
|
|
/// drop RecordingHandle 会停止录音
|
|
pub fn start_recording(
|
|
sample_rate: u32,
|
|
channels: u16,
|
|
) -> Result<(RecordingHandle, Arc<Mutex<Vec<f32>>>)> {
|
|
let host = cpal::default_host();
|
|
let device = host
|
|
.default_input_device()
|
|
.context("没有可用的输入设备")?;
|
|
|
|
let supported_configs: Vec<_> = device
|
|
.supported_input_configs()
|
|
.context("获取音频配置失败")?
|
|
.collect();
|
|
|
|
let config_found = supported_configs
|
|
.iter()
|
|
.find(|c| {
|
|
c.min_sample_rate().0 <= sample_rate
|
|
&& c.max_sample_rate().0 >= sample_rate
|
|
})
|
|
.or_else(|| supported_configs.first())
|
|
.context("没有匹配的音频配置")?;
|
|
|
|
let actual_sample_rate = config_found
|
|
.min_sample_rate()
|
|
.max(cpal::SampleRate(sample_rate))
|
|
.min(config_found.max_sample_rate());
|
|
|
|
let actual_channels = config_found.channels().max(channels);
|
|
|
|
let stream_config = cpal::StreamConfig {
|
|
sample_rate: actual_sample_rate,
|
|
channels: actual_channels,
|
|
buffer_size: cpal::BufferSize::Default,
|
|
};
|
|
|
|
let samples = Arc::new(Mutex::new(Vec::<f32>::new()));
|
|
let samples_clone = Arc::clone(&samples);
|
|
|
|
let err_fn = |err: cpal::StreamError| {
|
|
warn!("音频流错误: {}", err);
|
|
};
|
|
|
|
let stream = match config_found.sample_format() {
|
|
cpal::SampleFormat::F32 => device.build_input_stream(
|
|
&stream_config,
|
|
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
|
let mut buf = samples_clone.lock().unwrap();
|
|
buf.extend_from_slice(data);
|
|
},
|
|
err_fn,
|
|
None,
|
|
)?,
|
|
cpal::SampleFormat::I16 => device.build_input_stream(
|
|
&stream_config,
|
|
move |data: &[i16], _: &cpal::InputCallbackInfo| {
|
|
let mut buf = samples_clone.lock().unwrap();
|
|
for &s in data {
|
|
buf.push(s as f32 / 32768.0);
|
|
}
|
|
},
|
|
err_fn,
|
|
None,
|
|
)?,
|
|
cpal::SampleFormat::I32 => device.build_input_stream(
|
|
&stream_config,
|
|
move |data: &[i32], _: &cpal::InputCallbackInfo| {
|
|
let mut buf = samples_clone.lock().unwrap();
|
|
for &s in data {
|
|
buf.push(s as f32 / 2147483648.0);
|
|
}
|
|
},
|
|
err_fn,
|
|
None,
|
|
)?,
|
|
other => {
|
|
anyhow::bail!("不支持的采样格式: {:?}", other);
|
|
}
|
|
};
|
|
|
|
stream.play()?;
|
|
info!("录音已启动: 采样率={}, 声道={}", actual_sample_rate.0, actual_channels);
|
|
|
|
Ok((
|
|
RecordingHandle {
|
|
stream: Some(stream),
|
|
sample_rate: actual_sample_rate.0,
|
|
channels: actual_channels,
|
|
},
|
|
samples,
|
|
))
|
|
}
|
|
|
|
/// 录音句柄 - drop 时自动停止
|
|
pub struct RecordingHandle {
|
|
stream: Option<cpal::Stream>,
|
|
sample_rate: u32,
|
|
channels: u16,
|
|
}
|
|
|
|
impl RecordingHandle {
|
|
/// 停止录音并保存
|
|
pub fn stop_and_save(
|
|
&mut self,
|
|
samples: Arc<Mutex<Vec<f32>>>,
|
|
output_path: &PathBuf,
|
|
) -> Result<(String, f32)> {
|
|
// 停止流
|
|
self.stream.take();
|
|
info!("录音已停止");
|
|
|
|
let collected = samples.lock().unwrap().clone();
|
|
|
|
if collected.is_empty() {
|
|
anyhow::bail!("未采集到音频数据");
|
|
}
|
|
|
|
if let Some(parent) = output_path.parent() {
|
|
let _ = std::fs::create_dir_all(parent);
|
|
}
|
|
|
|
let spec = hound::WavSpec {
|
|
channels: self.channels,
|
|
sample_rate: self.sample_rate,
|
|
bits_per_sample: 16,
|
|
sample_format: hound::SampleFormat::Int,
|
|
};
|
|
|
|
let mut writer = hound::WavWriter::create(output_path, spec)?;
|
|
for sample in &collected {
|
|
writer.write_sample((sample.clamp(-1.0, 1.0) * 32767.0) as i16)?;
|
|
}
|
|
writer.finalize()?;
|
|
|
|
let duration = collected.len() as f32 / (self.sample_rate as f32 * self.channels as f32);
|
|
|
|
Ok((output_path.to_string_lossy().to_string(), duration))
|
|
}
|
|
|
|
pub fn sample_rate(&self) -> u32 { self.sample_rate }
|
|
pub fn channels(&self) -> u16 { self.channels }
|
|
}
|
|
|
|
/// 获取可用的输入设备列表
|
|
pub fn list_input_devices() -> Vec<String> {
|
|
let host = cpal::default_host();
|
|
match host.input_devices() {
|
|
Ok(devices) => devices
|
|
.filter_map(|d| d.name().ok())
|
|
.collect(),
|
|
Err(e) => {
|
|
warn!("获取输入设备失败: {}", e);
|
|
vec![]
|
|
}
|
|
}
|
|
}
|
|
|
|
/// 获取默认输入设备信息
|
|
pub fn get_default_input_device_info() -> Option<String> {
|
|
let host = cpal::default_host();
|
|
host.default_input_device()
|
|
.and_then(|d| d.name().ok())
|
|
}
|