Use CQT transform for spectrograph instead of FFT

This commit is contained in:
Skyler Lehmkuhl 2026-02-14 21:17:57 -05:00
parent 777d3ef6be
commit 068715c0fa
13 changed files with 1076 additions and 782 deletions

Binary file not shown.

Binary file not shown.

View File

@ -69,6 +69,11 @@ pub struct ReadAheadBuffer {
channels: u32, channels: u32,
/// Source file sample rate. /// Source file sample rate.
sample_rate: u32, sample_rate: u32,
/// Last file-local frame requested by the audio callback.
/// Written by the consumer (render_from_file), read by the disk reader.
/// The disk reader uses this instead of the global playhead to know
/// where in the file to buffer around.
target_frame: AtomicU64,
} }
// SAFETY: See the doc comment on ReadAheadBuffer for the full safety argument. // SAFETY: See the doc comment on ReadAheadBuffer for the full safety argument.
@ -102,6 +107,7 @@ impl ReadAheadBuffer {
capacity_frames, capacity_frames,
channels, channels,
sample_rate, sample_rate,
target_frame: AtomicU64::new(0),
} }
} }
@ -158,6 +164,20 @@ impl ReadAheadBuffer {
self.valid_frames.load(Ordering::Acquire) self.valid_frames.load(Ordering::Acquire)
} }
/// Update the target frame — the file-local frame the audio callback
/// is currently reading from. Called by `render_from_file` (consumer).
#[inline]
pub fn set_target_frame(&self, frame: u64) {
self.target_frame.store(frame, Ordering::Relaxed);
}
/// Get the target frame set by the audio callback.
/// Called by the disk reader thread (producer).
#[inline]
pub fn target_frame(&self) -> u64 {
self.target_frame.load(Ordering::Relaxed)
}
/// Reset the buffer to start at `new_start` with zero valid frames. /// Reset the buffer to start at `new_start` with zero valid frames.
/// Called by the **disk reader thread** (producer) after a seek. /// Called by the **disk reader thread** (producer) after a seek.
pub fn reset(&self, new_start: u64) { pub fn reset(&self, new_start: u64) {
@ -431,20 +451,16 @@ pub struct DiskReader {
impl DiskReader { impl DiskReader {
/// Create a new disk reader with a background thread. /// Create a new disk reader with a background thread.
///
/// `playhead_frame` should be the same `Arc<AtomicU64>` used by the engine
/// so the disk reader knows where to fill ahead.
pub fn new(playhead_frame: Arc<AtomicU64>, _sample_rate: u32) -> Self { pub fn new(playhead_frame: Arc<AtomicU64>, _sample_rate: u32) -> Self {
let (command_tx, command_rx) = rtrb::RingBuffer::new(64); let (command_tx, command_rx) = rtrb::RingBuffer::new(64);
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
let thread_running = running.clone(); let thread_running = running.clone();
let thread_playhead = playhead_frame.clone();
let thread_handle = std::thread::Builder::new() let thread_handle = std::thread::Builder::new()
.name("disk-reader".into()) .name("disk-reader".into())
.spawn(move || { .spawn(move || {
Self::reader_thread(command_rx, thread_playhead, thread_running); Self::reader_thread(command_rx, thread_running);
}) })
.expect("Failed to spawn disk reader thread"); .expect("Failed to spawn disk reader thread");
@ -473,7 +489,6 @@ impl DiskReader {
/// The disk reader background thread. /// The disk reader background thread.
fn reader_thread( fn reader_thread(
mut command_rx: rtrb::Consumer<DiskReaderCommand>, mut command_rx: rtrb::Consumer<DiskReaderCommand>,
playhead_frame: Arc<AtomicU64>,
running: Arc<AtomicBool>, running: Arc<AtomicBool>,
) { ) {
let mut active_files: HashMap<usize, (CompressedReader, Arc<ReadAheadBuffer>)> = let mut active_files: HashMap<usize, (CompressedReader, Arc<ReadAheadBuffer>)> =
@ -506,6 +521,7 @@ impl DiskReader {
} }
DiskReaderCommand::Seek { frame } => { DiskReaderCommand::Seek { frame } => {
for (_, (reader, buffer)) in active_files.iter_mut() { for (_, (reader, buffer)) in active_files.iter_mut() {
buffer.set_target_frame(frame);
buffer.reset(frame); buffer.reset(frame);
if let Err(e) = reader.seek(frame) { if let Err(e) = reader.seek(frame) {
eprintln!("[DiskReader] Seek error: {}", e); eprintln!("[DiskReader] Seek error: {}", e);
@ -518,26 +534,28 @@ impl DiskReader {
} }
} }
let playhead = playhead_frame.load(Ordering::Relaxed); // Fill each active file's buffer ahead of its target frame.
// Each file's target_frame is set by the audio callback in
// Fill each active file's buffer ahead of the playhead. // render_from_file, giving the file-local frame being read.
// This is independent of the global engine playhead.
for (_pool_index, (reader, buffer)) in active_files.iter_mut() { for (_pool_index, (reader, buffer)) in active_files.iter_mut() {
let target = buffer.target_frame();
let buf_start = buffer.start_frame(); let buf_start = buffer.start_frame();
let buf_valid = buffer.valid_frames_count(); let buf_valid = buffer.valid_frames_count();
let buf_end = buf_start + buf_valid; let buf_end = buf_start + buf_valid;
// If the playhead has jumped behind or far ahead of the buffer, // If the target has jumped behind or far ahead of the buffer,
// seek the decoder and reset. // seek the decoder and reset.
if playhead < buf_start || playhead > buf_end + reader.sample_rate as u64 { if target < buf_start || target > buf_end + reader.sample_rate as u64 {
buffer.reset(playhead); buffer.reset(target);
let _ = reader.seek(playhead); let _ = reader.seek(target);
continue; continue;
} }
// Advance the buffer start to reclaim space behind the playhead. // Advance the buffer start to reclaim space behind the target.
// Keep a small lookback for sinc interpolation (~32 frames). // Keep a small lookback for sinc interpolation (~32 frames).
let lookback = 64u64; let lookback = 64u64;
let advance_to = playhead.saturating_sub(lookback); let advance_to = target.saturating_sub(lookback);
if advance_to > buf_start { if advance_to > buf_start {
buffer.advance_start(advance_to); buffer.advance_start(advance_to);
} }
@ -547,7 +565,7 @@ impl DiskReader {
let buf_valid = buffer.valid_frames_count(); let buf_valid = buffer.valid_frames_count();
let buf_end = buf_start + buf_valid; let buf_end = buf_start + buf_valid;
let prefetch_target = let prefetch_target =
playhead + (PREFETCH_SECONDS * reader.sample_rate as f64) as u64; target + (PREFETCH_SECONDS * reader.sample_rate as f64) as u64;
if buf_end >= prefetch_target { if buf_end >= prefetch_target {
continue; // Already filled far enough ahead. continue; // Already filled far enough ahead.

View File

@ -489,6 +489,10 @@ impl Engine {
self.playhead_atomic.store(0, Ordering::Relaxed); self.playhead_atomic.store(0, Ordering::Relaxed);
// Stop all MIDI notes when stopping playback // Stop all MIDI notes when stopping playback
self.project.stop_all_notes(); self.project.stop_all_notes();
// Reset disk reader buffers to the new playhead position
if let Some(ref mut dr) = self.disk_reader {
dr.send(crate::audio::disk_reader::DiskReaderCommand::Seek { frame: 0 });
}
} }
Command::Pause => { Command::Pause => {
self.playing = false; self.playing = false;
@ -1686,165 +1690,144 @@ impl Engine {
} }
Command::ImportAudio(path) => { Command::ImportAudio(path) => {
let path_str = path.to_string_lossy().to_string(); if let Err(e) = self.do_import_audio(&path) {
eprintln!("[ENGINE] ImportAudio failed for {:?}: {}", path, e);
// Step 1: Read metadata (fast — no decoding)
let metadata = match crate::io::read_metadata(&path) {
Ok(m) => m,
Err(e) => {
eprintln!("[ENGINE] ImportAudio failed to read metadata for {:?}: {}", path, e);
return;
}
};
let pool_index;
eprintln!("[ENGINE] ImportAudio: format={:?}, ch={}, sr={}, n_frames={:?}, duration={:.2}s, path={}",
metadata.format, metadata.channels, metadata.sample_rate, metadata.n_frames, metadata.duration, path_str);
match metadata.format {
crate::io::AudioFormat::Pcm => {
// WAV/AIFF: memory-map the file for instant availability
let file = match std::fs::File::open(&path) {
Ok(f) => f,
Err(e) => {
eprintln!("[ENGINE] ImportAudio failed to open {:?}: {}", path, e);
return;
}
};
// SAFETY: The file is opened read-only. The mmap is shared
// immutably. We never write to it.
let mmap = match unsafe { memmap2::Mmap::map(&file) } {
Ok(m) => m,
Err(e) => {
eprintln!("[ENGINE] ImportAudio mmap failed for {:?}: {}", path, e);
return;
}
};
// Parse WAV header to find PCM data offset and format
let header = match crate::io::parse_wav_header(&mmap) {
Ok(h) => h,
Err(e) => {
eprintln!("[ENGINE] ImportAudio WAV parse failed for {:?}: {}", path, e);
return;
}
};
let audio_file = crate::audio::pool::AudioFile::from_mmap(
path.clone(),
mmap,
header.data_offset,
header.sample_format,
header.channels,
header.sample_rate,
header.total_frames,
);
pool_index = self.audio_pool.add_file(audio_file);
}
crate::io::AudioFormat::Compressed => {
let sync_decode = std::env::var("DAW_SYNC_DECODE").is_ok();
if sync_decode {
// Diagnostic: full synchronous decode to InMemory (bypasses ring buffer)
eprintln!("[ENGINE] DAW_SYNC_DECODE: doing full decode of {:?}", path);
match crate::io::AudioFile::load(&path) {
Ok(loaded) => {
let ext = path.extension()
.and_then(|e| e.to_str())
.map(|s| s.to_lowercase());
let audio_file = crate::audio::pool::AudioFile::with_format(
path.clone(),
loaded.data,
loaded.channels,
loaded.sample_rate,
ext,
);
pool_index = self.audio_pool.add_file(audio_file);
eprintln!("[ENGINE] DAW_SYNC_DECODE: pool_index={}, frames={}", pool_index, loaded.frames);
}
Err(e) => {
eprintln!("[ENGINE] DAW_SYNC_DECODE failed: {}", e);
return;
}
}
} else {
// Normal path: stream decode via disk reader
let ext = path.extension()
.and_then(|e| e.to_str())
.map(|s| s.to_lowercase());
let total_frames = metadata.n_frames.unwrap_or_else(|| {
(metadata.duration * metadata.sample_rate as f64).ceil() as u64
});
let mut audio_file = crate::audio::pool::AudioFile::from_compressed(
path.clone(),
metadata.channels,
metadata.sample_rate,
total_frames,
ext,
);
let buffer = crate::audio::disk_reader::DiskReader::create_buffer(
metadata.sample_rate,
metadata.channels,
);
audio_file.read_ahead = Some(buffer.clone());
pool_index = self.audio_pool.add_file(audio_file);
eprintln!("[ENGINE] Compressed: total_frames={}, pool_index={}, has_disk_reader={}",
total_frames, pool_index, self.disk_reader.is_some());
if let Some(ref mut dr) = self.disk_reader {
dr.send(crate::audio::disk_reader::DiskReaderCommand::ActivateFile {
pool_index,
path: path.clone(),
buffer,
});
}
// Spawn background thread to decode full file for waveform display
let bg_tx = self.chunk_generation_tx.clone();
let bg_path = path.clone();
let _ = std::thread::Builder::new()
.name(format!("waveform-decode-{}", pool_index))
.spawn(move || {
eprintln!("[WAVEFORM DECODE] Starting full decode of {:?}", bg_path);
match crate::io::AudioFile::load(&bg_path) {
Ok(loaded) => {
eprintln!("[WAVEFORM DECODE] Complete: {} frames, {} channels",
loaded.frames, loaded.channels);
let _ = bg_tx.send(AudioEvent::WaveformDecodeComplete {
pool_index,
samples: loaded.data,
});
}
Err(e) => {
eprintln!("[WAVEFORM DECODE] Failed to decode {:?}: {}", bg_path, e);
}
}
});
}
}
} }
// Emit AudioFileReady event
let _ = self.event_tx.push(AudioEvent::AudioFileReady {
pool_index,
path: path_str,
channels: metadata.channels,
sample_rate: metadata.sample_rate,
duration: metadata.duration,
format: metadata.format,
});
} }
} }
} }
/// Import an audio file into the pool: mmap for PCM, streaming for compressed.
/// Returns the pool index on success. Emits AudioFileReady event.
fn do_import_audio(&mut self, path: &std::path::Path) -> Result<usize, String> {
let path_str = path.to_string_lossy().to_string();
let metadata = crate::io::read_metadata(path)
.map_err(|e| format!("Failed to read metadata for {:?}: {}", path, e))?;
eprintln!("[ENGINE] ImportAudio: format={:?}, ch={}, sr={}, n_frames={:?}, duration={:.2}s, path={}",
metadata.format, metadata.channels, metadata.sample_rate, metadata.n_frames, metadata.duration, path_str);
let pool_index = match metadata.format {
crate::io::AudioFormat::Pcm => {
let file = std::fs::File::open(path)
.map_err(|e| format!("Failed to open {:?}: {}", path, e))?;
// SAFETY: The file is opened read-only. The mmap is shared
// immutably. We never write to it.
let mmap = unsafe { memmap2::Mmap::map(&file) }
.map_err(|e| format!("mmap failed for {:?}: {}", path, e))?;
let header = crate::io::parse_wav_header(&mmap)
.map_err(|e| format!("WAV parse failed for {:?}: {}", path, e))?;
let audio_file = crate::audio::pool::AudioFile::from_mmap(
path.to_path_buf(),
mmap,
header.data_offset,
header.sample_format,
header.channels,
header.sample_rate,
header.total_frames,
);
self.audio_pool.add_file(audio_file)
}
crate::io::AudioFormat::Compressed => {
let sync_decode = std::env::var("DAW_SYNC_DECODE").is_ok();
if sync_decode {
eprintln!("[ENGINE] DAW_SYNC_DECODE: doing full decode of {:?}", path);
let loaded = crate::io::AudioFile::load(path)
.map_err(|e| format!("DAW_SYNC_DECODE failed: {}", e))?;
let ext = path.extension()
.and_then(|e| e.to_str())
.map(|s| s.to_lowercase());
let audio_file = crate::audio::pool::AudioFile::with_format(
path.to_path_buf(),
loaded.data,
loaded.channels,
loaded.sample_rate,
ext,
);
let idx = self.audio_pool.add_file(audio_file);
eprintln!("[ENGINE] DAW_SYNC_DECODE: pool_index={}, frames={}", idx, loaded.frames);
idx
} else {
let ext = path.extension()
.and_then(|e| e.to_str())
.map(|s| s.to_lowercase());
let total_frames = metadata.n_frames.unwrap_or_else(|| {
(metadata.duration * metadata.sample_rate as f64).ceil() as u64
});
let mut audio_file = crate::audio::pool::AudioFile::from_compressed(
path.to_path_buf(),
metadata.channels,
metadata.sample_rate,
total_frames,
ext,
);
let buffer = crate::audio::disk_reader::DiskReader::create_buffer(
metadata.sample_rate,
metadata.channels,
);
audio_file.read_ahead = Some(buffer.clone());
let idx = self.audio_pool.add_file(audio_file);
eprintln!("[ENGINE] Compressed: total_frames={}, pool_index={}, has_disk_reader={}",
total_frames, idx, self.disk_reader.is_some());
if let Some(ref mut dr) = self.disk_reader {
dr.send(crate::audio::disk_reader::DiskReaderCommand::ActivateFile {
pool_index: idx,
path: path.to_path_buf(),
buffer,
});
}
// Spawn background thread to decode full file for waveform display
let bg_tx = self.chunk_generation_tx.clone();
let bg_path = path.to_path_buf();
let _ = std::thread::Builder::new()
.name(format!("waveform-decode-{}", idx))
.spawn(move || {
eprintln!("[WAVEFORM DECODE] Starting full decode of {:?}", bg_path);
match crate::io::AudioFile::load(&bg_path) {
Ok(loaded) => {
eprintln!("[WAVEFORM DECODE] Complete: {} frames, {} channels",
loaded.frames, loaded.channels);
let _ = bg_tx.send(AudioEvent::WaveformDecodeComplete {
pool_index: idx,
samples: loaded.data,
});
}
Err(e) => {
eprintln!("[WAVEFORM DECODE] Failed to decode {:?}: {}", bg_path, e);
}
}
});
idx
}
}
};
// Emit AudioFileReady event
let _ = self.event_tx.push(AudioEvent::AudioFileReady {
pool_index,
path: path_str,
channels: metadata.channels,
sample_rate: metadata.sample_rate,
duration: metadata.duration,
format: metadata.format,
});
Ok(pool_index)
}
/// Handle synchronous queries from the UI thread /// Handle synchronous queries from the UI thread
fn handle_query(&mut self, query: Query) { fn handle_query(&mut self, query: Query) {
let response = match query { let response = match query {
@ -2231,6 +2214,9 @@ impl Engine {
QueryResponse::AudioFileAddedSync(Ok(pool_index)) QueryResponse::AudioFileAddedSync(Ok(pool_index))
} }
Query::ImportAudioSync(path) => {
QueryResponse::AudioImportedSync(self.do_import_audio(&path))
}
Query::GetProject => { Query::GetProject => {
// Clone the entire project for serialization // Clone the entire project for serialization
QueryResponse::ProjectRetrieved(Ok(Box::new(self.project.clone()))) QueryResponse::ProjectRetrieved(Ok(Box::new(self.project.clone())))
@ -2674,6 +2660,21 @@ impl EngineController {
let _ = self.command_tx.push(Command::ImportAudio(path)); let _ = self.command_tx.push(Command::ImportAudio(path));
} }
/// Import an audio file synchronously and get the pool index.
/// Does the same work as `import_audio` (mmap for PCM, streaming for
/// compressed) but returns the real pool index directly.
/// NOTE: briefly blocks the UI thread during file setup (sub-ms for PCM
/// mmap; a few ms for compressed streaming init). If this becomes a
/// problem for very large files, switch to async import with event-based
/// pool index reconciliation.
pub fn import_audio_sync(&mut self, path: std::path::PathBuf) -> Result<usize, String> {
let query = Query::ImportAudioSync(path);
match self.send_query(query)? {
QueryResponse::AudioImportedSync(result) => result,
_ => Err("Unexpected query response".to_string()),
}
}
/// Add a clip to an audio track /// Add a clip to an audio track
pub fn add_audio_clip(&mut self, track_id: TrackId, pool_index: usize, start_time: f64, duration: f64, offset: f64) { pub fn add_audio_clip(&mut self, track_id: TrackId, pool_index: usize, start_time: f64, duration: f64, offset: f64) {
let _ = self.command_tx.push(Command::AddAudioClip(track_id, pool_index, start_time, duration, offset)); let _ = self.command_tx.push(Command::AddAudioClip(track_id, pool_index, start_time, duration, offset));

View File

@ -511,6 +511,11 @@ impl AudioClipPool {
let src_start_position = start_time_seconds * audio_file.sample_rate as f64; let src_start_position = start_time_seconds * audio_file.sample_rate as f64;
// Tell the disk reader where we're reading so it buffers the right region.
if use_read_ahead {
read_ahead.unwrap().set_target_frame(src_start_position as u64);
}
let mut rendered_frames = 0; let mut rendered_frames = 0;
if audio_file.sample_rate == engine_sample_rate { if audio_file.sample_rate == engine_sample_rate {

View File

@ -333,6 +333,14 @@ pub enum Query {
AddAudioClipSync(TrackId, usize, f64, f64, f64), AddAudioClipSync(TrackId, usize, f64, f64, f64),
/// Add an audio file to the pool synchronously (path, data, channels, sample_rate) - returns pool index /// Add an audio file to the pool synchronously (path, data, channels, sample_rate) - returns pool index
AddAudioFileSync(String, Vec<f32>, u32, u32), AddAudioFileSync(String, Vec<f32>, u32, u32),
/// Import an audio file synchronously (path) - returns pool index.
/// Does the same work as Command::ImportAudio (mmap for PCM, streaming
/// setup for compressed) but returns the real pool index in the response.
/// NOTE: briefly blocks the UI thread during file setup (sub-ms for PCM
/// mmap; a few ms for compressed streaming init). If this becomes a
/// problem for very large files, switch to async import with event-based
/// pool index reconciliation.
ImportAudioSync(std::path::PathBuf),
/// Get raw audio samples from pool (pool_index) - returns (samples, sample_rate, channels) /// Get raw audio samples from pool (pool_index) - returns (samples, sample_rate, channels)
GetPoolAudioSamples(usize), GetPoolAudioSamples(usize),
/// Get a clone of the current project for serialization /// Get a clone of the current project for serialization
@ -404,6 +412,8 @@ pub enum QueryResponse {
AudioClipInstanceAdded(Result<AudioClipInstanceId, String>), AudioClipInstanceAdded(Result<AudioClipInstanceId, String>),
/// Audio file added to pool (returns pool index) /// Audio file added to pool (returns pool index)
AudioFileAddedSync(Result<usize, String>), AudioFileAddedSync(Result<usize, String>),
/// Audio file imported to pool (returns pool index)
AudioImportedSync(Result<usize, String>),
/// Raw audio samples from pool (samples, sample_rate, channels) /// Raw audio samples from pool (samples, sample_rate, channels)
PoolAudioSamples(Result<(Vec<f32>, u32, u32), String>), PoolAudioSamples(Result<(Vec<f32>, u32, u32), String>),
/// Project retrieved /// Project retrieved

View File

@ -0,0 +1,683 @@
/// GPU-based Constant-Q Transform (CQT) spectrogram with streaming ring-buffer cache.
///
/// Replaces the old FFT spectrogram with a CQT that has logarithmic frequency spacing
/// (bins map directly to MIDI notes). Only the visible viewport is computed, with results
/// cached in a ring-buffer texture so scrolling only computes new columns.
///
/// Architecture:
/// - CqtGpuResources stored in CallbackResources (long-lived, holds pipelines)
/// - CqtCacheEntry per pool_index (cache texture, bin params, ring buffer state)
/// - CqtCallback implements CallbackTrait (per-frame compute + render)
/// - Compute shader reads audio from waveform mip-0 textures (already on GPU)
/// - Render shader reads from cache texture with colormap
use std::collections::HashMap;
use wgpu::util::DeviceExt;
use crate::waveform_gpu::WaveformGpuResources;
/// CQT parameters
const BINS_PER_OCTAVE: u32 = 24;
const FREQ_BINS: u32 = 174; // ceil(log2(4186.0 / 27.5) * 24) = ceil(173.95)
const HOP_SIZE: u32 = 512;
const CACHE_CAPACITY: u32 = 4096;
const MAX_COLS_PER_FRAME: u32 = 256;
const F_MIN: f64 = 27.5; // A0 = MIDI 21
const WAVEFORM_TEX_WIDTH: u32 = 2048;
/// Per-bin CQT kernel parameters, uploaded as a storage buffer.
/// Must match BinInfo in cqt_compute.wgsl.
#[repr(C)]
#[derive(Debug, Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct CqtBinParams {
window_length: u32,
phase_step: f32, // 2*pi*Q / N_k
_pad0: u32,
_pad1: u32,
}
/// Compute shader uniform params. Must match CqtParams in cqt_compute.wgsl.
#[repr(C)]
#[derive(Debug, Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct CqtComputeParams {
hop_size: u32,
freq_bins: u32,
cache_capacity: u32,
cache_write_offset: u32,
num_columns: u32,
column_start: u32,
tex_width: u32,
total_frames: u32,
sample_rate: f32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
/// Render shader uniform params. Must match Params in cqt_render.wgsl exactly.
/// Layout: clip_rect(16) + 18 × f32(72) + pad vec2(8) = 96 bytes
#[repr(C)]
#[derive(Debug, Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
pub struct CqtRenderParams {
pub clip_rect: [f32; 4], // 16 bytes @ offset 0
pub viewport_start_time: f32, // 4 @ 16
pub pixels_per_second: f32, // 4 @ 20
pub audio_duration: f32, // 4 @ 24
pub sample_rate: f32, // 4 @ 28
pub clip_start_time: f32, // 4 @ 32
pub trim_start: f32, // 4 @ 36
pub freq_bins: f32, // 4 @ 40
pub bins_per_octave: f32, // 4 @ 44
pub hop_size: f32, // 4 @ 48
pub scroll_y: f32, // 4 @ 52
pub note_height: f32, // 4 @ 56
pub min_note: f32, // 4 @ 60
pub max_note: f32, // 4 @ 64
pub gamma: f32, // 4 @ 68
pub cache_capacity: f32, // 4 @ 72
pub cache_start_column: f32, // 4 @ 76
pub cache_valid_start: f32, // 4 @ 80
pub cache_valid_end: f32, // 4 @ 84
pub _pad: [f32; 2], // 8 @ 88, total 96
}
/// Per-pool-index cache entry with ring buffer and GPU resources.
#[allow(dead_code)]
struct CqtCacheEntry {
// Cache texture (Rgba16Float for universal filterable + storage support)
cache_texture: wgpu::Texture,
cache_texture_view: wgpu::TextureView,
cache_storage_view: wgpu::TextureView,
cache_capacity: u32,
freq_bins: u32,
// Ring buffer state
cache_start_column: i64,
cache_valid_start: i64,
cache_valid_end: i64,
// CQT kernel data
bin_params_buffer: wgpu::Buffer,
// Waveform texture reference (cloned from WaveformGpuEntry)
waveform_texture_view: wgpu::TextureView,
waveform_total_frames: u64,
// Bind groups
compute_bind_group: wgpu::BindGroup,
compute_uniform_buffer: wgpu::Buffer,
render_bind_group: wgpu::BindGroup,
render_uniform_buffer: wgpu::Buffer,
// Metadata
sample_rate: u32,
}
/// Global GPU resources for CQT (stored in egui_wgpu::CallbackResources).
pub struct CqtGpuResources {
entries: HashMap<usize, CqtCacheEntry>,
compute_pipeline: wgpu::ComputePipeline,
compute_bind_group_layout: wgpu::BindGroupLayout,
render_pipeline: wgpu::RenderPipeline,
render_bind_group_layout: wgpu::BindGroupLayout,
sampler: wgpu::Sampler,
}
/// Per-frame callback for computing and rendering a CQT spectrogram.
pub struct CqtCallback {
pub pool_index: usize,
pub params: CqtRenderParams,
pub target_format: wgpu::TextureFormat,
pub sample_rate: u32,
/// Visible column range (global CQT column indices)
pub visible_col_start: i64,
pub visible_col_end: i64,
}
/// Precompute CQT bin parameters for a given sample rate.
fn precompute_bin_params(sample_rate: u32) -> Vec<CqtBinParams> {
let b = BINS_PER_OCTAVE as f64;
let q = 1.0 / (2.0_f64.powf(1.0 / b) - 1.0);
(0..FREQ_BINS)
.map(|k| {
let f_k = F_MIN * 2.0_f64.powf(k as f64 / b);
let n_k = (q * sample_rate as f64 / f_k).ceil() as u32;
let phase_step = (2.0 * std::f64::consts::PI * q / n_k as f64) as f32;
CqtBinParams {
window_length: n_k,
phase_step,
_pad0: 0,
_pad1: 0,
}
})
.collect()
}
impl CqtGpuResources {
pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self {
// Compute shader
let compute_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("cqt_compute_shader"),
source: wgpu::ShaderSource::Wgsl(
include_str!("panes/shaders/cqt_compute.wgsl").into(),
),
});
// Render shader
let render_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("cqt_render_shader"),
source: wgpu::ShaderSource::Wgsl(
include_str!("panes/shaders/cqt_render.wgsl").into(),
),
});
// Compute bind group layout:
// 0: audio_tex (texture_2d<f32>, read)
// 1: cqt_out (texture_storage_2d<rgba16float, write>)
// 2: params (uniform)
// 3: bins (storage, read)
let compute_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("cqt_compute_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: false },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture {
access: wgpu::StorageTextureAccess::WriteOnly,
format: wgpu::TextureFormat::Rgba16Float,
view_dimension: wgpu::TextureViewDimension::D2,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
// Render bind group layout: cache_tex + sampler + uniforms
let render_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("cqt_render_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
// Compute pipeline
let compute_pipeline_layout =
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("cqt_compute_pipeline_layout"),
bind_group_layouts: &[&compute_bind_group_layout],
push_constant_ranges: &[],
});
let compute_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("cqt_compute_pipeline"),
layout: Some(&compute_pipeline_layout),
module: &compute_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
// Render pipeline
let render_pipeline_layout =
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("cqt_render_pipeline_layout"),
bind_group_layouts: &[&render_bind_group_layout],
push_constant_ranges: &[],
});
let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("cqt_render_pipeline"),
layout: Some(&render_pipeline_layout),
vertex: wgpu::VertexState {
module: &render_shader,
entry_point: Some("vs_main"),
buffers: &[],
compilation_options: Default::default(),
},
fragment: Some(wgpu::FragmentState {
module: &render_shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format: target_format,
blend: Some(wgpu::BlendState::ALPHA_BLENDING),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: Default::default(),
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleList,
..Default::default()
},
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
multiview: None,
cache: None,
});
// Bilinear sampler for smooth interpolation in render shader
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("cqt_sampler"),
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::FilterMode::Nearest,
..Default::default()
});
Self {
entries: HashMap::new(),
compute_pipeline,
compute_bind_group_layout,
render_pipeline,
render_bind_group_layout,
sampler,
}
}
/// Create a cache entry for a pool index, referencing the waveform texture.
fn ensure_cache_entry(
&mut self,
device: &wgpu::Device,
pool_index: usize,
waveform_texture_view: wgpu::TextureView,
total_frames: u64,
sample_rate: u32,
) {
if self.entries.contains_key(&pool_index) {
return;
}
// Create cache texture (ring buffer)
let cache_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some(&format!("cqt_cache_{}", pool_index)),
size: wgpu::Extent3d {
width: CACHE_CAPACITY,
height: FREQ_BINS,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: wgpu::TextureFormat::Rgba16Float,
usage: wgpu::TextureUsages::STORAGE_BINDING | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let cache_texture_view = cache_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some(&format!("cqt_cache_{}_view", pool_index)),
..Default::default()
});
let cache_storage_view = cache_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some(&format!("cqt_cache_{}_storage", pool_index)),
..Default::default()
});
// Precompute bin params
let bin_params = precompute_bin_params(sample_rate);
let bin_params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(&format!("cqt_bins_{}", pool_index)),
contents: bytemuck::cast_slice(&bin_params),
usage: wgpu::BufferUsages::STORAGE,
});
// Compute uniform buffer
let compute_uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("cqt_compute_uniforms_{}", pool_index)),
size: std::mem::size_of::<CqtComputeParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Render uniform buffer
let render_uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("cqt_render_uniforms_{}", pool_index)),
size: std::mem::size_of::<CqtRenderParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Compute bind group
let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("cqt_compute_bg_{}", pool_index)),
layout: &self.compute_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(&waveform_texture_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::TextureView(&cache_storage_view),
},
wgpu::BindGroupEntry {
binding: 2,
resource: compute_uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: bin_params_buffer.as_entire_binding(),
},
],
});
// Render bind group
let render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("cqt_render_bg_{}", pool_index)),
layout: &self.render_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(&cache_texture_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::Sampler(&self.sampler),
},
wgpu::BindGroupEntry {
binding: 2,
resource: render_uniform_buffer.as_entire_binding(),
},
],
});
self.entries.insert(
pool_index,
CqtCacheEntry {
cache_texture,
cache_texture_view,
cache_storage_view,
cache_capacity: CACHE_CAPACITY,
freq_bins: FREQ_BINS,
cache_start_column: 0,
cache_valid_start: 0,
cache_valid_end: 0,
bin_params_buffer,
waveform_texture_view,
waveform_total_frames: total_frames,
compute_bind_group,
compute_uniform_buffer,
render_bind_group,
render_uniform_buffer,
sample_rate,
},
);
}
}
/// Dispatch compute shader to fill CQT columns in the cache.
/// Free function to avoid borrow conflicts with CqtGpuResources.entries.
fn dispatch_cqt_compute(
device: &wgpu::Device,
queue: &wgpu::Queue,
pipeline: &wgpu::ComputePipeline,
entry: &CqtCacheEntry,
start_col: i64,
end_col: i64,
) -> Vec<wgpu::CommandBuffer> {
let num_cols = (end_col - start_col) as u32;
if num_cols == 0 {
return Vec::new();
}
// Clamp to max per frame
let num_cols = num_cols.min(MAX_COLS_PER_FRAME);
// Calculate ring buffer write offset
let cache_write_offset =
((start_col - entry.cache_start_column) as u32) % entry.cache_capacity;
let params = CqtComputeParams {
hop_size: HOP_SIZE,
freq_bins: FREQ_BINS,
cache_capacity: entry.cache_capacity,
cache_write_offset,
num_columns: num_cols,
column_start: start_col.max(0) as u32,
tex_width: WAVEFORM_TEX_WIDTH,
total_frames: entry.waveform_total_frames as u32,
sample_rate: entry.sample_rate as f32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
queue.write_buffer(
&entry.compute_uniform_buffer,
0,
bytemuck::cast_slice(&[params]),
);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("cqt_compute_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("cqt_compute_pass"),
timestamp_writes: None,
});
pass.set_pipeline(pipeline);
pass.set_bind_group(0, &entry.compute_bind_group, &[]);
// Dispatch: X = ceil(freq_bins / 64), Y = num_columns
let workgroups_x = (FREQ_BINS + 63) / 64;
pass.dispatch_workgroups(workgroups_x, num_cols, 1);
}
vec![encoder.finish()]
}
impl egui_wgpu::CallbackTrait for CqtCallback {
fn prepare(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
_screen_descriptor: &egui_wgpu::ScreenDescriptor,
_egui_encoder: &mut wgpu::CommandEncoder,
resources: &mut egui_wgpu::CallbackResources,
) -> Vec<wgpu::CommandBuffer> {
// Initialize CQT resources if needed
if !resources.contains::<CqtGpuResources>() {
resources.insert(CqtGpuResources::new(device, self.target_format));
}
// First, check if waveform data is available and extract what we need
let waveform_info: Option<(wgpu::TextureView, u64)> = {
let waveform_gpu: Option<&WaveformGpuResources> = resources.get();
waveform_gpu.and_then(|wgpu_res| {
wgpu_res.entries.get(&self.pool_index).map(|entry| {
// Clone the texture view (Arc internally, cheap)
(entry.texture_views[0].clone(), entry.total_frames)
})
})
};
let (waveform_view, total_frames) = match waveform_info {
Some(info) => info,
None => return Vec::new(), // Waveform not uploaded yet
};
let cqt_gpu: &mut CqtGpuResources = resources.get_mut().unwrap();
// Ensure cache entry exists
cqt_gpu.ensure_cache_entry(
device,
self.pool_index,
waveform_view,
total_frames,
self.sample_rate,
);
// Determine which columns need computing
let vis_start = self.visible_col_start.max(0);
let max_col = (total_frames as i64) / HOP_SIZE as i64;
let vis_end = self.visible_col_end.min(max_col);
// Read current cache state, compute what's needed, then update state.
// We split borrows carefully: read entry state, compute, then write back.
let cmds;
{
let entry = cqt_gpu.entries.get(&self.pool_index).unwrap();
let cache_valid_start = entry.cache_valid_start;
let cache_valid_end = entry.cache_valid_end;
if vis_start >= vis_end {
cmds = Vec::new();
} else if vis_start >= cache_valid_start && vis_end <= cache_valid_end {
// Fully cached
cmds = Vec::new();
} else if vis_start >= cache_valid_start
&& vis_start < cache_valid_end
&& vis_end > cache_valid_end
{
// Scrolling right
let actual_end =
cache_valid_end + (vis_end - cache_valid_end).min(MAX_COLS_PER_FRAME as i64);
cmds = dispatch_cqt_compute(
device, queue, &cqt_gpu.compute_pipeline, entry,
cache_valid_end, actual_end,
);
let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap();
entry.cache_valid_end = actual_end;
if entry.cache_valid_end - entry.cache_valid_start > entry.cache_capacity as i64 {
entry.cache_valid_start = entry.cache_valid_end - entry.cache_capacity as i64;
entry.cache_start_column = entry.cache_valid_start;
}
} else if vis_end <= cache_valid_end
&& vis_end > cache_valid_start
&& vis_start < cache_valid_start
{
// Scrolling left
let actual_start =
cache_valid_start - (cache_valid_start - vis_start).min(MAX_COLS_PER_FRAME as i64);
cmds = dispatch_cqt_compute(
device, queue, &cqt_gpu.compute_pipeline, entry,
actual_start, cache_valid_start,
);
let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap();
entry.cache_valid_start = actual_start;
entry.cache_start_column = actual_start;
if entry.cache_valid_end - entry.cache_valid_start > entry.cache_capacity as i64 {
entry.cache_valid_end = entry.cache_valid_start + entry.cache_capacity as i64;
}
} else {
// No overlap or first compute — reset cache
let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap();
entry.cache_start_column = vis_start;
entry.cache_valid_start = vis_start;
entry.cache_valid_end = vis_start;
let compute_end = vis_start + (vis_end - vis_start).min(MAX_COLS_PER_FRAME as i64);
let entry = cqt_gpu.entries.get(&self.pool_index).unwrap();
cmds = dispatch_cqt_compute(
device, queue, &cqt_gpu.compute_pipeline, entry,
vis_start, compute_end,
);
let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap();
entry.cache_valid_end = compute_end;
}
}
// Update render uniform buffer
let entry = cqt_gpu.entries.get(&self.pool_index).unwrap();
let mut params = self.params;
params.cache_start_column = entry.cache_start_column as f32;
params.cache_valid_start = entry.cache_valid_start as f32;
params.cache_valid_end = entry.cache_valid_end as f32;
params.cache_capacity = entry.cache_capacity as f32;
queue.write_buffer(
&entry.render_uniform_buffer,
0,
bytemuck::cast_slice(&[params]),
);
cmds
}
fn paint(
&self,
_info: eframe::egui::PaintCallbackInfo,
render_pass: &mut wgpu::RenderPass<'static>,
resources: &egui_wgpu::CallbackResources,
) {
let cqt_gpu: &CqtGpuResources = match resources.get() {
Some(r) => r,
None => return,
};
let entry = match cqt_gpu.entries.get(&self.pool_index) {
Some(e) => e,
None => return,
};
// Don't render if nothing is cached yet
if entry.cache_valid_start >= entry.cache_valid_end {
return;
}
render_pass.set_pipeline(&cqt_gpu.render_pipeline);
render_pass.set_bind_group(0, &entry.render_bind_group, &[]);
render_pass.draw(0..3, 0..1);
}
}

