311 lines
10 KiB
Rust
311 lines
10 KiB
Rust
/// BPM Detection using autocorrelation and onset detection
|
|
///
|
|
/// This module provides both offline analysis (for audio import)
|
|
/// and real-time streaming analysis (for the BPM detector node)
|
|
|
|
use std::collections::VecDeque;
|
|
|
|
/// Detects BPM from a complete audio buffer (offline analysis)
|
|
pub fn detect_bpm_offline(audio: &[f32], sample_rate: u32) -> Option<f32> {
|
|
if audio.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
// Convert to mono if needed (already mono in our case)
|
|
// Downsample for efficiency (analyze every 4th sample for faster processing)
|
|
let downsampled: Vec<f32> = audio.iter().step_by(4).copied().collect();
|
|
let effective_sample_rate = sample_rate / 4;
|
|
|
|
// Detect onsets using energy-based method
|
|
let onsets = detect_onsets(&downsampled, effective_sample_rate);
|
|
|
|
if onsets.len() < 4 {
|
|
return None;
|
|
}
|
|
|
|
// Calculate onset strength function for autocorrelation
|
|
let onset_envelope = calculate_onset_envelope(&onsets, downsampled.len(), effective_sample_rate);
|
|
|
|
// Further downsample onset envelope for BPM analysis
|
|
// For 60-200 BPM (1-3.33 Hz), we only need ~10 Hz sample rate by Nyquist
|
|
// Use 100 Hz for good margin (100 samples per second)
|
|
let tempo_sample_rate = 100.0;
|
|
let downsample_factor = (effective_sample_rate as f32 / tempo_sample_rate) as usize;
|
|
let downsampled_envelope: Vec<f32> = onset_envelope
|
|
.iter()
|
|
.step_by(downsample_factor.max(1))
|
|
.copied()
|
|
.collect();
|
|
|
|
// Use autocorrelation to find the fundamental period
|
|
let bpm = detect_bpm_autocorrelation(&downsampled_envelope, tempo_sample_rate as u32);
|
|
|
|
bpm
|
|
}
|
|
|
|
/// Calculate an onset envelope from detected onsets
|
|
fn calculate_onset_envelope(onsets: &[usize], total_length: usize, sample_rate: u32) -> Vec<f32> {
|
|
// Create a sparse representation of onsets with exponential decay
|
|
let mut envelope = vec![0.0; total_length];
|
|
let decay_samples = (sample_rate as f32 * 0.05) as usize; // 50ms decay
|
|
|
|
for &onset in onsets {
|
|
if onset < total_length {
|
|
envelope[onset] = 1.0;
|
|
// Add exponential decay after onset
|
|
for i in 1..decay_samples.min(total_length - onset) {
|
|
let decay_value = (-3.0 * i as f32 / decay_samples as f32).exp();
|
|
envelope[onset + i] = f32::max(envelope[onset + i], decay_value);
|
|
}
|
|
}
|
|
}
|
|
|
|
envelope
|
|
}
|
|
|
|
/// Detect BPM using autocorrelation on onset envelope
|
|
fn detect_bpm_autocorrelation(onset_envelope: &[f32], sample_rate: u32) -> Option<f32> {
|
|
// BPM range: 60-200 BPM
|
|
let min_bpm = 60.0;
|
|
let max_bpm = 200.0;
|
|
|
|
let min_lag = (60.0 * sample_rate as f32 / max_bpm) as usize;
|
|
let max_lag = (60.0 * sample_rate as f32 / min_bpm) as usize;
|
|
|
|
if max_lag >= onset_envelope.len() / 2 {
|
|
return None;
|
|
}
|
|
|
|
// Calculate autocorrelation for tempo range
|
|
let mut best_lag = min_lag;
|
|
let mut best_correlation = 0.0;
|
|
|
|
for lag in min_lag..=max_lag {
|
|
let mut correlation = 0.0;
|
|
let mut count = 0;
|
|
|
|
for i in 0..(onset_envelope.len() - lag) {
|
|
correlation += onset_envelope[i] * onset_envelope[i + lag];
|
|
count += 1;
|
|
}
|
|
|
|
if count > 0 {
|
|
correlation /= count as f32;
|
|
|
|
// Bias toward faster tempos slightly (common in EDM)
|
|
let bias = 1.0 + (lag as f32 - min_lag as f32) / (max_lag - min_lag) as f32 * 0.1;
|
|
correlation /= bias;
|
|
|
|
if correlation > best_correlation {
|
|
best_correlation = correlation;
|
|
best_lag = lag;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert best lag to BPM
|
|
let bpm = 60.0 * sample_rate as f32 / best_lag as f32;
|
|
|
|
// Check for octave errors by testing multiples
|
|
// Common ranges: 60-90 (slow), 90-140 (medium), 140-200 (fast)
|
|
let half_bpm = bpm / 2.0;
|
|
let double_bpm = bpm * 2.0;
|
|
let quad_bpm = bpm * 4.0;
|
|
|
|
// Choose the octave that falls in the most common range (100-180 BPM for EDM/pop)
|
|
let final_bpm = if quad_bpm >= 100.0 && quad_bpm <= 200.0 {
|
|
// Very slow detection, multiply by 4
|
|
quad_bpm
|
|
} else if double_bpm >= 100.0 && double_bpm <= 200.0 {
|
|
// Slow detection, multiply by 2
|
|
double_bpm
|
|
} else if bpm >= 100.0 && bpm <= 200.0 {
|
|
// Already in good range
|
|
bpm
|
|
} else if half_bpm >= 100.0 && half_bpm <= 200.0 {
|
|
// Too fast detection, divide by 2
|
|
half_bpm
|
|
} else {
|
|
// Outside ideal range, use as-is
|
|
bpm
|
|
};
|
|
|
|
// Round to nearest 0.5 BPM for cleaner values
|
|
Some((final_bpm * 2.0).round() / 2.0)
|
|
}
|
|
|
|
/// Detect onsets (beat events) in audio using energy-based method
|
|
fn detect_onsets(audio: &[f32], sample_rate: u32) -> Vec<usize> {
|
|
let mut onsets = Vec::new();
|
|
|
|
// Window size for energy calculation (~20ms)
|
|
let window_size = ((sample_rate as f32 * 0.02) as usize).max(1);
|
|
let hop_size = window_size / 2;
|
|
|
|
if audio.len() < window_size {
|
|
return onsets;
|
|
}
|
|
|
|
// Calculate energy for each window
|
|
let mut energies = Vec::new();
|
|
let mut pos = 0;
|
|
while pos + window_size <= audio.len() {
|
|
let window = &audio[pos..pos + window_size];
|
|
let energy: f32 = window.iter().map(|&s| s * s).sum();
|
|
energies.push(energy / window_size as f32); // Normalize
|
|
pos += hop_size;
|
|
}
|
|
|
|
if energies.len() < 3 {
|
|
return onsets;
|
|
}
|
|
|
|
// Calculate energy differences (onset strength)
|
|
let mut onset_strengths = Vec::new();
|
|
for i in 1..energies.len() {
|
|
let diff = (energies[i] - energies[i - 1]).max(0.0); // Only positive changes
|
|
onset_strengths.push(diff);
|
|
}
|
|
|
|
// Find threshold (adaptive)
|
|
let mean_strength: f32 = onset_strengths.iter().sum::<f32>() / onset_strengths.len() as f32;
|
|
let threshold = mean_strength * 1.5; // 1.5x mean
|
|
|
|
// Peak picking with minimum distance
|
|
let min_distance = sample_rate as usize / 10; // Minimum 100ms between onsets
|
|
let mut last_onset = 0;
|
|
|
|
for (i, &strength) in onset_strengths.iter().enumerate() {
|
|
if strength > threshold {
|
|
let sample_pos = (i + 1) * hop_size;
|
|
|
|
// Check if it's a local maximum and far enough from last onset
|
|
let is_local_max = (i == 0 || onset_strengths[i - 1] <= strength) &&
|
|
(i == onset_strengths.len() - 1 || onset_strengths[i + 1] < strength);
|
|
|
|
if is_local_max && (onsets.is_empty() || sample_pos - last_onset >= min_distance) {
|
|
onsets.push(sample_pos);
|
|
last_onset = sample_pos;
|
|
}
|
|
}
|
|
}
|
|
|
|
onsets
|
|
}
|
|
|
|
/// Real-time BPM detector for streaming audio
|
|
pub struct BpmDetectorRealtime {
|
|
sample_rate: u32,
|
|
|
|
// Circular buffer for recent audio (e.g., 10 seconds)
|
|
audio_buffer: VecDeque<f32>,
|
|
max_buffer_samples: usize,
|
|
|
|
// Current BPM estimate
|
|
current_bpm: f32,
|
|
|
|
// Update interval (samples)
|
|
samples_since_update: usize,
|
|
update_interval: usize,
|
|
|
|
// Smoothing
|
|
bpm_history: VecDeque<f32>,
|
|
history_size: usize,
|
|
}
|
|
|
|
impl BpmDetectorRealtime {
|
|
pub fn new(sample_rate: u32, buffer_duration_seconds: f32) -> Self {
|
|
let max_buffer_samples = (sample_rate as f32 * buffer_duration_seconds) as usize;
|
|
let update_interval = sample_rate as usize; // Update every 1 second
|
|
|
|
Self {
|
|
sample_rate,
|
|
audio_buffer: VecDeque::with_capacity(max_buffer_samples),
|
|
max_buffer_samples,
|
|
current_bpm: 120.0, // Default BPM
|
|
samples_since_update: 0,
|
|
update_interval,
|
|
bpm_history: VecDeque::with_capacity(8),
|
|
history_size: 8,
|
|
}
|
|
}
|
|
|
|
/// Process a chunk of audio and return current BPM estimate
|
|
pub fn process(&mut self, audio: &[f32]) -> f32 {
|
|
// Add samples to buffer
|
|
for &sample in audio {
|
|
if self.audio_buffer.len() >= self.max_buffer_samples {
|
|
self.audio_buffer.pop_front();
|
|
}
|
|
self.audio_buffer.push_back(sample);
|
|
}
|
|
|
|
self.samples_since_update += audio.len();
|
|
|
|
// Periodically re-analyze
|
|
if self.samples_since_update >= self.update_interval && self.audio_buffer.len() > self.sample_rate as usize {
|
|
self.samples_since_update = 0;
|
|
|
|
// Convert buffer to slice for analysis
|
|
let buffer_vec: Vec<f32> = self.audio_buffer.iter().copied().collect();
|
|
|
|
if let Some(detected_bpm) = detect_bpm_offline(&buffer_vec, self.sample_rate) {
|
|
// Add to history for smoothing
|
|
if self.bpm_history.len() >= self.history_size {
|
|
self.bpm_history.pop_front();
|
|
}
|
|
self.bpm_history.push_back(detected_bpm);
|
|
|
|
// Use median of recent detections for stability
|
|
let mut sorted_history: Vec<f32> = self.bpm_history.iter().copied().collect();
|
|
sorted_history.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
self.current_bpm = sorted_history[sorted_history.len() / 2];
|
|
}
|
|
}
|
|
|
|
self.current_bpm
|
|
}
|
|
|
|
pub fn get_bpm(&self) -> f32 {
|
|
self.current_bpm
|
|
}
|
|
|
|
pub fn reset(&mut self) {
|
|
self.audio_buffer.clear();
|
|
self.bpm_history.clear();
|
|
self.samples_since_update = 0;
|
|
self.current_bpm = 120.0;
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_120_bpm_detection() {
|
|
let sample_rate = 48000;
|
|
let bpm = 120.0;
|
|
let beat_interval = 60.0 / bpm;
|
|
let beat_samples = (sample_rate as f32 * beat_interval) as usize;
|
|
|
|
// Generate 8 beats
|
|
let mut audio = vec![0.0; beat_samples * 8];
|
|
for beat in 0..8 {
|
|
let pos = beat * beat_samples;
|
|
// Add a sharp transient at each beat
|
|
for i in 0..100 {
|
|
audio[pos + i] = (1.0 - i as f32 / 100.0) * 0.8;
|
|
}
|
|
}
|
|
|
|
let detected = detect_bpm_offline(&audio, sample_rate);
|
|
assert!(detected.is_some());
|
|
let detected_bpm = detected.unwrap();
|
|
|
|
// Allow 5% tolerance
|
|
assert!((detected_bpm - bpm).abs() / bpm < 0.05,
|
|
"Expected ~{} BPM, got {}", bpm, detected_bpm);
|
|
}
|
|
}
|