impress_asr_input_rust/src/asr/features.rs
impressionyang b5b7930304
Some checks failed
Build Windows GUI / build-windows (push) Has been cancelled
Build Windows GUI / release (push) Has been cancelled
feat: 完成 ASR 识别核心链路实现
- 适配 ort 2.0.0-rc.12 ONNX Runtime API(Session, Value, Shape)
- 实现 log mel fbank 音频特征提取(预加重→分帧→加窗→FFT→Mel滤波器组→对数)
- 实现 cpal 实时音频捕获模块(支持多采样格式: F32/I16/I32/U16)
- 实现 CTC 贪婪解码器和 Vocabulary 词表管理
- 完成 ASR 推理引擎(特征提取→ONNX推理→结果解码完整管线)
- 更新 Tauri 命令和 CLI 工具接入真实 ASR 引擎

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-02 19:41:11 +08:00

239 lines
6.5 KiB
Rust

//! 音频特征提取模块
//!
//! 实现从原始音频到模型输入特征的转换:
//! - 预加重
//! - 分帧 & 加窗
//! - FFT + 功率谱
//! - Mel 滤波器组
//! - 对数能量 (log fbank)
use realfft::{RealFftPlanner, num_complex::Complex};
/// 特征提取配置
#[derive(Debug, Clone)]
pub struct FeatureConfig {
/// 采样率
pub sample_rate: u32,
/// FFT 窗口大小
pub n_fft: usize,
/// 帧移 (hop length)
pub hop_length: usize,
/// 窗长 (win length)
pub win_length: usize,
/// Mel 滤波器数量
pub n_mels: usize,
/// 最低频率
pub f_min: f32,
/// 最高频率
pub f_max: f32,
/// 预加重系数
pub pre_emphasis: f32,
}
impl Default for FeatureConfig {
fn default() -> Self {
Self {
sample_rate: 16000,
n_fft: 512,
hop_length: 160, // 10ms at 16kHz
win_length: 400, // 25ms at 16kHz
n_mels: 80,
f_min: 0.0,
f_max: 8000.0,
pre_emphasis: 0.97,
}
}
}
/// 从原始音频提取 log mel fbank 特征
pub fn extract_fbank(samples: &[f32], config: &FeatureConfig) -> Vec<f32> {
// 1. 预加重
let emphasized = pre_emphasis(samples, config.pre_emphasis);
// 2. 分帧加窗
let frames = frame(&emphasized, config.n_fft, config.hop_length, config.win_length);
if frames.is_empty() {
return vec![];
}
// 3. FFT + 功率谱 + Mel 滤波器组 + 对数
let n_spec = config.n_fft / 2 + 1;
let mut planner = RealFftPlanner::<f32>::new();
let r2c = planner.plan_fft_forward(config.n_fft);
// 预计算汉宁窗和 mel 权重
let window: Vec<f32> = (0..config.win_length)
.map(|i| {
0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (config.win_length - 1) as f32).cos())
})
.collect();
let mel_weights = create_mel_filterbank(
config.n_fft,
config.sample_rate,
config.n_mels,
config.f_min,
config.f_max,
);
// 复用缓冲区
let mut fft_input = vec![0.0f32; config.n_fft];
let mut fft_output = vec![Complex::new(0.0f32, 0.0f32); n_spec];
let mut mel_frame = vec![0.0f32; config.n_mels];
let mut all_features = Vec::new();
for frame_data in &frames {
// 加窗
let copy_len = config.win_length.min(frame_data.len());
for i in 0..copy_len {
fft_input[i] = frame_data[i] * window[i];
}
for i in copy_len..config.n_fft {
fft_input[i] = 0.0;
}
// FFT
r2c.process(&mut fft_input, &mut fft_output).expect("FFT 失败");
// 计算 mel 能量 (直接在 FFT 输出上计算)
for m in 0..config.n_mels {
let mut energy = 0.0f32;
for (i, weight) in mel_weights[m].iter().enumerate() {
if i >= n_spec { break; }
let re = fft_output[i].re;
let im = fft_output[i].im;
energy += (re * re + im * im) * weight * weight / config.n_fft as f32;
}
// 对数
mel_frame[m] = (energy + 1e-10).ln();
}
all_features.extend_from_slice(&mel_frame);
}
all_features
}
/// 预加重滤波
fn pre_emphasis(samples: &[f32], coef: f32) -> Vec<f32> {
if samples.len() < 2 {
return samples.to_vec();
}
let mut output = Vec::with_capacity(samples.len());
output.push(samples[0]);
for i in 1..samples.len() {
output.push(samples[i] - coef * samples[i - 1]);
}
output
}
/// 分帧 + 汉宁窗
fn frame(samples: &[f32], n_fft: usize, hop_length: usize, _win_length: usize) -> Vec<Vec<f32>> {
let mut frames = Vec::new();
let mut start = 0;
while start + n_fft <= samples.len() {
let frame_data = samples[start..start + n_fft].to_vec();
frames.push(frame_data);
start += hop_length;
}
frames
}
/// 频率到 mel
fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
/// mel 到频率
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
}
/// 创建 Mel 滤波器组
fn create_mel_filterbank(
n_fft: usize,
sample_rate: u32,
n_mels: usize,
f_min: f32,
f_max: f32,
) -> Vec<Vec<f32>> {
let n_spec = n_fft / 2 + 1;
let mel_min = hz_to_mel(f_min);
let mel_max = hz_to_mel(f_max.min(sample_rate as f32 / 2.0));
let mel_points: Vec<f32> = (0..=n_mels + 1)
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
.collect();
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let bin_points: Vec<usize> = hz_points
.iter()
.map(|&h| ((n_fft as f32 + 1.0) * h / sample_rate as f32).floor() as usize)
.collect();
let mut filterbank = vec![vec![0.0f32; n_spec]; n_mels];
for m in 0..n_mels {
let left = bin_points[m];
let center = bin_points[m + 1];
let right = bin_points[m + 2].min(n_spec - 1);
for i in left..center {
if center > left {
filterbank[m][i] = (i as f32 - left as f32) / (center as f32 - left as f32);
}
}
for i in center..=right {
if right > center {
filterbank[m][i] = (right as f32 - i as f32) / (right as f32 - center as f32);
}
}
}
filterbank
}
/// 完整的特征提取管线: 原始音频 → 展平的 log mel fbank
pub fn audio_to_features(
samples: &[f32],
sample_rate: u32,
) -> (Vec<f32>, usize, usize) {
let config = FeatureConfig {
sample_rate,
..Default::default()
};
let features = extract_fbank(samples, &config);
let n_frames = if config.n_mels > 0 { features.len() / config.n_mels } else { 0 };
(features, n_frames, config.n_mels)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mel_conversion() {
let mel = hz_to_mel(1000.0);
assert!((mel_to_hz(mel) - 1000.0).abs() < 1.0);
}
#[test]
fn test_pre_emphasis() {
let input = vec![1.0, 1.0, 1.0, 1.0];
let output = pre_emphasis(&input, 0.97);
assert_eq!(output[0], 1.0);
assert_eq!(output[1], 1.0 - 0.97);
}
#[test]
fn test_fbank_shape() {
let samples = vec![0.0f32; 16000];
let (features, n_frames, n_mels) = audio_to_features(&samples, 16000);
assert_eq!(n_mels, 80);
assert!(n_frames > 0);
assert_eq!(features.len(), n_frames * n_mels);
}
}