View File

@ -20,8 +20,7 @@ mod theme;
use theme::{Theme, ThemeMode}; use theme::{Theme, ThemeMode};
mod waveform_gpu; mod waveform_gpu;
mod spectrogram_gpu; mod cqt_gpu;
mod spectrogram_compute;
mod config; mod config;
use config::AppConfig; use config::AppConfig;
@ -2382,16 +2381,20 @@ impl EditorApp {
let sample_rate = metadata.sample_rate; let sample_rate = metadata.sample_rate;
if let Some(ref controller_arc) = self.audio_controller { if let Some(ref controller_arc) = self.audio_controller {
// Predict the pool index (engine assigns sequentially) // Import synchronously to get the real pool index from the engine.
let pool_index = self.action_executor.document().audio_clips.len(); // NOTE: briefly blocks the UI thread (sub-ms for PCM mmap; a few ms
// for compressed streaming init).
// Send async import command (non-blocking) let pool_index = {
{
let mut controller = controller_arc.lock().unwrap(); let mut controller = controller_arc.lock().unwrap();
controller.import_audio(path.to_path_buf()); match controller.import_audio_sync(path.to_path_buf()) {
} Ok(idx) => idx,
Err(e) => {
eprintln!("Failed to import audio '{}': {}", path.display(), e);
return None;
}
}
};
// Create audio clip in document immediately (metadata is enough)
let clip = AudioClip::new_sampled(&name, pool_index, duration); let clip = AudioClip::new_sampled(&name, pool_index, duration);
let clip_id = self.action_executor.document_mut().add_audio_clip(clip); let clip_id = self.action_executor.document_mut().add_audio_clip(clip);

View File

@ -92,10 +92,6 @@ pub struct PianoRollPane {
// Resolved note cache — tracks when to invalidate // Resolved note cache — tracks when to invalidate
cached_clip_id: Option<u32>, cached_clip_id: Option<u32>,
// Spectrogram cache — keyed by audio pool index
// Stores pre-computed SpectrogramUpload data ready for GPU
spectrogram_computed: HashMap<usize, crate::spectrogram_gpu::SpectrogramUpload>,
// Spectrogram gamma (power curve for colormap) // Spectrogram gamma (power curve for colormap)
spectrogram_gamma: f32, spectrogram_gamma: f32,
} }
@ -126,8 +122,7 @@ impl PianoRollPane {
auto_scroll_enabled: true, auto_scroll_enabled: true,
user_scrolled_since_play: false, user_scrolled_since_play: false,
cached_clip_id: None, cached_clip_id: None,
spectrogram_computed: HashMap::new(), spectrogram_gamma: 0.8,
spectrogram_gamma: 5.0,
} }
} }
@ -1256,80 +1251,51 @@ impl PianoRollPane {
} }
} }
let screen_size = ui.ctx().input(|i| i.content_rect().size()); // Render CQT spectrogram for each sampled clip on this layer
// Render spectrogram for each sampled clip on this layer
for &(pool_index, timeline_start, trim_start, _duration, sample_rate) in &clip_infos { for &(pool_index, timeline_start, trim_start, _duration, sample_rate) in &clip_infos {
// Compute spectrogram if not cached // Get audio duration from the raw audio cache
let needs_compute = !self.spectrogram_computed.contains_key(&pool_index); let audio_duration = if let Some((samples, sr, ch)) = shared.raw_audio_cache.get(&pool_index) {
let pending_upload = if needs_compute { samples.len() as f64 / (*sr as f64 * *ch as f64)
if let Some((samples, sr, ch)) = shared.raw_audio_cache.get(&pool_index) {
let spec_data = crate::spectrogram_compute::compute_spectrogram(
samples, *sr, *ch, 2048, 512,
);
if spec_data.time_bins > 0 {
let upload = crate::spectrogram_gpu::SpectrogramUpload {
magnitudes: spec_data.magnitudes,
time_bins: spec_data.time_bins as u32,
freq_bins: spec_data.freq_bins as u32,
sample_rate: spec_data.sample_rate,
hop_size: spec_data.hop_size as u32,
fft_size: spec_data.fft_size as u32,
duration: spec_data.duration as f32,
};
// Store a marker so we don't recompute
self.spectrogram_computed.insert(pool_index, crate::spectrogram_gpu::SpectrogramUpload {
magnitudes: Vec::new(), // We don't need to keep the data around
time_bins: upload.time_bins,
freq_bins: upload.freq_bins,
sample_rate: upload.sample_rate,
hop_size: upload.hop_size,
fft_size: upload.fft_size,
duration: upload.duration,
});
Some(upload)
} else {
None
}
} else {
None
}
} else { } else {
None continue;
};
// Get cached spectrogram metadata for params
let spec_meta = self.spectrogram_computed.get(&pool_index);
let (time_bins, freq_bins, hop_size, fft_size, audio_duration) = match spec_meta {
Some(m) => (m.time_bins as f32, m.freq_bins as f32, m.hop_size as f32, m.fft_size as f32, m.duration),
None => continue,
}; };
if view_rect.width() > 0.0 && view_rect.height() > 0.0 { if view_rect.width() > 0.0 && view_rect.height() > 0.0 {
let callback = crate::spectrogram_gpu::SpectrogramCallback { // Calculate visible CQT column range for streaming
let viewport_end_time = self.viewport_start_time + (view_rect.width() / self.pixels_per_second) as f64;
let vis_audio_start = (self.viewport_start_time - timeline_start + trim_start).max(0.0);
let vis_audio_end = (viewport_end_time - timeline_start + trim_start).min(audio_duration);
let vis_col_start = (vis_audio_start * sample_rate as f64 / 512.0).floor() as i64;
let vis_col_end = (vis_audio_end * sample_rate as f64 / 512.0).ceil() as i64 + 1;
let callback = crate::cqt_gpu::CqtCallback {
pool_index, pool_index,
params: crate::spectrogram_gpu::SpectrogramParams { params: crate::cqt_gpu::CqtRenderParams {
clip_rect: [view_rect.min.x, view_rect.min.y, view_rect.max.x, view_rect.max.y], clip_rect: [view_rect.min.x, view_rect.min.y, view_rect.max.x, view_rect.max.y],
viewport_start_time: self.viewport_start_time as f32, viewport_start_time: self.viewport_start_time as f32,
pixels_per_second: self.pixels_per_second, pixels_per_second: self.pixels_per_second,
audio_duration, audio_duration: audio_duration as f32,
sample_rate: sample_rate as f32, sample_rate: sample_rate as f32,
clip_start_time: timeline_start as f32, clip_start_time: timeline_start as f32,
trim_start: trim_start as f32, trim_start: trim_start as f32,
time_bins, freq_bins: 174.0,
freq_bins, bins_per_octave: 24.0,
hop_size, hop_size: 512.0,
fft_size,
scroll_y: self.scroll_y, scroll_y: self.scroll_y,
note_height: self.note_height, note_height: self.note_height,
screen_size: [screen_size.x, screen_size.y],
min_note: MIN_NOTE as f32, min_note: MIN_NOTE as f32,
max_note: MAX_NOTE as f32, max_note: MAX_NOTE as f32,
gamma: self.spectrogram_gamma, gamma: self.spectrogram_gamma,
_pad: [0.0; 3], cache_capacity: 0.0, // filled by prepare()
cache_start_column: 0.0,
cache_valid_start: 0.0,
cache_valid_end: 0.0,
_pad: [0.0; 2],
}, },
target_format: shared.target_format, target_format: shared.target_format,
pending_upload, sample_rate,
visible_col_start: vis_col_start,
visible_col_end: vis_col_end,
}; };
ui.painter().add(egui_wgpu::Callback::new_paint_callback( ui.painter().add(egui_wgpu::Callback::new_paint_callback(

View File

@ -0,0 +1,101 @@
// GPU Constant-Q Transform (CQT) compute shader.
//
// Reads raw audio samples from a waveform mip-0 texture (Rgba16Float, packed
// row-major at TEX_WIDTH=2048) and computes CQT magnitude for each
// (freq_bin, time_column) pair, writing normalized dB values into a ring-buffer
// cache texture (R32Float, width=cache_capacity, height=freq_bins).
//
// Dispatch: (ceil(freq_bins / 64), num_columns, 1)
// Each thread handles one frequency bin for one time column.
struct CqtParams {
hop_size: u32,
freq_bins: u32,
cache_capacity: u32,
cache_write_offset: u32, // ring buffer position to start writing
num_columns: u32, // how many columns in this dispatch
column_start: u32, // global CQT column index of first column
tex_width: u32, // waveform texture width (2048)
total_frames: u32, // total audio frames in waveform texture
sample_rate: f32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
struct BinInfo {
window_length: u32,
phase_step: f32, // 2*pi*Q / N_k
_pad0: u32,
_pad1: u32,
}
@group(0) @binding(0) var audio_tex: texture_2d<f32>;
@group(0) @binding(1) var cqt_out: texture_storage_2d<rgba16float, write>;
@group(0) @binding(2) var<uniform> params: CqtParams;
@group(0) @binding(3) var<storage, read> bins: array<BinInfo>;
const PI2: f32 = 6.283185307;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let bin_k = gid.x;
let col_rel = gid.y; // relative to this dispatch batch
if bin_k >= params.freq_bins || col_rel >= params.num_columns {
return;
}
let global_col = params.column_start + col_rel;
let sample_start = global_col * params.hop_size;
let info = bins[bin_k];
let n_k = info.window_length;
// Center the analysis window: offset by half the window length so the
// column timestamp refers to the center of the window, not the start.
// This gives better time alignment, especially for low-frequency bins
// that have very long windows.
let half_win = n_k / 2u;
// Accumulate complex inner product: sum of x[n] * w[n] * exp(-i * phase_step * n)
var sum_re: f32 = 0.0;
var sum_im: f32 = 0.0;
for (var n = 0u; n < n_k; n++) {
// Center the window around the hop position
let raw_idx = i32(sample_start) + i32(n) - i32(half_win);
if raw_idx < 0 || u32(raw_idx) >= params.total_frames {
continue;
}
let sample_idx = u32(raw_idx);
// Read audio sample from 2D waveform texture (mip 0)
// At mip 0: R=G=left, B=A=right; average to mono
let tx = sample_idx % params.tex_width;
let ty = sample_idx / params.tex_width;
let texel = textureLoad(audio_tex, vec2<i32>(i32(tx), i32(ty)), 0);
let sample_val = (texel.r + texel.b) * 0.5;
// Hann window computed analytically
let window = 0.5 * (1.0 - cos(PI2 * f32(n) / f32(n_k)));
// Complex exponential: exp(-i * phase_step * n)
let angle = info.phase_step * f32(n);
let windowed = sample_val * window;
sum_re += windowed * cos(angle);
sum_im -= windowed * sin(angle);
}
// Magnitude, normalized by window length
let mag = sqrt(sum_re * sum_re + sum_im * sum_im) / f32(n_k);
// Convert to dB, map -80dB..0dB -> 0.0..1.0
// WGSL log() is natural log, so log10(x) = log(x) / log(10)
let db = 20.0 * log(mag + 1e-10) / 2.302585093;
let normalized = clamp((db + 80.0) / 80.0, 0.0, 1.0);
// Write to ring buffer cache texture
let cache_x = (params.cache_write_offset + col_rel) % params.cache_capacity;
textureStore(cqt_out, vec2<i32>(i32(cache_x), i32(bin_k)), vec4(normalized, 0.0, 0.0, 1.0));
}

View File

@ -1,30 +1,37 @@
// Spectrogram rendering shader for FFT magnitude data. // CQT spectrogram render shader.
// Texture layout: X = frequency bin, Y = time bin //
// Values: normalized magnitude (0.0 = silence, 1.0 = peak) // Reads from a ring-buffer cache texture (Rgba16Float) where:
// Vertical axis maps MIDI notes to frequency bins (matching piano roll) // X = time column (ring buffer index), Y = CQT frequency bin
// CQT bins map directly to MIDI notes via: bin = (note - min_note) * bins_per_octave / 12
//
// Applies the same colormap as the old FFT spectrogram.
// Must match CqtRenderParams in cqt_gpu.rs exactly (96 bytes).
struct Params { struct Params {
clip_rect: vec4<f32>, clip_rect: vec4<f32>, // 16 @ 0
viewport_start_time: f32, viewport_start_time: f32, // 4 @ 16
pixels_per_second: f32, pixels_per_second: f32, // 4 @ 20
audio_duration: f32, audio_duration: f32, // 4 @ 24
sample_rate: f32, sample_rate: f32, // 4 @ 28
clip_start_time: f32, clip_start_time: f32, // 4 @ 32
trim_start: f32, trim_start: f32, // 4 @ 36
time_bins: f32, freq_bins: f32, // 4 @ 40
freq_bins: f32, bins_per_octave: f32, // 4 @ 44
hop_size: f32, hop_size: f32, // 4 @ 48
fft_size: f32, scroll_y: f32, // 4 @ 52
scroll_y: f32, note_height: f32, // 4 @ 56
note_height: f32, min_note: f32, // 4 @ 60
screen_size: vec2<f32>, max_note: f32, // 4 @ 64
min_note: f32, gamma: f32, // 4 @ 68
max_note: f32, cache_capacity: f32, // 4 @ 72
gamma: f32, cache_start_column: f32, // 4 @ 76
cache_valid_start: f32, // 4 @ 80
cache_valid_end: f32, // 4 @ 84
_pad: vec2<f32>, // 8 @ 88, total 96
} }
@group(0) @binding(0) var spec_tex: texture_2d<f32>; @group(0) @binding(0) var cache_tex: texture_2d<f32>;
@group(0) @binding(1) var spec_sampler: sampler; @group(0) @binding(1) var cache_sampler: sampler;
@group(0) @binding(2) var<uniform> params: Params; @group(0) @binding(2) var<uniform> params: Params;
struct VertexOutput { struct VertexOutput {
@ -42,7 +49,6 @@ fn vs_main(@builtin(vertex_index) vi: u32) -> VertexOutput {
return out; return out;
} }
// Signed distance from point to rounded rectangle boundary
fn rounded_rect_sdf(pos: vec2<f32>, rect_min: vec2<f32>, rect_max: vec2<f32>, r: f32) -> f32 { fn rounded_rect_sdf(pos: vec2<f32>, rect_min: vec2<f32>, rect_max: vec2<f32>, r: f32) -> f32 {
let center = (rect_min + rect_max) * 0.5; let center = (rect_min + rect_max) * 0.5;
let half_size = (rect_max - rect_min) * 0.5; let half_size = (rect_max - rect_min) * 0.5;
@ -55,27 +61,21 @@ fn colormap(v: f32, gamma: f32) -> vec4<f32> {
let t = pow(clamp(v, 0.0, 1.0), gamma); let t = pow(clamp(v, 0.0, 1.0), gamma);
if t < 1.0 / 6.0 { if t < 1.0 / 6.0 {
// Black -> blue
let s = t * 6.0; let s = t * 6.0;
return vec4(0.0, 0.0, s, 1.0); return vec4(0.0, 0.0, s, 1.0);
} else if t < 2.0 / 6.0 { } else if t < 2.0 / 6.0 {
// Blue -> purple
let s = (t - 1.0 / 6.0) * 6.0; let s = (t - 1.0 / 6.0) * 6.0;
return vec4(s * 0.6, 0.0, 1.0 - s * 0.2, 1.0); return vec4(s * 0.6, 0.0, 1.0 - s * 0.2, 1.0);
} else if t < 3.0 / 6.0 { } else if t < 3.0 / 6.0 {
// Purple -> red
let s = (t - 2.0 / 6.0) * 6.0; let s = (t - 2.0 / 6.0) * 6.0;
return vec4(0.6 + s * 0.4, 0.0, 0.8 - s * 0.8, 1.0); return vec4(0.6 + s * 0.4, 0.0, 0.8 - s * 0.8, 1.0);
} else if t < 4.0 / 6.0 { } else if t < 4.0 / 6.0 {
// Red -> orange
let s = (t - 3.0 / 6.0) * 6.0; let s = (t - 3.0 / 6.0) * 6.0;
return vec4(1.0, s * 0.5, 0.0, 1.0); return vec4(1.0, s * 0.5, 0.0, 1.0);
} else if t < 5.0 / 6.0 { } else if t < 5.0 / 6.0 {
// Orange -> yellow
let s = (t - 4.0 / 6.0) * 6.0; let s = (t - 4.0 / 6.0) * 6.0;
return vec4(1.0, 0.5 + s * 0.5, 0.0, 1.0); return vec4(1.0, 0.5 + s * 0.5, 0.0, 1.0);
} else { } else {
// Yellow -> white
let s = (t - 5.0 / 6.0) * 6.0; let s = (t - 5.0 / 6.0) * 6.0;
return vec4(1.0, 1.0, s, 1.0); return vec4(1.0, 1.0, s, 1.0);
} }
@ -98,11 +98,9 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let content_top = params.clip_rect.y - params.scroll_y; let content_top = params.clip_rect.y - params.scroll_y;
let content_bottom = params.clip_rect.y + (params.max_note - params.min_note + 1.0) * params.note_height - params.scroll_y; let content_bottom = params.clip_rect.y + (params.max_note - params.min_note + 1.0) * params.note_height - params.scroll_y;
// Rounded corners: content edges on X, visible viewport edges on Y. // Rounded corners
// This rounds left/right where the clip starts/ends, and top/bottom at the view boundary.
let vis_top = max(content_top, params.clip_rect.y); let vis_top = max(content_top, params.clip_rect.y);
let vis_bottom = min(content_bottom, params.clip_rect.w); let vis_bottom = min(content_bottom, params.clip_rect.w);
let corner_radius = 6.0; let corner_radius = 6.0;
let dist = rounded_rect_sdf( let dist = rounded_rect_sdf(
vec2(frag_x, frag_y), vec2(frag_x, frag_y),
@ -114,7 +112,7 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
discard; discard;
} }
// Fragment X -> audio time -> time bin // Fragment X -> audio time -> global CQT column
let timeline_time = params.viewport_start_time + (frag_x - params.clip_rect.x) / params.pixels_per_second; let timeline_time = params.viewport_start_time + (frag_x - params.clip_rect.x) / params.pixels_per_second;
let audio_time = timeline_time - params.clip_start_time + params.trim_start; let audio_time = timeline_time - params.clip_start_time + params.trim_start;
@ -122,32 +120,35 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
discard; discard;
} }
let time_bin = audio_time * params.sample_rate / params.hop_size; let global_col = audio_time * params.sample_rate / params.hop_size;
if time_bin < 0.0 || time_bin >= params.time_bins {
// Check if this column is in the cached range
if global_col < params.cache_valid_start || global_col >= params.cache_valid_end {
discard; discard;
} }
// Fragment Y -> MIDI note -> frequency -> frequency bin // Fragment Y -> MIDI note -> CQT bin (direct mapping!)
let note = params.max_note - ((frag_y - params.clip_rect.y + params.scroll_y) / params.note_height); let note = params.max_note - ((frag_y - params.clip_rect.y + params.scroll_y) / params.note_height);
if note < params.min_note || note > params.max_note { if note < params.min_note || note > params.max_note {
discard; discard;
} }
// MIDI note -> frequency: freq = 440 * 2^((note - 69) / 12) // CQT bin: each octave has bins_per_octave bins, starting from min_note
let freq = 440.0 * pow(2.0, (note - 69.0) / 12.0); let bin = (note - params.min_note) * params.bins_per_octave / 12.0;
// Frequency -> FFT bin index if bin < 0.0 || bin >= params.freq_bins {
let freq_bin = freq * params.fft_size / params.sample_rate;
if freq_bin < 0.0 || freq_bin >= params.freq_bins {
discard; discard;
} }
// Sample texture with bilinear filtering // Map global column to ring buffer position
let u = freq_bin / params.freq_bins; let ring_pos = global_col - params.cache_start_column;
let v = time_bin / params.time_bins; let cache_x = ring_pos % params.cache_capacity;
let magnitude = textureSampleLevel(spec_tex, spec_sampler, vec2(u, v), 0.0).r;
// Sample cache texture with bilinear filtering
let u = (cache_x + 0.5) / params.cache_capacity;
let v = (bin + 0.5) / params.freq_bins;
let magnitude = textureSampleLevel(cache_tex, cache_sampler, vec2(u, v), 0.0).r;
return colormap(magnitude, params.gamma); return colormap(magnitude, params.gamma);
} }

View File

@ -1,144 +0,0 @@
/// CPU-side FFT computation for spectrogram visualization.
///
/// Uses rayon to parallelize FFT across time slices on all CPU cores.
/// Produces a 2D magnitude grid (time bins x frequency bins) for GPU texture upload.
use rayon::prelude::*;
use std::f32::consts::PI;
/// Pre-computed spectrogram data ready for GPU upload
pub struct SpectrogramData {
/// Flattened 2D array of normalized magnitudes [time_bins * freq_bins], row-major
/// Each value is 0.0 (silence) to 1.0 (peak), log-scale normalized
pub magnitudes: Vec<f32>,
pub time_bins: usize,
pub freq_bins: usize,
pub sample_rate: u32,
pub hop_size: usize,
pub fft_size: usize,
pub duration: f64,
}
/// Compute a spectrogram from raw audio samples using parallel FFT.
///
/// Each time slice is processed independently via rayon, making this
/// scale well across all CPU cores.
pub fn compute_spectrogram(
samples: &[f32],
sample_rate: u32,
channels: u32,
fft_size: usize,
hop_size: usize,
) -> SpectrogramData {
// Mix to mono
let mono: Vec<f32> = if channels >= 2 {
samples
.chunks(channels as usize)
.map(|frame| frame.iter().sum::<f32>() / channels as f32)
.collect()
} else {
samples.to_vec()
};
let freq_bins = fft_size / 2 + 1;
let duration = mono.len() as f64 / sample_rate as f64;
if mono.len() < fft_size {
return SpectrogramData {
magnitudes: Vec::new(),
time_bins: 0,
freq_bins,
sample_rate,
hop_size,
fft_size,
duration,
};
}
let time_bins = (mono.len().saturating_sub(fft_size)) / hop_size + 1;
// Precompute Hann window
let window: Vec<f32> = (0..fft_size)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / fft_size as f32).cos()))
.collect();
// Precompute twiddle factors for Cooley-Tukey radix-2 FFT
let twiddles: Vec<(f32, f32)> = (0..fft_size / 2)
.map(|k| {
let angle = -2.0 * PI * k as f32 / fft_size as f32;
(angle.cos(), angle.sin())
})
.collect();
// Bit-reversal permutation table
let bits = (fft_size as f32).log2() as u32;
let bit_rev: Vec<usize> = (0..fft_size)
.map(|i| (i as u32).reverse_bits().wrapping_shr(32 - bits) as usize)
.collect();
// Process all time slices in parallel
let magnitudes: Vec<f32> = (0..time_bins)
.into_par_iter()
.flat_map(|t| {
let offset = t * hop_size;
let mut re = vec![0.0f32; fft_size];
let mut im = vec![0.0f32; fft_size];
// Load windowed samples in bit-reversed order
for i in 0..fft_size {
let sample = if offset + i < mono.len() {
mono[offset + i]
} else {
0.0
};
re[bit_rev[i]] = sample * window[i];
}
// Cooley-Tukey radix-2 DIT FFT
let mut half_size = 1;
while half_size < fft_size {
let step = half_size * 2;
let twiddle_step = fft_size / step;
for k in (0..fft_size).step_by(step) {
for j in 0..half_size {
let tw_idx = j * twiddle_step;
let (tw_re, tw_im) = twiddles[tw_idx];
let a = k + j;
let b = a + half_size;
let t_re = tw_re * re[b] - tw_im * im[b];
let t_im = tw_re * im[b] + tw_im * re[b];
re[b] = re[a] - t_re;
im[b] = im[a] - t_im;
re[a] += t_re;
im[a] += t_im;
}
}
half_size = step;
}
// Extract magnitudes for positive frequencies
let mut mags = Vec::with_capacity(freq_bins);
for f in 0..freq_bins {
let mag = (re[f] * re[f] + im[f] * im[f]).sqrt();
// dB normalization: -80dB floor to 0dB ceiling → 0.0 to 1.0
let db = 20.0 * (mag + 1e-10).log10();
mags.push(((db + 80.0) / 80.0).clamp(0.0, 1.0));
}
mags
})
.collect();
SpectrogramData {
magnitudes,
time_bins,
freq_bins,
sample_rate,
hop_size,
fft_size,
duration,
}
}

View File

@ -1,350 +0,0 @@
/// GPU resources for spectrogram rendering.
///
/// Follows the same pattern as waveform_gpu.rs:
/// - SpectrogramGpuResources stored in CallbackResources (long-lived)
/// - SpectrogramCallback implements egui_wgpu::CallbackTrait (per-frame)
/// - R32Float texture holds magnitude data (time bins × freq bins)
/// - Fragment shader applies colormap and frequency mapping
use std::collections::HashMap;
/// GPU resources for all spectrograms (stored in egui_wgpu::CallbackResources)
pub struct SpectrogramGpuResources {
pub entries: HashMap<usize, SpectrogramGpuEntry>,
render_pipeline: wgpu::RenderPipeline,
render_bind_group_layout: wgpu::BindGroupLayout,
sampler: wgpu::Sampler,
}
/// Per-audio-pool GPU data for one spectrogram
#[allow(dead_code)]
pub struct SpectrogramGpuEntry {
pub texture: wgpu::Texture,
pub texture_view: wgpu::TextureView,
pub render_bind_group: wgpu::BindGroup,
pub uniform_buffer: wgpu::Buffer,
pub time_bins: u32,
pub freq_bins: u32,
pub sample_rate: u32,
pub hop_size: u32,
pub fft_size: u32,
pub duration: f32,
}
/// Uniform buffer struct — must match spectrogram.wgsl Params exactly
#[repr(C)]
#[derive(Debug, Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
pub struct SpectrogramParams {
pub clip_rect: [f32; 4], // 16 bytes @ offset 0
pub viewport_start_time: f32, // 4 bytes @ offset 16
pub pixels_per_second: f32, // 4 bytes @ offset 20
pub audio_duration: f32, // 4 bytes @ offset 24
pub sample_rate: f32, // 4 bytes @ offset 28
pub clip_start_time: f32, // 4 bytes @ offset 32
pub trim_start: f32, // 4 bytes @ offset 36
pub time_bins: f32, // 4 bytes @ offset 40
pub freq_bins: f32, // 4 bytes @ offset 44
pub hop_size: f32, // 4 bytes @ offset 48
pub fft_size: f32, // 4 bytes @ offset 52
pub scroll_y: f32, // 4 bytes @ offset 56
pub note_height: f32, // 4 bytes @ offset 60
pub screen_size: [f32; 2], // 8 bytes @ offset 64
pub min_note: f32, // 4 bytes @ offset 72
pub max_note: f32, // 4 bytes @ offset 76
pub gamma: f32, // 4 bytes @ offset 80
pub _pad: [f32; 3], // 12 bytes @ offset 84 (pad to 96 for WGSL struct alignment)
}
// Total: 96 bytes (multiple of 16 for vec4 alignment)
/// Data for a pending spectrogram texture upload
pub struct SpectrogramUpload {
pub magnitudes: Vec<f32>,
pub time_bins: u32,
pub freq_bins: u32,
pub sample_rate: u32,
pub hop_size: u32,
pub fft_size: u32,
pub duration: f32,
}
/// Per-frame callback for rendering one spectrogram instance
pub struct SpectrogramCallback {
pub pool_index: usize,
pub params: SpectrogramParams,
pub target_format: wgpu::TextureFormat,
pub pending_upload: Option<SpectrogramUpload>,
}
impl SpectrogramGpuResources {
pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self {
// Shader
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("spectrogram_render_shader"),
source: wgpu::ShaderSource::Wgsl(
include_str!("panes/shaders/spectrogram.wgsl").into(),
),
});
// Bind group layout: texture + sampler + uniforms
let render_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("spectrogram_render_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
// Render pipeline
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("spectrogram_pipeline_layout"),
bind_group_layouts: &[&render_bind_group_layout],
push_constant_ranges: &[],
});
let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("spectrogram_render_pipeline"),
layout: Some(&pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[],
compilation_options: Default::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format: target_format,
blend: Some(wgpu::BlendState::ALPHA_BLENDING),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: Default::default(),
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleList,
..Default::default()
},
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
multiview: None,
cache: None,
});
// Bilinear sampler for smooth frequency interpolation
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("spectrogram_sampler"),
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::FilterMode::Nearest,
..Default::default()
});
Self {
entries: HashMap::new(),
render_pipeline,
render_bind_group_layout,
sampler,
}
}
/// Upload pre-computed spectrogram magnitude data as a GPU texture
pub fn upload_spectrogram(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
pool_index: usize,
upload: &SpectrogramUpload,
) {
// Remove old entry
self.entries.remove(&pool_index);
if upload.time_bins == 0 || upload.freq_bins == 0 {
return;
}
// Data layout: magnitudes[t * freq_bins + f] — each row is one time slice
// with freq_bins values. So texture width = freq_bins, height = time_bins.
// R8Unorm is filterable (unlike R32Float) for bilinear interpolation.
let tex_width = upload.freq_bins;
let tex_height = upload.time_bins;
let texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some(&format!("spectrogram_{}", pool_index)),
size: wgpu::Extent3d {
width: tex_width,
height: tex_height,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: wgpu::TextureFormat::R8Unorm,
usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
view_formats: &[],
});
// Convert f32 magnitudes to u8 for R8Unorm, with row padding for alignment.
// wgpu requires bytes_per_row to be a multiple of COPY_BYTES_PER_ROW_ALIGNMENT (256).
let align = wgpu::COPY_BYTES_PER_ROW_ALIGNMENT;
let unpadded_row = tex_width; // 1 byte per texel for R8Unorm
let padded_row = (unpadded_row + align - 1) / align * align;
let mut texel_data = vec![0u8; padded_row as usize * tex_height as usize];
for row in 0..tex_height as usize {
let src_offset = row * tex_width as usize;
let dst_offset = row * padded_row as usize;
for col in 0..tex_width as usize {
let m = upload.magnitudes[src_offset + col];
texel_data[dst_offset + col] = (m.clamp(0.0, 1.0) * 255.0) as u8;
}
}
// Upload magnitude data
queue.write_texture(
wgpu::TexelCopyTextureInfo {
texture: &texture,
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
&texel_data,
wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(padded_row),
rows_per_image: Some(tex_height),
},
wgpu::Extent3d {
width: tex_width,
height: tex_height,
depth_or_array_layers: 1,
},
);
let texture_view = texture.create_view(&wgpu::TextureViewDescriptor {
label: Some(&format!("spectrogram_{}_view", pool_index)),
..Default::default()
});
let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("spectrogram_{}_uniforms", pool_index)),
size: std::mem::size_of::<SpectrogramParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("spectrogram_{}_bg", pool_index)),
layout: &self.render_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(&texture_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::Sampler(&self.sampler),
},
wgpu::BindGroupEntry {
binding: 2,
resource: uniform_buffer.as_entire_binding(),
},
],
});
self.entries.insert(
pool_index,
SpectrogramGpuEntry {
texture,
texture_view,
render_bind_group,
uniform_buffer,
time_bins: upload.time_bins,
freq_bins: upload.freq_bins,
sample_rate: upload.sample_rate,
hop_size: upload.hop_size,
fft_size: upload.fft_size,
duration: upload.duration,
},
);
}
}
impl egui_wgpu::CallbackTrait for SpectrogramCallback {
fn prepare(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
_screen_descriptor: &egui_wgpu::ScreenDescriptor,
_egui_encoder: &mut wgpu::CommandEncoder,
resources: &mut egui_wgpu::CallbackResources,
) -> Vec<wgpu::CommandBuffer> {
// Initialize global resources on first use
if !resources.contains::<SpectrogramGpuResources>() {
resources.insert(SpectrogramGpuResources::new(device, self.target_format));
}
let gpu: &mut SpectrogramGpuResources = resources.get_mut().unwrap();
// Handle pending upload
if let Some(ref upload) = self.pending_upload {
gpu.upload_spectrogram(device, queue, self.pool_index, upload);
}
// Update uniform buffer
if let Some(entry) = gpu.entries.get(&self.pool_index) {
queue.write_buffer(
&entry.uniform_buffer,
0,
bytemuck::cast_slice(&[self.params]),
);
}
Vec::new()
}
fn paint(
&self,
_info: eframe::egui::PaintCallbackInfo,
render_pass: &mut wgpu::RenderPass<'static>,
resources: &egui_wgpu::CallbackResources,
) {
let gpu: &SpectrogramGpuResources = match resources.get() {
Some(r) => r,
None => return,
};
let entry = match gpu.entries.get(&self.pool_index) {
Some(e) => e,
None => return,
};
render_pass.set_pipeline(&gpu.render_pipeline);
render_pass.set_bind_group(0, &entry.render_bind_group, &[]);
render_pass.draw(0..3, 0..1); // Fullscreen triangle
}
}