/// GPU-based Constant-Q Transform (CQT) spectrogram with streaming ring-buffer cache. /// /// Replaces the old FFT spectrogram with a CQT that has logarithmic frequency spacing /// (bins map directly to MIDI notes). Only the visible viewport is computed, with results /// cached in a ring-buffer texture so scrolling only computes new columns. /// /// Architecture: /// - CqtGpuResources stored in CallbackResources (long-lived, holds pipelines) /// - CqtCacheEntry per pool_index (cache texture, bin params, ring buffer state) /// - CqtCallback implements CallbackTrait (per-frame compute + render) /// - Compute shader reads audio from waveform mip-0 textures (already on GPU) /// - Render shader reads from cache texture with colormap use std::collections::HashMap; use wgpu::util::DeviceExt; use crate::waveform_gpu::WaveformGpuResources; /// CQT parameters const BINS_PER_OCTAVE: u32 = 24; const FREQ_BINS: u32 = 174; // ceil(log2(4186.0 / 27.5) * 24) = ceil(173.95) const HOP_SIZE: u32 = 512; const CACHE_CAPACITY: u32 = 4096; const MAX_COLS_PER_FRAME: u32 = 128; const F_MIN: f64 = 27.5; // A0 = MIDI 21 const WAVEFORM_TEX_WIDTH: u32 = 2048; /// Per-bin CQT kernel parameters, uploaded as a storage buffer. /// Must match BinInfo in cqt_compute.wgsl. #[repr(C)] #[derive(Debug, Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] struct CqtBinParams { window_length: u32, phase_step: f32, // 2*pi*Q / N_k _pad0: u32, _pad1: u32, } /// Compute shader uniform params. Must match CqtParams in cqt_compute.wgsl. #[repr(C)] #[derive(Debug, Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] struct CqtComputeParams { hop_size: u32, freq_bins: u32, cache_capacity: u32, cache_write_offset: u32, num_columns: u32, column_start: u32, tex_width: u32, total_frames: u32, sample_rate: f32, column_stride: u32, _pad1: u32, _pad2: u32, } /// Render shader uniform params. Must match Params in cqt_render.wgsl exactly. /// Layout: clip_rect(16) + 18 × f32(72) + pad vec2(8) = 96 bytes #[repr(C)] #[derive(Debug, Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] pub struct CqtRenderParams { pub clip_rect: [f32; 4], // 16 bytes @ offset 0 pub viewport_start_time: f32, // 4 @ 16 pub pixels_per_second: f32, // 4 @ 20 pub audio_duration: f32, // 4 @ 24 pub sample_rate: f32, // 4 @ 28 pub clip_start_time: f32, // 4 @ 32 pub trim_start: f32, // 4 @ 36 pub freq_bins: f32, // 4 @ 40 pub bins_per_octave: f32, // 4 @ 44 pub hop_size: f32, // 4 @ 48 pub scroll_y: f32, // 4 @ 52 pub note_height: f32, // 4 @ 56 pub min_note: f32, // 4 @ 60 pub max_note: f32, // 4 @ 64 pub gamma: f32, // 4 @ 68 pub cache_capacity: f32, // 4 @ 72 pub cache_start_column: f32, // 4 @ 76 pub cache_valid_start: f32, // 4 @ 80 pub cache_valid_end: f32, // 4 @ 84 pub column_stride: f32, // 4 @ 88 pub _pad: f32, // 4 @ 92, total 96 } /// Per-pool-index cache entry with ring buffer and GPU resources. #[allow(dead_code)] struct CqtCacheEntry { // Cache texture (Rgba16Float for universal filterable + storage support) cache_texture: wgpu::Texture, cache_texture_view: wgpu::TextureView, cache_storage_view: wgpu::TextureView, cache_capacity: u32, freq_bins: u32, // Ring buffer state cache_start_column: i64, cache_valid_start: i64, cache_valid_end: i64, // CQT kernel data bin_params_buffer: wgpu::Buffer, // Waveform texture reference (cloned from WaveformGpuEntry) waveform_texture_view: wgpu::TextureView, waveform_total_frames: u64, // Bind groups compute_bind_group: wgpu::BindGroup, compute_uniform_buffer: wgpu::Buffer, render_bind_group: wgpu::BindGroup, render_uniform_buffer: wgpu::Buffer, // Metadata sample_rate: u32, current_stride: u32, } /// Global GPU resources for CQT (stored in egui_wgpu::CallbackResources). pub struct CqtGpuResources { entries: HashMap, compute_pipeline: wgpu::ComputePipeline, compute_bind_group_layout: wgpu::BindGroupLayout, render_pipeline: wgpu::RenderPipeline, render_bind_group_layout: wgpu::BindGroupLayout, sampler: wgpu::Sampler, } /// Per-frame callback for computing and rendering a CQT spectrogram. pub struct CqtCallback { pub pool_index: usize, pub params: CqtRenderParams, pub target_format: wgpu::TextureFormat, pub sample_rate: u32, /// Visible column range (global CQT column indices) pub visible_col_start: i64, pub visible_col_end: i64, /// Column stride: 1 = full resolution, N = compute every Nth column pub stride: u32, } /// Precompute CQT bin parameters for a given sample rate. fn precompute_bin_params(sample_rate: u32) -> Vec { let b = BINS_PER_OCTAVE as f64; let q = 1.0 / (2.0_f64.powf(1.0 / b) - 1.0); (0..FREQ_BINS) .map(|k| { let f_k = F_MIN * 2.0_f64.powf(k as f64 / b); let n_k = (q * sample_rate as f64 / f_k).ceil() as u32; let phase_step = (2.0 * std::f64::consts::PI * q / n_k as f64) as f32; CqtBinParams { window_length: n_k, phase_step, _pad0: 0, _pad1: 0, } }) .collect() } impl CqtGpuResources { pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self { // Compute shader let compute_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { label: Some("cqt_compute_shader"), source: wgpu::ShaderSource::Wgsl( include_str!("panes/shaders/cqt_compute.wgsl").into(), ), }); // Render shader let render_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { label: Some("cqt_render_shader"), source: wgpu::ShaderSource::Wgsl( include_str!("panes/shaders/cqt_render.wgsl").into(), ), }); // Compute bind group layout: // 0: audio_tex (texture_2d, read) // 1: cqt_out (texture_storage_2d) // 2: params (uniform) // 3: bins (storage, read) let compute_bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some("cqt_compute_bgl"), entries: &[ wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Texture { sample_type: wgpu::TextureSampleType::Float { filterable: false }, view_dimension: wgpu::TextureViewDimension::D2, multisampled: false, }, count: None, }, wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::StorageTexture { access: wgpu::StorageTextureAccess::WriteOnly, format: wgpu::TextureFormat::Rgba16Float, view_dimension: wgpu::TextureViewDimension::D2, }, count: None, }, wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None, }, count: None, }, wgpu::BindGroupLayoutEntry { binding: 3, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None, }, count: None, }, ], }); // Render bind group layout: cache_tex + sampler + uniforms let render_bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some("cqt_render_bgl"), entries: &[ wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::FRAGMENT, ty: wgpu::BindingType::Texture { sample_type: wgpu::TextureSampleType::Float { filterable: true }, view_dimension: wgpu::TextureViewDimension::D2, multisampled: false, }, count: None, }, wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::FRAGMENT, ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering), count: None, }, wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::FRAGMENT, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None, }, count: None, }, ], }); // Compute pipeline let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: Some("cqt_compute_pipeline_layout"), bind_group_layouts: &[&compute_bind_group_layout], push_constant_ranges: &[], }); let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: Some("cqt_compute_pipeline"), layout: Some(&compute_pipeline_layout), module: &compute_shader, entry_point: Some("main"), compilation_options: Default::default(), cache: None, }); // Render pipeline let render_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: Some("cqt_render_pipeline_layout"), bind_group_layouts: &[&render_bind_group_layout], push_constant_ranges: &[], }); let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { label: Some("cqt_render_pipeline"), layout: Some(&render_pipeline_layout), vertex: wgpu::VertexState { module: &render_shader, entry_point: Some("vs_main"), buffers: &[], compilation_options: Default::default(), }, fragment: Some(wgpu::FragmentState { module: &render_shader, entry_point: Some("fs_main"), targets: &[Some(wgpu::ColorTargetState { format: target_format, blend: Some(wgpu::BlendState::ALPHA_BLENDING), write_mask: wgpu::ColorWrites::ALL, })], compilation_options: Default::default(), }), primitive: wgpu::PrimitiveState { topology: wgpu::PrimitiveTopology::TriangleList, ..Default::default() }, depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, cache: None, }); // Bilinear sampler for smooth interpolation in render shader let sampler = device.create_sampler(&wgpu::SamplerDescriptor { label: Some("cqt_sampler"), mag_filter: wgpu::FilterMode::Linear, min_filter: wgpu::FilterMode::Linear, mipmap_filter: wgpu::FilterMode::Nearest, ..Default::default() }); Self { entries: HashMap::new(), compute_pipeline, compute_bind_group_layout, render_pipeline, render_bind_group_layout, sampler, } } /// Create a cache entry for a pool index, referencing the waveform texture. fn ensure_cache_entry( &mut self, device: &wgpu::Device, pool_index: usize, waveform_texture_view: wgpu::TextureView, total_frames: u64, sample_rate: u32, ) { // If entry exists, check if waveform data has grown (progressive decode) if let Some(entry) = self.entries.get_mut(&pool_index) { if entry.waveform_total_frames != total_frames { // Waveform texture updated in-place with more data. // Invalidate cached CQT columns near the old boundary — they were // computed with truncated windows and show artifacts. Rewind // cache_valid_end so those columns get recomputed with full data. // The lowest bin (A0=27.5Hz) has the longest window: // Q = 1/(2^(1/B)-1), N_k = ceil(Q * sr / f_min), margin = N_k/2/hop let q = 1.0 / (2.0_f64.powf(1.0 / BINS_PER_OCTAVE as f64) - 1.0); let max_window = (q * entry.sample_rate as f64 / F_MIN).ceil() as i64; let margin_cols = max_window / 2 / HOP_SIZE as i64 + 1; let old_max_col = entry.waveform_total_frames as i64 / HOP_SIZE as i64; entry.waveform_total_frames = total_frames; let invalidate_from = old_max_col - margin_cols; if entry.cache_valid_end > invalidate_from { entry.cache_valid_end = invalidate_from.max(entry.cache_valid_start); } } return; } // Create cache texture (ring buffer) let cache_texture = device.create_texture(&wgpu::TextureDescriptor { label: Some(&format!("cqt_cache_{}", pool_index)), size: wgpu::Extent3d { width: CACHE_CAPACITY, height: FREQ_BINS, depth_or_array_layers: 1, }, mip_level_count: 1, sample_count: 1, dimension: wgpu::TextureDimension::D2, format: wgpu::TextureFormat::Rgba16Float, usage: wgpu::TextureUsages::STORAGE_BINDING | wgpu::TextureUsages::TEXTURE_BINDING, view_formats: &[], }); let cache_texture_view = cache_texture.create_view(&wgpu::TextureViewDescriptor { label: Some(&format!("cqt_cache_{}_view", pool_index)), ..Default::default() }); let cache_storage_view = cache_texture.create_view(&wgpu::TextureViewDescriptor { label: Some(&format!("cqt_cache_{}_storage", pool_index)), ..Default::default() }); // Precompute bin params let bin_params = precompute_bin_params(sample_rate); let bin_params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: Some(&format!("cqt_bins_{}", pool_index)), contents: bytemuck::cast_slice(&bin_params), usage: wgpu::BufferUsages::STORAGE, }); // Compute uniform buffer let compute_uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some(&format!("cqt_compute_uniforms_{}", pool_index)), size: std::mem::size_of::() as u64, usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Render uniform buffer let render_uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some(&format!("cqt_render_uniforms_{}", pool_index)), size: std::mem::size_of::() as u64, usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Compute bind group let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { label: Some(&format!("cqt_compute_bg_{}", pool_index)), layout: &self.compute_bind_group_layout, entries: &[ wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(&waveform_texture_view), }, wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::TextureView(&cache_storage_view), }, wgpu::BindGroupEntry { binding: 2, resource: compute_uniform_buffer.as_entire_binding(), }, wgpu::BindGroupEntry { binding: 3, resource: bin_params_buffer.as_entire_binding(), }, ], }); // Render bind group let render_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { label: Some(&format!("cqt_render_bg_{}", pool_index)), layout: &self.render_bind_group_layout, entries: &[ wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(&cache_texture_view), }, wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::Sampler(&self.sampler), }, wgpu::BindGroupEntry { binding: 2, resource: render_uniform_buffer.as_entire_binding(), }, ], }); self.entries.insert( pool_index, CqtCacheEntry { cache_texture, cache_texture_view, cache_storage_view, cache_capacity: CACHE_CAPACITY, freq_bins: FREQ_BINS, cache_start_column: 0, cache_valid_start: 0, cache_valid_end: 0, bin_params_buffer, waveform_texture_view, waveform_total_frames: total_frames, compute_bind_group, compute_uniform_buffer, render_bind_group, render_uniform_buffer, sample_rate, current_stride: 1, }, ); } } /// Dispatch compute shader to fill CQT columns in the cache. /// Free function to avoid borrow conflicts with CqtGpuResources.entries. fn dispatch_cqt_compute( device: &wgpu::Device, queue: &wgpu::Queue, pipeline: &wgpu::ComputePipeline, entry: &CqtCacheEntry, start_col: i64, end_col: i64, stride: u32, ) -> Vec { // Number of cache slots needed (each slot covers `stride` global columns) let num_cols = ((end_col - start_col) as u32 / stride).max(1); if end_col <= start_col { return Vec::new(); } // Clamp to max per frame let num_cols = num_cols.min(MAX_COLS_PER_FRAME); // Calculate ring buffer write offset (in cache slots, not global columns) let cache_write_offset = (((start_col - entry.cache_start_column) / stride as i64) as u32) % entry.cache_capacity; let params = CqtComputeParams { hop_size: HOP_SIZE, freq_bins: FREQ_BINS, cache_capacity: entry.cache_capacity, cache_write_offset, num_columns: num_cols, column_start: start_col.max(0) as u32, tex_width: WAVEFORM_TEX_WIDTH, total_frames: entry.waveform_total_frames as u32, sample_rate: entry.sample_rate as f32, column_stride: stride, _pad1: 0, _pad2: 0, }; queue.write_buffer( &entry.compute_uniform_buffer, 0, bytemuck::cast_slice(&[params]), ); let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cqt_compute_encoder"), }); { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("cqt_compute_pass"), timestamp_writes: None, }); pass.set_pipeline(pipeline); pass.set_bind_group(0, &entry.compute_bind_group, &[]); // Dispatch: X = ceil(freq_bins / 64), Y = num_columns let workgroups_x = (FREQ_BINS + 63) / 64; pass.dispatch_workgroups(workgroups_x, num_cols, 1); } vec![encoder.finish()] } impl egui_wgpu::CallbackTrait for CqtCallback { fn prepare( &self, device: &wgpu::Device, queue: &wgpu::Queue, _screen_descriptor: &egui_wgpu::ScreenDescriptor, _egui_encoder: &mut wgpu::CommandEncoder, resources: &mut egui_wgpu::CallbackResources, ) -> Vec { // Initialize CQT resources if needed if !resources.contains::() { resources.insert(CqtGpuResources::new(device, self.target_format)); } // First, check if waveform data is available and extract what we need let waveform_info: Option<(wgpu::TextureView, u64)> = { let waveform_gpu: Option<&WaveformGpuResources> = resources.get(); waveform_gpu.and_then(|wgpu_res| { wgpu_res.entries.get(&self.pool_index).map(|entry| { // Clone the texture view (Arc internally, cheap) (entry.texture_views[0].clone(), entry.total_frames) }) }) }; let (waveform_view, total_frames) = match waveform_info { Some(info) => info, None => return Vec::new(), // Waveform not uploaded yet }; let cqt_gpu: &mut CqtGpuResources = resources.get_mut().unwrap(); // Ensure cache entry exists cqt_gpu.ensure_cache_entry( device, self.pool_index, waveform_view, total_frames, self.sample_rate, ); // Determine which columns need computing let stride = self.stride.max(1) as i64; let vis_start = self.visible_col_start.max(0); let max_col = (total_frames as i64) / HOP_SIZE as i64; let vis_end_raw = self.visible_col_end.min(max_col); // Clamp visible range to cache capacity (in global columns, accounting for stride) let vis_end = vis_end_raw.min(vis_start + CACHE_CAPACITY as i64 * stride); // If stride changed, invalidate cache { let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap(); if entry.current_stride != self.stride { entry.current_stride = self.stride; entry.cache_start_column = vis_start; entry.cache_valid_start = vis_start; entry.cache_valid_end = vis_start; } } // Stride-aware max columns per frame (in global column units) let max_cols_global = MAX_COLS_PER_FRAME as i64 * stride; // Read current cache state, compute what's needed, then update state. // We split borrows carefully: read entry state, compute, then write back. let cmds; { let entry = cqt_gpu.entries.get(&self.pool_index).unwrap(); let cache_valid_start = entry.cache_valid_start; let cache_valid_end = entry.cache_valid_end; if vis_start >= vis_end { cmds = Vec::new(); } else if vis_start >= cache_valid_start && vis_end <= cache_valid_end { // Fully cached cmds = Vec::new(); } else if vis_start >= cache_valid_start && vis_start < cache_valid_end && vis_end > cache_valid_end { // Scrolling right — compute new columns at the right edge. // cache_start_column stays fixed; the ring buffer wraps naturally. let actual_end = cache_valid_end + (vis_end - cache_valid_end).min(max_cols_global); cmds = dispatch_cqt_compute( device, queue, &cqt_gpu.compute_pipeline, entry, cache_valid_end, actual_end, self.stride, ); let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap(); entry.cache_valid_end = actual_end; let cache_cap_global = entry.cache_capacity as i64 * stride; if entry.cache_valid_end - entry.cache_valid_start > cache_cap_global { entry.cache_valid_start = entry.cache_valid_end - cache_cap_global; } } else if vis_end <= cache_valid_end && vis_end > cache_valid_start && vis_start < cache_valid_start { // Scrolling left — move anchor BEFORE dispatch so compute and // render use the same cache_start_column mapping. let actual_start = cache_valid_start - (cache_valid_start - vis_start).min(max_cols_global); { let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap(); entry.cache_start_column = actual_start; } let entry = cqt_gpu.entries.get(&self.pool_index).unwrap(); cmds = dispatch_cqt_compute( device, queue, &cqt_gpu.compute_pipeline, entry, actual_start, cache_valid_start, self.stride, ); let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap(); entry.cache_valid_start = actual_start; let cache_cap_global = entry.cache_capacity as i64 * stride; if entry.cache_valid_end - entry.cache_valid_start > cache_cap_global { entry.cache_valid_end = entry.cache_valid_start + cache_cap_global; } } else { // No overlap or first compute — reset cache let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap(); entry.cache_start_column = vis_start; entry.cache_valid_start = vis_start; entry.cache_valid_end = vis_start; let compute_end = vis_start + (vis_end - vis_start).min(max_cols_global); let entry = cqt_gpu.entries.get(&self.pool_index).unwrap(); cmds = dispatch_cqt_compute( device, queue, &cqt_gpu.compute_pipeline, entry, vis_start, compute_end, self.stride, ); let entry = cqt_gpu.entries.get_mut(&self.pool_index).unwrap(); entry.cache_valid_end = compute_end; } } // Update render uniform buffer let entry = cqt_gpu.entries.get(&self.pool_index).unwrap(); let mut params = self.params; params.cache_start_column = entry.cache_start_column as f32; params.cache_valid_start = entry.cache_valid_start as f32; params.cache_valid_end = entry.cache_valid_end as f32; params.cache_capacity = entry.cache_capacity as f32; params.column_stride = self.stride as f32; queue.write_buffer( &entry.render_uniform_buffer, 0, bytemuck::cast_slice(&[params]), ); cmds } fn paint( &self, _info: eframe::egui::PaintCallbackInfo, render_pass: &mut wgpu::RenderPass<'static>, resources: &egui_wgpu::CallbackResources, ) { let cqt_gpu: &CqtGpuResources = match resources.get() { Some(r) => r, None => return, }; let entry = match cqt_gpu.entries.get(&self.pool_index) { Some(e) => e, None => return, }; // Don't render if nothing is cached yet if entry.cache_valid_start >= entry.cache_valid_end { return; } render_pass.set_pipeline(&cqt_gpu.render_pipeline); render_pass.set_bind_group(0, &entry.render_bind_group, &[]); render_pass.draw(0..3, 0..1); } }