diff --git a/daw-backend/src/audio/disk_reader.rs b/daw-backend/src/audio/disk_reader.rs index 4a21c07..377f4f0 100644 --- a/daw-backend/src/audio/disk_reader.rs +++ b/daw-backend/src/audio/disk_reader.rs @@ -74,6 +74,9 @@ pub struct ReadAheadBuffer { /// The disk reader uses this instead of the global playhead to know /// where in the file to buffer around. target_frame: AtomicU64, + /// When true, `render_from_file` will block-wait for frames instead of + /// returning silence on buffer miss. Used during offline export. + export_mode: AtomicBool, } // SAFETY: See the doc comment on ReadAheadBuffer for the full safety argument. @@ -108,6 +111,7 @@ impl ReadAheadBuffer { channels, sample_rate, target_frame: AtomicU64::new(0), + export_mode: AtomicBool::new(false), } } @@ -200,6 +204,18 @@ impl ReadAheadBuffer { self.target_frame.load(Ordering::Relaxed) != u64::MAX } + /// Enable or disable export (blocking) mode. When enabled, + /// `render_from_file` will spin-wait for frames instead of returning + /// silence on buffer miss. + pub fn set_export_mode(&self, export: bool) { + self.export_mode.store(export, Ordering::Release); + } + + /// Check if export (blocking) mode is active. + pub fn is_export_mode(&self) -> bool { + self.export_mode.load(Ordering::Acquire) + } + /// Reset the buffer to start at `new_start` with zero valid frames. /// Called by the **disk reader thread** (producer) after a seek. pub fn reset(&self, new_start: u64) { @@ -614,8 +630,12 @@ impl DiskReader { } } - // Sleep briefly to avoid busy-spinning when all buffers are full. - std::thread::sleep(std::time::Duration::from_millis(POLL_INTERVAL_MS)); + // In export mode, skip the sleep so decoding runs at full speed. + // Otherwise sleep briefly to avoid busy-spinning. + let any_exporting = active_files.values().any(|(_, buf)| buf.is_export_mode()); + if !any_exporting { + std::thread::sleep(std::time::Duration::from_millis(POLL_INTERVAL_MS)); + } } } } diff --git a/daw-backend/src/audio/engine.rs b/daw-backend/src/audio/engine.rs index b329493..5fa2452 100644 --- a/daw-backend/src/audio/engine.rs +++ b/daw-backend/src/audio/engine.rs @@ -1334,6 +1334,18 @@ impl Engine { } } + Command::GraphSetNodePosition(track_id, node_index, x, y) => { + let graph = match self.project.get_track_mut(track_id) { + Some(TrackNode::Midi(track)) => Some(&mut track.instrument_graph), + Some(TrackNode::Audio(track)) => Some(&mut track.effects_graph), + _ => None, + }; + if let Some(graph) = graph { + let node_idx = NodeIndex::new(node_index as usize); + graph.set_node_position(node_idx, x, y); + } + } + Command::GraphSetMidiTarget(track_id, node_index, enabled) => { if let Some(TrackNode::Midi(track)) = self.project.get_track_mut(track_id) { let graph = &mut track.instrument_graph; @@ -2245,7 +2257,10 @@ impl Engine { QueryResponse::AudioImportedSync(self.do_import_audio(&path)) } Query::GetProject => { - // Clone the entire project for serialization + // Save graph presets before cloning β€” AudioTrack::clone() creates + // a fresh default graph (not a copy), so the preset must be populated + // first so the clone carries the serialized graph data. + self.project.prepare_for_save(); QueryResponse::ProjectRetrieved(Ok(Box::new(self.project.clone()))) } Query::SetProject(new_project) => { @@ -2950,6 +2965,11 @@ impl EngineController { let _ = self.command_tx.push(Command::GraphSetParameter(track_id, node_id, param_id, value)); } + /// Set the UI position of a node in a track's graph + pub fn graph_set_node_position(&mut self, track_id: TrackId, node_id: u32, x: f32, y: f32) { + let _ = self.command_tx.push(Command::GraphSetNodePosition(track_id, node_id, x, y)); + } + /// Set which node receives MIDI events in a track's instrument graph pub fn graph_set_midi_target(&mut self, track_id: TrackId, node_id: u32, enabled: bool) { let _ = self.command_tx.push(Command::GraphSetMidiTarget(track_id, node_id, enabled)); diff --git a/daw-backend/src/audio/export.rs b/daw-backend/src/audio/export.rs index 5ca541a..4fc1bbf 100644 --- a/daw-backend/src/audio/export.rs +++ b/daw-backend/src/audio/export.rs @@ -72,29 +72,44 @@ pub fn export_audio>( midi_pool: &MidiClipPool, settings: &ExportSettings, output_path: P, - event_tx: Option<&mut rtrb::Producer>, + mut event_tx: Option<&mut rtrb::Producer>, ) -> Result<(), String> { - // Route to appropriate export implementation based on format - match settings.format { + // Reset all node graphs to clear stale effect buffers (echo, reverb, etc.) + project.reset_all_graphs(); + + // Enable blocking mode on all read-ahead buffers so compressed audio + // streams block until decoded frames are available (instead of returning + // silence when the disk reader hasn't caught up with offline rendering). + project.set_export_mode(true); + + // Route to appropriate export implementation based on format. + // Ensure export mode is disabled even if an error occurs. + let result = match settings.format { ExportFormat::Wav | ExportFormat::Flac => { - // Render to memory then write (existing path) - let samples = render_to_memory(project, pool, midi_pool, settings, event_tx)?; + let samples = render_to_memory(project, pool, midi_pool, settings, event_tx.as_mut().map(|tx| &mut **tx))?; + // Signal that rendering is done and we're now writing the file + if let Some(ref mut tx) = event_tx { + let _ = tx.push(AudioEvent::ExportFinalizing); + } match settings.format { - ExportFormat::Wav => write_wav(&samples, settings, output_path)?, - ExportFormat::Flac => write_flac(&samples, settings, output_path)?, + ExportFormat::Wav => write_wav(&samples, settings, &output_path), + ExportFormat::Flac => write_flac(&samples, settings, &output_path), _ => unreachable!(), } } ExportFormat::Mp3 => { - export_mp3(project, pool, midi_pool, settings, output_path, event_tx)?; + export_mp3(project, pool, midi_pool, settings, output_path, event_tx) } ExportFormat::Aac => { - export_aac(project, pool, midi_pool, settings, output_path, event_tx)?; + export_aac(project, pool, midi_pool, settings, output_path, event_tx) } - } + }; - Ok(()) + // Always disable export mode, even on error + project.set_export_mode(false); + + result } /// Render the project to memory @@ -437,6 +452,11 @@ fn export_mp3>( )?; } + // Signal that rendering is done and we're now flushing/finalizing + if let Some(ref mut tx) = event_tx { + let _ = tx.push(AudioEvent::ExportFinalizing); + } + // Flush encoder encoder.send_eof() .map_err(|e| format!("Failed to send EOF: {}", e))?; @@ -602,6 +622,11 @@ fn export_aac>( )?; } + // Signal that rendering is done and we're now flushing/finalizing + if let Some(ref mut tx) = event_tx { + let _ = tx.push(AudioEvent::ExportFinalizing); + } + // Flush encoder encoder.send_eof() .map_err(|e| format!("Failed to send EOF: {}", e))?; diff --git a/daw-backend/src/audio/node_graph/graph.rs b/daw-backend/src/audio/node_graph/graph.rs index 5c65043..cb8b108 100644 --- a/daw-backend/src/audio/node_graph/graph.rs +++ b/daw-backend/src/audio/node_graph/graph.rs @@ -161,6 +161,17 @@ impl AudioGraph { // Validate the connection self.validate_connection(from, from_port, to, to_port)?; + // Remove any existing connection to the same input port (replace semantics). + // The frontend UI enforces single-connection inputs, so when a new connection + // targets the same port, the old one should be replaced. + let edges_to_remove: Vec<_> = self.graph.edges_directed(to, petgraph::Direction::Incoming) + .filter(|e| e.weight().to_port == to_port) + .map(|e| e.id()) + .collect(); + for edge_id in edges_to_remove { + self.graph.remove_edge(edge_id); + } + // Add the edge self.graph.add_edge(from, to, Connection { from_port, to_port }); self.topo_cache = None; diff --git a/daw-backend/src/audio/pool.rs b/daw-backend/src/audio/pool.rs index 7a9465a..ebf46f4 100644 --- a/daw-backend/src/audio/pool.rs +++ b/daw-backend/src/audio/pool.rs @@ -95,6 +95,8 @@ pub struct AudioFile { /// Original file format (mp3, ogg, wav, flac, etc.) /// Used to determine if we should preserve lossy encoding during save pub original_format: Option, + /// Original compressed file bytes (preserved across save/load to avoid re-encoding) + pub original_bytes: Option>, } impl AudioFile { @@ -108,6 +110,7 @@ impl AudioFile { sample_rate, frames, original_format: None, + original_bytes: None, } } @@ -121,6 +124,7 @@ impl AudioFile { sample_rate, frames, original_format, + original_bytes: None, } } @@ -152,6 +156,7 @@ impl AudioFile { sample_rate, frames: total_frames, original_format: Some("wav".to_string()), + original_bytes: None, } } @@ -174,6 +179,7 @@ impl AudioFile { sample_rate, frames: total_frames, original_format, + original_bytes: None, } } @@ -470,6 +476,31 @@ impl AudioClipPool { return 0; } + // In export mode, block-wait until the disk reader has filled the + // frames we need, so offline rendering never gets buffer misses. + if use_read_ahead { + let ra = read_ahead.unwrap(); + if ra.is_export_mode() { + let src_start = (start_time_seconds * audio_file.sample_rate as f64) as u64; + // Tell the disk reader where we need data BEFORE waiting + ra.set_target_frame(src_start); + // Pad by 64 frames for sinc interpolation taps + let frames_needed = (output.len() / engine_channels as usize) as u64 + 64; + // Spin-wait with small sleeps until the disk reader fills the buffer + let mut wait_iters = 0u64; + while !ra.has_range(src_start, frames_needed) { + std::thread::sleep(std::time::Duration::from_micros(100)); + wait_iters += 1; + if wait_iters > 100_000 { + // Safety valve: 10 seconds of waiting + eprintln!("[EXPORT] Timed out waiting for disk reader (need frames {}..{})", + src_start, src_start + frames_needed); + break; + } + } + } + } + // Snapshot the read-ahead buffer range once for the entire render call. // This ensures all sinc interpolation taps within a single callback // see a consistent range, preventing crackle from concurrent updates. @@ -834,6 +865,15 @@ impl AudioClipPool { || fmt_lower == "m4a" || fmt_lower == "opus" }); + // Check for preserved original bytes first (from previous load cycle) + if let Some(ref original_bytes) = audio_file.original_bytes { + let data_base64 = general_purpose::STANDARD.encode(original_bytes); + return EmbeddedAudioData { + data_base64, + format: audio_file.original_format.clone().unwrap_or_else(|| "wav".to_string()), + }; + } + if is_lossy { // For lossy formats, read the original file bytes (if it still exists) if let Ok(original_bytes) = std::fs::read(&audio_file.path) { @@ -1012,9 +1052,12 @@ impl AudioClipPool { // Clean up temporary file let _ = std::fs::remove_file(&temp_path); - // Update the path to reflect it was embedded + // Update the path to reflect it was embedded, and preserve original bytes if result.is_ok() && pool_index < self.files.len() { self.files[pool_index].path = PathBuf::from(format!("", name)); + // Preserve the original compressed/encoded bytes so re-save doesn't need to re-encode + self.files[pool_index].original_bytes = Some(data); + self.files[pool_index].original_format = Some(embedded.format.clone()); } eprintln!("πŸ“Š [POOL] βœ… Total load_from_embedded time: {:.2}ms", fn_start.elapsed().as_secs_f64() * 1000.0); diff --git a/daw-backend/src/audio/project.rs b/daw-backend/src/audio/project.rs index 107f646..0e70cc0 100644 --- a/daw-backend/src/audio/project.rs +++ b/daw-backend/src/audio/project.rs @@ -506,6 +506,21 @@ impl Project { } } + /// Set export (blocking) mode on all clip read-ahead buffers. + /// When enabled, `render_from_file` blocks until the disk reader + /// has filled the needed frames instead of returning silence. + pub fn set_export_mode(&self, export: bool) { + for track in self.tracks.values() { + if let TrackNode::Audio(t) = track { + for clip in &t.clips { + if let Some(ref ra) = clip.read_ahead { + ra.set_export_mode(export); + } + } + } + } + } + /// Reset all node graphs (clears effect buffers on seek) pub fn reset_all_graphs(&mut self) { for track in self.tracks.values_mut() { diff --git a/daw-backend/src/command/types.rs b/daw-backend/src/command/types.rs index d776679..114a4b6 100644 --- a/daw-backend/src/command/types.rs +++ b/daw-backend/src/command/types.rs @@ -148,6 +148,8 @@ pub enum Command { GraphDisconnect(TrackId, u32, usize, u32, usize), /// Set a parameter on a node (track_id, node_index, param_id, value) GraphSetParameter(TrackId, u32, u32, f32), + /// Set the UI position of a node (track_id, node_index, x, y) + GraphSetNodePosition(TrackId, u32, f32, f32), /// Set which node receives MIDI events (track_id, node_index, enabled) GraphSetMidiTarget(TrackId, u32, bool), /// Set which node is the audio output (track_id, node_index) @@ -250,6 +252,8 @@ pub enum AudioEvent { frames_rendered: usize, total_frames: usize, }, + /// Export rendering complete, now writing/encoding the output file + ExportFinalizing, /// Waveform generated for audio pool file (pool_index, waveform) WaveformGenerated(usize, Vec), diff --git a/lightningbeam-ui/egui_node_graph2/Cargo.toml b/lightningbeam-ui/egui_node_graph2/Cargo.toml new file mode 100644 index 0000000..8d3e4d5 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "egui_node_graph2" +description = "A helper library to create interactive node graphs using egui" +homepage = "https://github.com/trevyn/egui_node_graph2" +repository = "https://github.com/trevyn/egui_node_graph2" +license = "MIT" +version = "0.7.0" +keywords = ["egui_node_graph", "ui", "egui", "graph", "node"] +edition = "2021" +readme = "../README.md" +workspace = ".." + +[features] +persistence = ["serde", "slotmap/serde", "smallvec/serde", "egui/persistence"] + +[dependencies] +egui = "0.33.3" +slotmap = { version = "1.0" } +smallvec = { version = "1.10.0" } +serde = { version = "1.0", optional = true, features = ["derive"] } +thiserror = "1.0" diff --git a/lightningbeam-ui/egui_node_graph2/src/color_hex_utils.rs b/lightningbeam-ui/egui_node_graph2/src/color_hex_utils.rs new file mode 100644 index 0000000..0479a8c --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/color_hex_utils.rs @@ -0,0 +1,94 @@ +use egui::Color32; + +/// Converts a hex string with a leading '#' into a egui::Color32. +/// - The first three channels are interpreted as R, G, B. +/// - The fourth channel, if present, is used as the alpha value. +/// - Both upper and lowercase characters can be used for the hex values. +/// +/// *Adapted from: https://docs.rs/raster/0.1.0/src/raster/lib.rs.html#425-725. +/// Credit goes to original authors.* +pub fn color_from_hex(hex: &str) -> Result { + // Convert a hex string to decimal. Eg. "00" -> 0. "FF" -> 255. + fn _hex_dec(hex_string: &str) -> Result { + match u8::from_str_radix(hex_string, 16) { + Ok(o) => Ok(o), + Err(e) => Err(format!("Error parsing hex: {}", e)), + } + } + + if hex.len() == 9 && hex.starts_with('#') { + // #FFFFFFFF (Red Green Blue Alpha) + return Ok(Color32::from_rgba_premultiplied( + _hex_dec(&hex[1..3])?, + _hex_dec(&hex[3..5])?, + _hex_dec(&hex[5..7])?, + _hex_dec(&hex[7..9])?, + )); + } else if hex.len() == 7 && hex.starts_with('#') { + // #FFFFFF (Red Green Blue) + return Ok(Color32::from_rgb( + _hex_dec(&hex[1..3])?, + _hex_dec(&hex[3..5])?, + _hex_dec(&hex[5..7])?, + )); + } + + Err(format!( + "Error parsing hex: {}. Example of valid formats: #FFFFFF or #ffffffff", + hex + )) +} + +/// Converts a Color32 into its canonical hexadecimal representation. +/// - The color string will be preceded by '#'. +/// - If the alpha channel is completely opaque, it will be ommitted. +/// - Characters from 'a' to 'f' will be written in lowercase. +#[allow(dead_code)] +pub fn color_to_hex(color: Color32) -> String { + if color.a() < 255 { + format!( + "#{:02x?}{:02x?}{:02x?}{:02x?}", + color.r(), + color.g(), + color.b(), + color.a() + ) + } else { + format!("#{:02x?}{:02x?}{:02x?}", color.r(), color.g(), color.b()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn test_color_from_and_to_hex() { + assert_eq!( + color_from_hex("#00ff00").unwrap(), + Color32::from_rgb(0, 255, 0) + ); + assert_eq!( + color_from_hex("#5577AA").unwrap(), + Color32::from_rgb(85, 119, 170) + ); + assert_eq!( + color_from_hex("#E2e2e277").unwrap(), + Color32::from_rgba_premultiplied(226, 226, 226, 119) + ); + assert!(color_from_hex("abcdefgh").is_err()); + + assert_eq!( + color_to_hex(Color32::from_rgb(0, 255, 0)), + "#00ff00".to_string() + ); + assert_eq!( + color_to_hex(Color32::from_rgb(85, 119, 170)), + "#5577aa".to_string() + ); + assert_eq!( + color_to_hex(Color32::from_rgba_premultiplied(226, 226, 226, 119)), + "#e2e2e277".to_string() + ); + } +} diff --git a/lightningbeam-ui/egui_node_graph2/src/editor_ui.rs b/lightningbeam-ui/egui_node_graph2/src/editor_ui.rs new file mode 100644 index 0000000..e32d241 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/editor_ui.rs @@ -0,0 +1,1222 @@ +use std::collections::HashSet; +use std::num::NonZeroU32; + +use crate::color_hex_utils::*; +use crate::utils::ColorUtils; + +use super::*; +use egui::epaint::{CubicBezierShape, RectShape}; +use egui::*; + +/// Mapping from parameter id to positions of hooks it contains. +/// +/// Outputs and short inputs always only have one hook, so the value is +/// just `vec![port_position]`. Wide inputs may have multiple hooks. +pub type PortLocations = std::collections::HashMap>; + +/// Destination positions of connections made to a given input. +/// +/// This is not equivalent to [`PortLocations`] because connections may be moved +/// around (e.g. while an in-progress connection is hovered over a wide port), +/// while hooks within a port are strictly a function of the port. +pub type ConnLocations = std::collections::HashMap>; + +/// Rectangle containing each node. +pub type NodeRects = std::collections::HashMap; + +const DISTANCE_TO_CONNECT: f32 = 20.0; + +/// Nodes communicate certain events to the parent graph when drawn. There is +/// one special `User` variant which can be used by users as the return value +/// when executing some custom actions in the UI of the node. +#[derive(Clone, Debug)] +pub enum NodeResponse { + ConnectEventStarted(NodeId, AnyParameterId), + ConnectEventEnded { + output: OutputId, + input: InputId, + /// Index of the connection in wide input ports. + /// + /// If the input isn't a wide port this is always 0 and may be ignored. + input_hook: usize, + }, + CreatedNode(NodeId), + SelectNode(NodeId), + /// As a user of this library, prefer listening for `DeleteNodeFull` which + /// will also contain the user data for the deleted node. + DeleteNodeUi(NodeId), + /// Emitted when a node is deleted. The node will no longer exist in the + /// graph after this response is returned from the draw function, but its + /// contents are passed along with the event. + DeleteNodeFull { + node_id: NodeId, + node: Node, + }, + DisconnectEvent { + output: OutputId, + input: InputId, + }, + /// Emitted when a node is interacted with, and should be raised + RaiseNode(NodeId), + MoveNode { + node: NodeId, + drag_delta: Vec2, + }, + User(UserResponse), +} + +/// The return value of [`draw_graph_editor`]. This value can be used to make +/// user code react to specific events that happened when drawing the graph. +#[derive(Clone, Debug)] +pub struct GraphResponse { + /// Events that occurred during this frame of rendering the graph. Check the + /// [`UserResponse`] type for a description of each event. + pub node_responses: Vec>, + /// Is the mouse currently hovering the graph editor? Note that the node + /// finder is considered part of the graph editor, even when it floats + /// outside the graph editor rect. + pub cursor_in_editor: bool, + /// Is the mouse currently hovering the node finder? + pub cursor_in_finder: bool, +} +impl Default + for GraphResponse +{ + fn default() -> Self { + Self { + node_responses: Default::default(), + cursor_in_editor: false, + cursor_in_finder: false, + } + } +} +pub struct GraphNodeWidget<'a, NodeData, DataType, ValueType> { + pub position: &'a mut Pos2, + pub graph: &'a mut Graph, + pub port_locations: &'a mut PortLocations, + pub conn_locations: &'a mut ConnLocations, + pub node_rects: &'a mut NodeRects, + pub node_id: NodeId, + pub ongoing_drag: Option<(NodeId, AnyParameterId)>, + pub selected: bool, + pub pan: egui::Vec2, +} + +impl + GraphEditorState +where + NodeData: NodeDataTrait< + Response = UserResponse, + UserState = UserState, + DataType = DataType, + ValueType = ValueType, + >, + UserResponse: UserResponseTrait, + ValueType: + WidgetValueTrait, + NodeTemplate: NodeTemplateTrait< + NodeData = NodeData, + DataType = DataType, + ValueType = ValueType, + UserState = UserState, + CategoryType = CategoryType, + >, + DataType: DataTypeTrait, + CategoryType: CategoryTrait, +{ + #[must_use] + pub fn draw_graph_editor( + &mut self, + ui: &mut Ui, + all_kinds: impl NodeTemplateIter, + user_state: &mut UserState, + prepend_responses: Vec>, + ) -> GraphResponse { + ui.set_clip_rect(ui.max_rect()); + let clip_rect = ui.clip_rect(); + // Zoom may have never taken place, so ensure we use parent style + if !self.pan_zoom.started { + self.zoom(ui, 1.0); + self.pan_zoom.started = true; + } + + // Zoom only within area where graph is shown + if ui.rect_contains_pointer(clip_rect) { + let scroll_delta = ui.input(|i| i.smooth_scroll_delta.y); + if scroll_delta != 0.0 { + let zoom_delta = (scroll_delta * 0.002).exp(); + self.zoom(ui, zoom_delta); + } + } + + // Render graph zoomed + let zoomed_style = self.pan_zoom.zoomed_style.clone(); + let graph_response = show_zoomed(ui.style().clone(), zoomed_style, ui, |ui| { + self.draw_graph_editor_inside_zoom(ui, all_kinds, user_state, prepend_responses) + }); + + graph_response + } + + /// Reset zoom to 1.0 + pub fn reset_zoom(&mut self, ui: &Ui) { + let new_zoom = 1.0 / self.pan_zoom.zoom; + self.zoom(ui, new_zoom); + } + + /// Zoom within the where you call `draw_graph_editor`. Use values like 1.01, or 0.99 to zoom. + /// For example: `let zoom_delta = (scroll_delta * 0.002).exp();` + pub fn zoom(&mut self, ui: &Ui, zoom_delta: f32) { + // Update zoom, and styles + let zoom_before = self.pan_zoom.zoom; + self.pan_zoom.zoom(ui.clip_rect(), ui.style(), zoom_delta); + if zoom_before != self.pan_zoom.zoom { + let actual_delta = self.pan_zoom.zoom / zoom_before; + self.update_node_positions_after_zoom(actual_delta); + } + } + + fn update_node_positions_after_zoom(&mut self, zoom_delta: f32) { + // Update node positions, zoom towards center + let half_size = self.pan_zoom.clip_rect.size() / 2.0; + for (_id, node_pos) in self.node_positions.iter_mut() { + // 1. Get node local position (relative to origo) + let local_pos = node_pos.to_vec2() - half_size + self.pan_zoom.pan; + // 2. Scale local position by zoom delta + let scaled_local_pos = (local_pos * zoom_delta).to_pos2(); + // 3. Transform back to global position + *node_pos = scaled_local_pos + half_size - self.pan_zoom.pan; + // This way we can retain pan untouched when zooming :) + } + } + + fn draw_graph_editor_inside_zoom( + &mut self, + ui: &mut Ui, + all_kinds: impl NodeTemplateIter, + user_state: &mut UserState, + prepend_responses: Vec>, + ) -> GraphResponse { + // This causes the graph editor to use as much free space as it can. + // (so for windows it will use up to the resizeably set limit + // and for a Panel it will fill it completely) + let editor_rect = ui.max_rect(); + let resp = ui.allocate_rect(editor_rect, Sense::hover()); + + let cursor_pos = ui + .ctx() + .input(|i| i.pointer.hover_pos().unwrap_or(Pos2::ZERO)); + let mut cursor_in_editor = resp.contains_pointer(); + let mut cursor_in_finder = false; + + // Gets filled with the node metrics as they are drawn + let mut port_locations = PortLocations::new(); + let mut node_rects = NodeRects::new(); + + // actual dest location of each connection + let mut conn_locations = ConnLocations::default(); + + // The responses returned from node drawing have side effects that are best + // executed at the end of this function. + let mut delayed_responses: Vec> = prepend_responses; + + // Used to detect when the background was clicked + let mut click_on_background = false; + + // Used to detect drag events in the background + let mut drag_started_on_background = false; + let mut drag_released_on_background = false; + + debug_assert_eq!( + self.node_order.iter().copied().collect::>(), + self.graph.iter_nodes().collect::>(), + "The node_order field of the GraphEditorself was left in an \ + inconsistent self. It has either more or less values than the graph." + ); + + // Allocate rect before the nodes, otherwise this will block the interaction + // with the nodes. + let r = ui.allocate_rect(ui.min_rect(), Sense::click().union(Sense::drag())); + if r.clicked() { + click_on_background = true; + } else if r.drag_started() { + drag_started_on_background = true; + } else if r.drag_stopped() { + drag_released_on_background = true; + } + + /* Draw nodes */ + for node_id in self.node_order.iter().copied() { + let responses = GraphNodeWidget { + position: self.node_positions.get_mut(node_id).unwrap(), + graph: &mut self.graph, + port_locations: &mut port_locations, + conn_locations: &mut conn_locations, + node_rects: &mut node_rects, + node_id, + ongoing_drag: self.connection_in_progress, + selected: self + .selected_nodes + .iter() + .any(|selected| *selected == node_id), + pan: self.pan_zoom.pan + editor_rect.min.to_vec2(), + } + .show(&self.pan_zoom, ui, user_state); + + // Actions executed later + delayed_responses.extend(responses); + } + + /* Draw the node finder, if open */ + let mut should_close_node_finder = false; + if let Some(ref mut node_finder) = self.node_finder { + let mut node_finder_area = Area::new(Id::new("node_finder")).order(Order::Foreground); + if let Some(pos) = node_finder.position { + node_finder_area = node_finder_area.current_pos(pos); + } + node_finder_area.show(ui.ctx(), |ui| { + if let Some(node_kind) = node_finder.show(ui, all_kinds, user_state) { + let new_node = self.graph.add_node( + node_kind.node_graph_label(user_state), + node_kind.user_data(user_state), + |graph, node_id| node_kind.build_node(graph, user_state, node_id), + ); + self.node_positions.insert( + new_node, + node_finder.position.unwrap_or(cursor_pos) + - self.pan_zoom.pan + - editor_rect.min.to_vec2(), + ); + self.node_order.push(new_node); + + should_close_node_finder = true; + delayed_responses.push(NodeResponse::CreatedNode(new_node)); + } + let finder_rect = ui.min_rect(); + // If the cursor is not in the main editor, check if the cursor is in the finder + // if the cursor is in the finder, then we can consider that also in the editor. + if finder_rect.contains(cursor_pos) { + cursor_in_editor = true; + cursor_in_finder = true; + } + }); + } + if should_close_node_finder { + self.node_finder = None; + } + + // draw in-progress connections + if let Some((_, ref locator)) = self.connection_in_progress { + let port_type = self.graph.any_param_type(*locator).unwrap(); + let connection_color = port_type.data_type_color(user_state); + + // outputs can't be wide yet so this is fine. + let start_pos = *port_locations[locator].last().unwrap(); + + // Find a port to connect to + fn snap_to_ports< + NodeData, + UserState, + DataType: DataTypeTrait, + ValueType, + Key: slotmap::Key + Into, + Value, + >( + graph: &Graph, + port_type: &DataType, + ports: &SlotMap, + port_locations: &PortLocations, + cursor_pos: Pos2, + ) -> Pos2 { + ports + .iter() + .find_map(|(port_id, _)| { + let compatible_ports = graph + .any_param_type(port_id.into()) + .map(|other| other == port_type) + .unwrap_or(false); + + if compatible_ports { + port_locations.get(&port_id.into()).and_then(|hooks| { + hooks + .iter() + .min_by(|hook1, hook2| { + hook1 + .distance(cursor_pos) + .partial_cmp(&hook2.distance(cursor_pos)) + .unwrap() + }) + .filter(|nearest_hook| { + nearest_hook.distance(cursor_pos) < DISTANCE_TO_CONNECT + }) + .copied() + }) + } else { + None + } + }) + .unwrap_or(cursor_pos) + } + + let (src_pos, dst_pos) = match locator { + AnyParameterId::Output(_) => ( + start_pos, + snap_to_ports( + &self.graph, + port_type, + &self.graph.inputs, + &port_locations, + cursor_pos, + ), + ), + AnyParameterId::Input(_) => ( + snap_to_ports( + &self.graph, + port_type, + &self.graph.outputs, + &port_locations, + cursor_pos, + ), + start_pos, + ), + }; + draw_connection( + &self.pan_zoom, + ui.painter(), + src_pos, + dst_pos, + connection_color, + ); + } + + // draw existing connections + for (input, outputs) in self.graph.iter_connection_groups() { + for (hook_n, &output) in outputs.iter().enumerate() { + let port_type = self + .graph + .any_param_type(AnyParameterId::Output(output)) + .unwrap(); + let connection_color = port_type.data_type_color(user_state); + // outputs can't be wide yet so this is fine. + let src_pos = port_locations[&AnyParameterId::Output(output)][0]; + let dst_pos = conn_locations[&input][hook_n]; + draw_connection( + &self.pan_zoom, + ui.painter(), + src_pos, + dst_pos, + connection_color, + ); + } + } + + /* Handle responses from drawing nodes */ + + // Some responses generate additional responses when processed. These + // are stored here to report them back to the user. + let mut extra_responses: Vec> = Vec::new(); + + for response in delayed_responses.iter() { + match response { + NodeResponse::ConnectEventStarted(node_id, port) => { + self.connection_in_progress = Some((*node_id, *port)); + } + NodeResponse::ConnectEventEnded { + output, + input, + input_hook, + } => self.graph.add_connection(*output, *input, *input_hook), + NodeResponse::CreatedNode(_) => { + //Convenience NodeResponse for users + } + NodeResponse::SelectNode(node_id) => { + self.selected_nodes = Vec::from([*node_id]); + } + NodeResponse::DeleteNodeUi(node_id) => { + let (node, disc_events) = self.graph.remove_node(*node_id); + + // Pass the disconnection responses first so user code can perform cleanup + // before node removal response. + extra_responses.extend( + disc_events + .into_iter() + .map(|(input, output)| NodeResponse::DisconnectEvent { input, output }), + ); + // Pass the full node as a response so library users can + // listen for it and get their user data. + extra_responses.push(NodeResponse::DeleteNodeFull { + node_id: *node_id, + node, + }); + self.node_positions.remove(*node_id); + // Make sure to not leave references to old nodes hanging + self.selected_nodes.retain(|id| *id != *node_id); + self.node_order.retain(|id| *id != *node_id); + } + NodeResponse::DisconnectEvent { input, output } => { + let other_node = self.graph.get_output(*output).node; + self.graph.remove_connection(*input, *output); + self.connection_in_progress = + Some((other_node, AnyParameterId::Output(*output))); + } + NodeResponse::RaiseNode(node_id) => { + let old_pos = self + .node_order + .iter() + .position(|id| *id == *node_id) + .expect("Node to be raised should be in `node_order`"); + self.node_order.remove(old_pos); + self.node_order.push(*node_id); + } + NodeResponse::MoveNode { node, drag_delta } => { + self.node_positions[*node] += *drag_delta; + // Handle multi-node selection movement + if self.selected_nodes.contains(node) && self.selected_nodes.len() > 1 { + for n in self.selected_nodes.iter().copied() { + if n != *node { + self.node_positions[n] += *drag_delta; + } + } + } + } + NodeResponse::User(_) => { + // These are handled by the user code. + } + NodeResponse::DeleteNodeFull { .. } => { + unreachable!("The UI should never produce a DeleteNodeFull event.") + } + } + } + + // Handle box selection + if let Some(box_start) = self.ongoing_box_selection { + let selection_rect = Rect::from_two_pos(cursor_pos, box_start); + let bg_color = Color32::from_rgba_unmultiplied(200, 200, 200, 20); + let stroke_color = Color32::from_rgba_unmultiplied(200, 200, 200, 180); + ui.painter().rect( + selection_rect, + 2.0, + bg_color, + Stroke::new(3.0, stroke_color), + StrokeKind::Middle, + ); + + self.selected_nodes = node_rects + .into_iter() + .filter_map(|(node_id, rect)| { + if selection_rect.intersects(rect) { + Some(node_id) + } else { + None + } + }) + .collect(); + } + + // Push any responses that were generated during response handling. + // These are only informative for the end-user and need no special + // treatment here. + delayed_responses.extend(extra_responses); + + /* Mouse input handling */ + + // This locks the context, so don't hold on to it for too long. + let mouse = &ui.ctx().input(|i| i.pointer.clone()); + + if mouse.any_released() && self.connection_in_progress.is_some() { + self.connection_in_progress = None; + } + + if mouse.secondary_released() && cursor_in_editor && !cursor_in_finder { + self.node_finder = Some(NodeFinder::new_at(cursor_pos)); + } + if ui.ctx().input(|i| i.key_pressed(Key::Escape)) { + self.node_finder = None; + } + + if r.dragged() + && ui + .ctx() + .input(|i| i.pointer.middle_down() || i.modifiers.command_only()) + { + self.pan_zoom.pan += ui.ctx().input(|i| i.pointer.delta()); + } + + // Deselect and deactivate finder if the editor background is clicked, + // *or* if the the mouse clicks off the ui + if click_on_background || (mouse.any_click() && !cursor_in_editor) { + self.selected_nodes = Vec::new(); + self.node_finder = None; + } + + if drag_started_on_background + && mouse.primary_down() + && !ui.ctx().input(|i| i.modifiers.command_only()) + { + self.ongoing_box_selection = Some(cursor_pos); + } + if mouse.primary_released() || drag_released_on_background { + self.ongoing_box_selection = None; + } + + GraphResponse { + node_responses: delayed_responses, + cursor_in_editor, + cursor_in_finder, + } + } +} + +fn draw_connection( + pan_zoom: &PanZoom, + painter: &Painter, + src_pos: Pos2, + dst_pos: Pos2, + color: Color32, +) { + let connection_stroke = egui::Stroke { + width: 5.0 * pan_zoom.zoom, + color, + }; + + let control_scale = ((dst_pos.x - src_pos.x) * pan_zoom.zoom / 2.0).max(30.0 * pan_zoom.zoom); + let src_control = src_pos + Vec2::X * control_scale; + let dst_control = dst_pos - Vec2::X * control_scale; + + let bezier = CubicBezierShape::from_points_stroke( + [src_pos, src_control, dst_control, dst_pos], + false, + Color32::TRANSPARENT, + connection_stroke, + ); + + painter.add(bezier); +} + +#[derive(Clone, Copy, Debug)] +struct OuterRectMemory(Rect); + +impl<'a, NodeData, DataType, ValueType, UserResponse, UserState> + GraphNodeWidget<'a, NodeData, DataType, ValueType> +where + NodeData: NodeDataTrait< + Response = UserResponse, + UserState = UserState, + DataType = DataType, + ValueType = ValueType, + >, + UserResponse: UserResponseTrait, + ValueType: + WidgetValueTrait, + DataType: DataTypeTrait, +{ + pub const MAX_NODE_SIZE: [f32; 2] = [200.0, 200.0]; + + pub fn show( + self, + pan_zoom: &PanZoom, + ui: &mut Ui, + user_state: &mut UserState, + ) -> Vec> { + let mut child_ui = ui.child_ui_with_id_source( + Rect::from_min_size(*self.position + self.pan, Self::MAX_NODE_SIZE.into()), + Layout::default(), + self.node_id, + None, + ); + + Self::show_graph_node(self, pan_zoom, &mut child_ui, user_state) + } + + /// Draws this node. Also fills in the list of port locations with all of its ports. + /// Returns responses indicating multiple events. + fn show_graph_node( + self, + pan_zoom: &PanZoom, + ui: &mut Ui, + user_state: &mut UserState, + ) -> Vec> { + let margin = egui::vec2(15.0, 5.0) * pan_zoom.zoom; + let mut responses = Vec::>::new(); + + let background_color; + let text_color; + if ui.visuals().dark_mode { + background_color = color_from_hex("#3f3f3f").unwrap(); + text_color = color_from_hex("#fefefe").unwrap(); + } else { + background_color = color_from_hex("#ffffff").unwrap(); + text_color = color_from_hex("#505050").unwrap(); + } + + ui.visuals_mut().widgets.noninteractive.fg_stroke = + Stroke::new(2.0 * pan_zoom.zoom, text_color); + + // Preallocate shapes to paint below contents + let outline_shape = ui.painter().add(Shape::Noop); + let background_shape = ui.painter().add(Shape::Noop); + + let mut outer_rect_bounds = ui.available_rect_before_wrap(); + // Scale hack, otherwise some (larger) rects expand too much when zoomed out + outer_rect_bounds.max.x = + outer_rect_bounds.min.x + outer_rect_bounds.width() * pan_zoom.zoom; + + let mut inner_rect = outer_rect_bounds.shrink2(margin); + + // Make sure we don't shrink to the negative: + inner_rect.max.x = inner_rect.max.x.max(inner_rect.min.x); + inner_rect.max.y = inner_rect.max.y.max(inner_rect.min.y); + + let mut child_ui = ui.child_ui(inner_rect, *ui.layout(), None); + + // Get interaction rect from memory, it may expand after the window response on resize. + let interaction_rect = ui + .ctx() + .memory_mut(|mem| { + mem.data + .get_temp::(child_ui.id()) + .map(|stored| stored.0) + }) + .unwrap_or(outer_rect_bounds); + // After 0.20, layers added over others can block hover interaction. Call this first + // before creating the node content. + let window_response = ui.interact( + interaction_rect, + Id::new((self.node_id, "window")), + Sense::click_and_drag(), + ); + + let mut title_height = 0.0; + + let mut input_port_heights = vec![]; + let mut output_port_heights = vec![]; + + child_ui.vertical(|ui| { + ui.horizontal(|ui| { + ui.add(Label::new( + RichText::new(&self.graph[self.node_id].label) + .text_style(TextStyle::Button) + .color(text_color), + ).selectable(false)); + responses.extend(self.graph[self.node_id].user_data.top_bar_ui( + ui, + self.node_id, + self.graph, + user_state, + )); + ui.add_space(8.0 * pan_zoom.zoom); // The size of the little cross icon + }); + ui.add_space(margin.y); + title_height = ui.min_size().y; + + // First pass: Draw the inner fields. Compute port heights + let inputs = self.graph[self.node_id].inputs.clone(); + for (param_name, param_id) in inputs { + if self.graph[param_id].shown_inline { + let height_before = ui.min_rect().bottom(); + + if self.graph[param_id].max_connections == NonZeroU32::new(1) { + // NOTE: We want to pass the `user_data` to + // `value_widget`, but we can't since that would require + // borrowing the graph twice. Here, we make the + // assumption that the value is cheaply replaced, and + // use `std::mem::take` to temporarily replace it with a + // dummy value. This requires `ValueType` to implement + // Default, but results in a totally safe alternative. + let mut value = std::mem::take(&mut self.graph[param_id].value); + + if !self.graph.connections(param_id).is_empty() { + let node_responses = value.value_widget_connected( + ¶m_name, + self.node_id, + ui, + user_state, + &self.graph[self.node_id].user_data, + ); + + responses.extend(node_responses.into_iter().map(NodeResponse::User)); + } else { + let node_responses = value.value_widget( + ¶m_name, + self.node_id, + ui, + user_state, + &self.graph[self.node_id].user_data, + ); + + responses.extend(node_responses.into_iter().map(NodeResponse::User)); + } + + self.graph[param_id].value = value; + } else { + ui.label(param_name); + } + + let height_intermediate = ui.min_rect().bottom(); + + let max_connections = self.graph[param_id] + .max_connections + .map(NonZeroU32::get) + .unwrap_or(std::u32::MAX) + as usize; + let port_height = port_height( + max_connections != 1, + self.graph.connections(param_id).len(), + max_connections, + ); + let margin = 5.0; + let missing_space = + port_height - (height_intermediate - height_before) + margin; + if missing_space > 0.0 { + ui.add_space(missing_space); + } + + self.graph[self.node_id].user_data.separator( + ui, + self.node_id, + AnyParameterId::Input(param_id), + self.graph, + user_state, + ); + + let height_after = ui.min_rect().bottom(); + + input_port_heights.push((height_before + height_after) / 2.0); + } + } + + let outputs = self.graph[self.node_id].outputs.clone(); + for (param_name, param_id) in outputs { + let height_before = ui.min_rect().bottom(); + responses.extend( + self.graph[self.node_id] + .user_data + .output_ui(ui, self.node_id, self.graph, user_state, ¶m_name) + .into_iter(), + ); + + self.graph[self.node_id].user_data.separator( + ui, + self.node_id, + AnyParameterId::Output(param_id), + self.graph, + user_state, + ); + + let height_after = ui.min_rect().bottom(); + output_port_heights.push((height_before + height_after) / 2.0); + } + + responses.extend(self.graph[self.node_id].user_data.bottom_ui( + ui, + self.node_id, + self.graph, + user_state, + )); + }); + + // Second pass, iterate again to draw the ports. This happens outside + // the child_ui because we want ports to overflow the node background. + + let outer_rect = child_ui.min_rect().expand2(margin); + let port_left = outer_rect.left(); + let port_right = outer_rect.right(); + + // Save expanded rect to memory. + ui.ctx().memory_mut(|mem| { + mem.data + .insert_temp(child_ui.id(), OuterRectMemory(outer_rect)) + }); + + fn port_height(wide_port: bool, connections: usize, max_connections: usize) -> f32 { + let port_full = connections == max_connections; + if wide_port { + let hooks = connections + if port_full { 0 } else { 1 }; + + 5.0 + (10.0 * hooks as f32).max(10.0) + } else { + 10.0 + } + } + + #[allow(clippy::too_many_arguments)] + fn draw_port( + pan_zoom: &PanZoom, + ui: &mut Ui, + graph: &Graph, + node_id: NodeId, + user_state: &mut UserState, + port_pos: Pos2, + responses: &mut Vec>, + param_id: AnyParameterId, + port_locations: &mut PortLocations, + conn_locations: &mut ConnLocations, + ongoing_drag: Option<(NodeId, AnyParameterId)>, + wide_port: bool, + connections: usize, + max_connections: usize, + ) where + DataType: DataTypeTrait, + UserResponse: UserResponseTrait, + NodeData: NodeDataTrait, + { + let port_type = graph.any_param_type(param_id).unwrap(); + + let port_rect = Rect::from_center_size( + port_pos, + egui::vec2(DISTANCE_TO_CONNECT * 2.0, port_height(wide_port, connections, max_connections)) + * pan_zoom.zoom, + ); + + let port_full = connections == max_connections; + + let inner_ports = if wide_port { + connections + if port_full { 0 } else { 1 } + } else { + 1 + }; + + port_locations.insert( + param_id, + (0..inner_ports) + .map(|k| { + port_rect.center_top() + + Vec2::new(0.0, 5.0 * pan_zoom.zoom) + + Vec2::new(0.0, 10.0 * pan_zoom.zoom) * k as f32 + }) + .collect(), + ); + + let sense = if ongoing_drag.is_some() { + Sense::hover() + } else { + Sense::click_and_drag() + }; + + let resp = ui.allocate_rect(port_rect, sense); + + // Check if the mouse is within the port's interaction rect + let close_enough = if let Some(pointer_pos) = ui.ctx().pointer_hover_pos() { + port_rect.contains(pointer_pos) + } else { + false + }; + + let port_color = if close_enough { + Color32::WHITE + } else { + port_type.data_type_color(user_state) + }; + + if wide_port { + ui.painter() + .rect_filled(port_rect, 5.0 * pan_zoom.zoom, port_color); + } else { + ui.painter().circle( + port_rect.center(), + 5.0 * pan_zoom.zoom, + port_color, + Stroke::NONE, + ); + } + + if connections > 0 { + if let AnyParameterId::Input(input) = param_id { + for (k, dst_pos) in port_locations[&AnyParameterId::Input(input)] + .iter() + .enumerate() + { + conn_locations.entry(input).or_default().insert(k, *dst_pos); + } + } + } + + let nearest_hook = ui + .input(|in_state| in_state.pointer.hover_pos()) + .and_then(|mouse_pos| match param_id { + AnyParameterId::Input(input) => Some((mouse_pos, input)), + AnyParameterId::Output(_) => None, + }) + .and_then(|(mouse_pos, input)| { + let hooks = 0..inner_ports; + hooks.min_by(|&hook1, &hook2| { + let out1_dist = conn_locations[&input][hook1].distance(mouse_pos); + let out2_dist = conn_locations[&input][hook2].distance(mouse_pos); + + out1_dist.partial_cmp(&out2_dist).unwrap() + }) + }); + + if resp.drag_started() { + match param_id { + AnyParameterId::Input(input) => { + match nearest_hook + .and_then(|hook| graph.connections(input).get(hook).copied()) + { + Some(output) => { + responses.push(NodeResponse::DisconnectEvent { input, output }); + } + None => { + responses + .push(NodeResponse::ConnectEventStarted(node_id, param_id)); + } + } + } + AnyParameterId::Output(_) => { + responses.push(NodeResponse::ConnectEventStarted(node_id, param_id)); + } + } + } + + if let Some((origin_node, origin_param)) = ongoing_drag { + if origin_node != node_id { + // Don't allow self-loops + if graph.any_param_type(origin_param).unwrap() == port_type && close_enough { + match (param_id, origin_param) { + (AnyParameterId::Input(input), AnyParameterId::Output(output)) + | (AnyParameterId::Output(output), AnyParameterId::Input(input)) => { + let input_hook = + nearest_hook.unwrap_or(graph.connections(input).len()); + + if ui.input(|i| i.pointer.any_released()) { + responses.push(NodeResponse::ConnectEventEnded { + output, + input, + input_hook, + }); + } else if wide_port && !port_full { + // move connections below the in-progress one to a lower position + for k in input_hook..graph.connections(input).len() { + conn_locations.get_mut(&input).unwrap()[k].y += 7.5; + } + } + } + _ => { /* Ignore in-in or out-out connections */ } + } + } + } + } + } + + // Input ports + for ((_, param), port_height) in self.graph[self.node_id] + .inputs + .iter() + .zip(input_port_heights.into_iter()) + { + let should_draw = match self.graph[*param].kind() { + InputParamKind::ConnectionOnly => true, + InputParamKind::ConstantOnly => false, + InputParamKind::ConnectionOrConstant => true, + }; + + if should_draw { + let pos_left = pos2(port_left, port_height); + let max_connections = self.graph[*param] + .max_connections + .map(NonZeroU32::get) + .unwrap_or(std::u32::MAX) as usize; + draw_port( + pan_zoom, + ui, + self.graph, + self.node_id, + user_state, + pos_left, + &mut responses, + AnyParameterId::Input(*param), + self.port_locations, + self.conn_locations, + self.ongoing_drag, + max_connections > 1, + self.graph.connections(*param).len(), + max_connections, + ); + } + } + + // Output ports + for ((_, param), port_height) in self.graph[self.node_id] + .outputs + .iter() + .zip(output_port_heights.into_iter()) + { + let pos_right = pos2(port_right, port_height); + draw_port( + pan_zoom, + ui, + self.graph, + self.node_id, + user_state, + pos_right, + &mut responses, + AnyParameterId::Output(*param), + self.port_locations, + self.conn_locations, + self.ongoing_drag, + false, + 0, + 1, + ); + } + + // Draw the background shape. + // NOTE: This code is a bit more involved than it needs to be because egui + // does not support drawing rectangles with asymmetrical round corners. + + let (shape, outline) = { + let rounding_radius = 4.0 * pan_zoom.zoom; + let rounding = CornerRadius::same(rounding_radius as u8); + + let titlebar_height = title_height + margin.y; + let titlebar_rect = + Rect::from_min_size(outer_rect.min, vec2(outer_rect.width(), titlebar_height)); + let titlebar = Shape::Rect(RectShape { + rect: titlebar_rect, + corner_radius: rounding, + fill: self.graph[self.node_id] + .user_data + .titlebar_color(ui, self.node_id, self.graph, user_state) + .unwrap_or_else(|| background_color.lighten(0.8)), + stroke: Stroke::NONE, + blur_width: 0.0, + round_to_pixels: None, + brush: None, + stroke_kind: StrokeKind::Inside, + }); + + let body_rect = Rect::from_min_size( + outer_rect.min + vec2(0.0, titlebar_height - rounding_radius), + vec2(outer_rect.width(), outer_rect.height() - titlebar_height), + ); + let body = Shape::Rect(RectShape { + rect: body_rect, + corner_radius: CornerRadius::ZERO, + fill: background_color, + stroke: Stroke::NONE, + blur_width: 0.0, + round_to_pixels: None, + brush: None, + stroke_kind: StrokeKind::Inside, + }); + + let bottom_body_rect = Rect::from_min_size( + body_rect.min + vec2(0.0, body_rect.height() - titlebar_height * 0.5), + vec2(outer_rect.width(), titlebar_height), + ); + let bottom_body = Shape::Rect(RectShape { + rect: bottom_body_rect, + corner_radius: rounding, + fill: background_color, + stroke: Stroke::NONE, + blur_width: 0.0, + round_to_pixels: None, + brush: None, + stroke_kind: StrokeKind::Inside, + }); + + let node_rect = titlebar_rect.union(body_rect).union(bottom_body_rect); + let outline = if self.selected { + Shape::Rect(RectShape { + rect: node_rect.expand(1.0 * pan_zoom.zoom), + corner_radius: rounding, + fill: Color32::WHITE.lighten(0.8), + stroke: Stroke::NONE, + blur_width: 0.0, + round_to_pixels: None, + brush: None, + stroke_kind: StrokeKind::Inside, + }) + } else { + Shape::Noop + }; + + // Take note of the node rect, so the editor can use it later to compute intersections. + self.node_rects.insert(self.node_id, node_rect); + + (Shape::Vec(vec![titlebar, body, bottom_body]), outline) + }; + + ui.painter().set(background_shape, shape); + ui.painter().set(outline_shape, outline); + + // --- Interaction --- + + // Titlebar buttons + let can_delete = self.graph.nodes[self.node_id].user_data.can_delete( + self.node_id, + self.graph, + user_state, + ); + + if can_delete && Self::close_button(pan_zoom, ui, outer_rect).clicked() { + responses.push(NodeResponse::DeleteNodeUi(self.node_id)); + }; + + // Movement + let drag_delta = window_response.drag_delta(); + if drag_delta.length_sq() > 0.0 { + responses.push(NodeResponse::MoveNode { + node: self.node_id, + drag_delta, + }); + responses.push(NodeResponse::RaiseNode(self.node_id)); + } + + // Node selection + // + // HACK: Only set the select response when no other response is active. + // This prevents some issues. + if responses.is_empty() && window_response.clicked_by(PointerButton::Primary) { + responses.push(NodeResponse::SelectNode(self.node_id)); + responses.push(NodeResponse::RaiseNode(self.node_id)); + } + + responses + } + + fn close_button(pan_zoom: &PanZoom, ui: &mut Ui, node_rect: Rect) -> Response { + // Measurements + let margin = 8.0 * pan_zoom.zoom; + let size = 10.0 * pan_zoom.zoom; + let stroke_width = 2.0; + let offs = margin + size / 2.0; + + let position = pos2(node_rect.right() - offs, node_rect.top() + offs); + let rect = Rect::from_center_size(position, vec2(size, size)); + let resp = ui.allocate_rect(rect, Sense::click()); + + let dark_mode = ui.visuals().dark_mode; + let color = if resp.clicked() { + if dark_mode { + color_from_hex("#ffffff").unwrap() + } else { + color_from_hex("#000000").unwrap() + } + } else if resp.hovered() { + if dark_mode { + color_from_hex("#dddddd").unwrap() + } else { + color_from_hex("#222222").unwrap() + } + } else { + #[allow(clippy::collapsible_else_if)] + if dark_mode { + color_from_hex("#aaaaaa").unwrap() + } else { + color_from_hex("#555555").unwrap() + } + }; + let stroke = Stroke { + width: stroke_width, + color, + }; + + ui.painter() + .line_segment([rect.left_top(), rect.right_bottom()], stroke); + ui.painter() + .line_segment([rect.right_top(), rect.left_bottom()], stroke); + + resp + } +} diff --git a/lightningbeam-ui/egui_node_graph2/src/error.rs b/lightningbeam-ui/egui_node_graph2/src/error.rs new file mode 100644 index 0000000..8033727 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/error.rs @@ -0,0 +1,10 @@ +use super::*; + +#[derive(Debug, thiserror::Error)] +pub enum EguiGraphError { + #[error("Node {0:?} has no parameter named {1}")] + NoParameterNamed(NodeId, String), + + #[error("Parameter {0:?} was not found in the graph.")] + InvalidParameterId(AnyParameterId), +} diff --git a/lightningbeam-ui/egui_node_graph2/src/graph.rs b/lightningbeam-ui/egui_node_graph2/src/graph.rs new file mode 100644 index 0000000..85c142e --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/graph.rs @@ -0,0 +1,95 @@ +use std::num::NonZeroU32; + +use super::*; + +#[cfg(feature = "persistence")] +use serde::{Deserialize, Serialize}; + +/// A node inside the [`Graph`]. Nodes have input and output parameters, stored +/// as ids. They also contain a custom `NodeData` struct with whatever data the +/// user wants to store per-node. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct Node { + pub id: NodeId, + pub label: String, + pub inputs: Vec<(String, InputId)>, + pub outputs: Vec<(String, OutputId)>, + pub user_data: NodeData, +} + +/// The three kinds of input params. These describe how the graph must behave +/// with respect to inline widgets and connections for this parameter. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub enum InputParamKind { + /// No constant value can be set. Only incoming connections can produce it + ConnectionOnly, + /// Only a constant value can be set. No incoming connections accepted. + ConstantOnly, + /// Both incoming connections and constants are accepted. Connections take + /// precedence over the constant values. + ConnectionOrConstant, +} + +#[cfg(feature = "persistence")] +fn shown_inline_default() -> bool { + true +} + +/// An input parameter. Input parameters are inside a node, and represent data +/// that this node receives. Unlike their [`OutputParam`] counterparts, input +/// parameters also display an inline widget which allows setting its "value". +/// The `DataType` generic parameter is used to restrict the range of input +/// connections for this parameter, and the `ValueType` is use to represent the +/// data for the inline widget (i.e. constant) value. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct InputParam { + pub id: InputId, + /// The data type of this node. Used to determine incoming connections. This + /// should always match the type of the InputParamValue, but the property is + /// not actually enforced. + pub typ: DataType, + /// The constant value stored in this parameter. + pub value: ValueType, + /// The input kind. See [`InputParamKind`] + pub kind: InputParamKind, + /// Back-reference to the node containing this parameter. + pub node: NodeId, + /// How many connections can be made with this input. `None` means no limit. + pub max_connections: Option, + /// When true, the node is shown inline inside the node graph. + #[cfg_attr(feature = "persistence", serde(default = "shown_inline_default"))] + pub shown_inline: bool, +} + +/// An output parameter. Output parameters are inside a node, and represent the +/// data that the node produces. Output parameters can be linked to the input +/// parameters of other nodes. Unlike an [`InputParam`], output parameters +/// cannot have a constant inline value. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct OutputParam { + pub id: OutputId, + /// Back-reference to the node containing this parameter. + pub node: NodeId, + pub typ: DataType, +} + +/// The graph, containing nodes, input parameters and output parameters. Because +/// graphs are full of self-referential structures, this type uses the `slotmap` +/// crate to represent all the inner references in the data. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct Graph { + /// The [`Node`]s of the graph + pub nodes: SlotMap>, + /// The [`InputParam`]s of the graph + pub inputs: SlotMap>, + /// The [`OutputParam`]s of the graph + pub outputs: SlotMap>, + // Connects the input of a node, to the output of its predecessor that + // produces it + pub connections: SecondaryMap>, +} diff --git a/lightningbeam-ui/egui_node_graph2/src/graph_impls.rs b/lightningbeam-ui/egui_node_graph2/src/graph_impls.rs new file mode 100644 index 0000000..a117b95 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/graph_impls.rs @@ -0,0 +1,292 @@ +use std::num::NonZeroU32; + +use super::*; + +impl Graph { + pub fn new() -> Self { + Self { + nodes: SlotMap::default(), + inputs: SlotMap::default(), + outputs: SlotMap::default(), + connections: SecondaryMap::default(), + } + } + + pub fn add_node( + &mut self, + label: String, + user_data: NodeData, + f: impl FnOnce(&mut Graph, NodeId), + ) -> NodeId { + let node_id = self.nodes.insert_with_key(|node_id| { + Node { + id: node_id, + label, + // These get filled in later by the user function + inputs: Vec::default(), + outputs: Vec::default(), + user_data, + } + }); + + f(self, node_id); + + node_id + } + + #[allow(clippy::too_many_arguments)] + pub fn add_wide_input_param( + &mut self, + node_id: NodeId, + name: String, + typ: DataType, + value: ValueType, + kind: InputParamKind, + max_connections: Option, + shown_inline: bool, + ) -> InputId { + let input_id = self.inputs.insert_with_key(|input_id| InputParam { + id: input_id, + typ, + value, + kind, + node: node_id, + max_connections, + shown_inline, + }); + self.nodes[node_id].inputs.push((name, input_id)); + input_id + } + + pub fn add_input_param( + &mut self, + node_id: NodeId, + name: String, + typ: DataType, + value: ValueType, + kind: InputParamKind, + shown_inline: bool, + ) -> InputId { + self.add_wide_input_param( + node_id, + name, + typ, + value, + kind, + NonZeroU32::new(1), + shown_inline, + ) + } + + pub fn remove_input_param(&mut self, param: InputId) { + let node = self[param].node; + self[node].inputs.retain(|(_, id)| *id != param); + self.inputs.remove(param); + self.connections.retain(|i, _| i != param); + } + + pub fn remove_output_param(&mut self, param: OutputId) { + let node = self[param].node; + self[node].outputs.retain(|(_, id)| *id != param); + self.outputs.remove(param); + for (_, conns) in &mut self.connections { + conns.retain(|o| *o != param); + } + } + + pub fn add_output_param(&mut self, node_id: NodeId, name: String, typ: DataType) -> OutputId { + let output_id = self.outputs.insert_with_key(|output_id| OutputParam { + id: output_id, + node: node_id, + typ, + }); + self.nodes[node_id].outputs.push((name, output_id)); + output_id + } + + /// Removes a node from the graph with given `node_id`. This also removes + /// any incoming or outgoing connections from that node + /// + /// This function returns the list of connections that has been removed + /// after deleting this node as input-output pairs. Note that one of the two + /// ids in the pair (the one on `node_id`'s end) will be invalid after + /// calling this function. + pub fn remove_node(&mut self, node_id: NodeId) -> (Node, Vec<(InputId, OutputId)>) { + let mut disconnect_events = vec![]; + + for (i, conns) in &mut self.connections { + conns.retain(|o| { + if self.outputs[*o].node == node_id || self.inputs[i].node == node_id { + disconnect_events.push((i, *o)); + false + } else { + true + } + }); + } + + // NOTE: Collect is needed because we can't borrow the input ids while + // we remove them inside the loop. + for input in self[node_id].input_ids().collect::>() { + self.inputs.remove(input); + } + for output in self[node_id].output_ids().collect::>() { + self.outputs.remove(output); + } + let removed_node = self.nodes.remove(node_id).expect("Node should exist"); + + (removed_node, disconnect_events) + } + + pub fn remove_connection(&mut self, input_id: InputId, output_id: OutputId) -> bool { + self.connections + .get_mut(input_id) + .map(|conns| { + let old_size = conns.len(); + conns.retain(|id| id != &output_id); + + // connection removed if `conn` size changes + old_size != conns.len() + }) + .unwrap_or(false) + } + + pub fn iter_nodes(&self) -> impl Iterator + '_ { + self.nodes.iter().map(|(id, _)| id) + } + + pub fn add_connection(&mut self, output: OutputId, input: InputId, pos: usize) { + if !self.connections.contains_key(input) { + self.connections.insert(input, Vec::default()); + } + + let max_connections = self + .get_input(input) + .max_connections + .map(NonZeroU32::get) + .unwrap_or(std::u32::MAX) as usize; + let already_in = self.connections[input].contains(&output); + + // connecting twice to the same port is a no-op + // even for wide ports. + if already_in { + return; + } + + if self.connections[input].len() == max_connections { + // if full, replace the connected output + self.connections[input][pos] = output; + } else { + // otherwise, insert at a selected position + self.connections[input].insert(pos, output); + } + } + + pub fn iter_connection_groups(&self) -> impl Iterator)> + '_ { + self.connections.iter().map(|(i, conns)| (i, conns.clone())) + } + + pub fn iter_connections(&self) -> impl Iterator + '_ { + self.iter_connection_groups() + .flat_map(|(i, conns)| conns.into_iter().map(move |o| (i, o))) + } + + pub fn connections(&self, input: InputId) -> Vec { + self.connections.get(input).cloned().unwrap_or_default() + } + + pub fn connection(&self, input: InputId) -> Option { + let is_limit_1 = self.get_input(input).max_connections == NonZeroU32::new(1); + let connections = self.connections(input); + + if is_limit_1 && connections.len() == 1 { + connections.into_iter().next() + } else { + None + } + } + + pub fn any_param_type(&self, param: AnyParameterId) -> Result<&DataType, EguiGraphError> { + match param { + AnyParameterId::Input(input) => self.inputs.get(input).map(|x| &x.typ), + AnyParameterId::Output(output) => self.outputs.get(output).map(|x| &x.typ), + } + .ok_or(EguiGraphError::InvalidParameterId(param)) + } + + pub fn try_get_input(&self, input: InputId) -> Option<&InputParam> { + self.inputs.get(input) + } + + pub fn get_input(&self, input: InputId) -> &InputParam { + &self.inputs[input] + } + + pub fn try_get_output(&self, output: OutputId) -> Option<&OutputParam> { + self.outputs.get(output) + } + + pub fn get_output(&self, output: OutputId) -> &OutputParam { + &self.outputs[output] + } +} + +impl Default for Graph { + fn default() -> Self { + Self::new() + } +} + +impl Node { + pub fn inputs<'a, DataType, DataValue>( + &'a self, + graph: &'a Graph, + ) -> impl Iterator> + 'a { + self.input_ids().map(|id| graph.get_input(id)) + } + + pub fn outputs<'a, DataType, DataValue>( + &'a self, + graph: &'a Graph, + ) -> impl Iterator> + 'a { + self.output_ids().map(|id| graph.get_output(id)) + } + + pub fn input_ids(&self) -> impl Iterator + '_ { + self.inputs.iter().map(|(_name, id)| *id) + } + + pub fn output_ids(&self) -> impl Iterator + '_ { + self.outputs.iter().map(|(_name, id)| *id) + } + + pub fn get_input(&self, name: &str) -> Result { + self.inputs + .iter() + .find(|(param_name, _id)| param_name == name) + .map(|x| x.1) + .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into())) + } + + pub fn get_output(&self, name: &str) -> Result { + self.outputs + .iter() + .find(|(param_name, _id)| param_name == name) + .map(|x| x.1) + .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into())) + } +} + +impl InputParam { + pub fn value(&self) -> &ValueType { + &self.value + } + + pub fn kind(&self) -> InputParamKind { + self.kind + } + + pub fn node(&self) -> NodeId { + self.node + } +} diff --git a/lightningbeam-ui/egui_node_graph2/src/id_type.rs b/lightningbeam-ui/egui_node_graph2/src/id_type.rs new file mode 100644 index 0000000..5c272e1 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/id_type.rs @@ -0,0 +1,37 @@ +slotmap::new_key_type! { pub struct NodeId; } +slotmap::new_key_type! { pub struct InputId; } +slotmap::new_key_type! { pub struct OutputId; } + +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum AnyParameterId { + Input(InputId), + Output(OutputId), +} + +impl AnyParameterId { + pub fn assume_input(&self) -> InputId { + match self { + AnyParameterId::Input(input) => *input, + AnyParameterId::Output(output) => panic!("{:?} is not an InputId", output), + } + } + pub fn assume_output(&self) -> OutputId { + match self { + AnyParameterId::Output(output) => *output, + AnyParameterId::Input(input) => panic!("{:?} is not an OutputId", input), + } + } +} + +impl From for AnyParameterId { + fn from(output: OutputId) -> Self { + Self::Output(output) + } +} + +impl From for AnyParameterId { + fn from(input: InputId) -> Self { + Self::Input(input) + } +} diff --git a/lightningbeam-ui/egui_node_graph2/src/index_impls.rs b/lightningbeam-ui/egui_node_graph2/src/index_impls.rs new file mode 100644 index 0000000..f002330 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/index_impls.rs @@ -0,0 +1,35 @@ +use super::*; + +macro_rules! impl_index_traits { + ($id_type:ty, $output_type:ty, $arena:ident) => { + impl std::ops::Index<$id_type> for Graph { + type Output = $output_type; + + fn index(&self, index: $id_type) -> &Self::Output { + self.$arena.get(index).unwrap_or_else(|| { + panic!( + "{} index error for {:?}. Has the value been deleted?", + stringify!($id_type), + index + ) + }) + } + } + + impl std::ops::IndexMut<$id_type> for Graph { + fn index_mut(&mut self, index: $id_type) -> &mut Self::Output { + self.$arena.get_mut(index).unwrap_or_else(|| { + panic!( + "{} index error for {:?}. Has the value been deleted?", + stringify!($id_type), + index + ) + }) + } + } + }; +} + +impl_index_traits!(NodeId, Node, nodes); +impl_index_traits!(InputId, InputParam, inputs); +impl_index_traits!(OutputId, OutputParam, outputs); diff --git a/lightningbeam-ui/egui_node_graph2/src/lib.rs b/lightningbeam-ui/egui_node_graph2/src/lib.rs new file mode 100644 index 0000000..a3d562e --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/lib.rs @@ -0,0 +1,47 @@ +#![forbid(unsafe_code)] + +use slotmap::{SecondaryMap, SlotMap}; + +pub type SVec = smallvec::SmallVec<[T; 4]>; + +/// Contains the main definitions for the node graph model. +pub mod graph; +pub use graph::*; + +/// Type declarations for the different id types (node, input, output) +pub mod id_type; +pub use id_type::*; + +/// Implements the index trait for the Graph type, allowing indexing by all +/// three id types +pub mod index_impls; + +/// Implementing the main methods for the `Graph` +pub mod graph_impls; + +/// Custom error types, crate-wide +pub mod error; +pub use error::*; + +/// The main struct in the library, contains all the necessary state to draw the +/// UI graph +pub mod ui_state; +pub use ui_state::*; + +/// The node finder is a tiny widget allowing to create new node types +pub mod node_finder; +pub use node_finder::*; + +/// The inner details of the egui implementation. Most egui code lives here. +pub mod editor_ui; +pub use editor_ui::*; + +/// Several traits that must be implemented by the user to customize the +/// behavior of this library. +pub mod traits; +pub use traits::*; + +mod utils; + +mod color_hex_utils; +mod scale; diff --git a/lightningbeam-ui/egui_node_graph2/src/node_finder.rs b/lightningbeam-ui/egui_node_graph2/src/node_finder.rs new file mode 100644 index 0000000..d1d370d --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/node_finder.rs @@ -0,0 +1,152 @@ +use std::{collections::BTreeMap, marker::PhantomData}; + +use crate::{color_hex_utils::*, CategoryTrait, NodeTemplateIter, NodeTemplateTrait}; + +use egui::*; + +#[derive(Clone)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +pub struct NodeFinder { + pub query: String, + /// Reset every frame. When set, the node finder will be moved at that position + pub position: Option, + pub just_spawned: bool, + _phantom: PhantomData, +} + +impl NodeFinder +where + NodeTemplate: + NodeTemplateTrait, + CategoryType: CategoryTrait, +{ + pub fn new_at(pos: Pos2) -> Self { + NodeFinder { + query: "".into(), + position: Some(pos), + just_spawned: true, + _phantom: Default::default(), + } + } + + /// Shows the node selector panel with a search bar. Returns whether a node + /// archetype was selected and, in that case, the finder should be hidden on + /// the next frame. + pub fn show( + &mut self, + ui: &mut Ui, + all_kinds: impl NodeTemplateIter, + user_state: &mut UserState, + ) -> Option { + let background_color; + let text_color; + + if ui.visuals().dark_mode { + background_color = color_from_hex("#3f3f3f").unwrap(); + text_color = color_from_hex("#fefefe").unwrap(); + } else { + background_color = color_from_hex("#fefefe").unwrap(); + text_color = color_from_hex("#3f3f3f").unwrap(); + } + + ui.visuals_mut().widgets.noninteractive.fg_stroke = Stroke::new(2.0, text_color); + + let frame = Frame::dark_canvas(ui.style()) + .fill(background_color) + .inner_margin(vec2(5.0, 5.0)); + + // The archetype that will be returned. + let mut submitted_archetype = None; + frame.show(ui, |ui| { + ui.vertical(|ui| { + let resp = ui.text_edit_singleline(&mut self.query); + if self.just_spawned { + resp.request_focus(); + self.just_spawned = false; + } + let update_open = resp.changed(); + + let mut query_submit = resp.lost_focus() && ui.input(|i| i.key_pressed(Key::Enter)); + + let max_height = ui.input(|i| i.content_rect().height() * 0.5); + let scroll_area_width = resp.rect.width() - 30.0; + + let all_kinds = all_kinds.all_kinds(); + let mut categories: BTreeMap> = Default::default(); + let mut orphan_kinds = Vec::new(); + + for kind in &all_kinds { + let kind_categories = kind.node_finder_categories(user_state); + + if kind_categories.is_empty() { + orphan_kinds.push(kind); + } else { + for category in kind_categories { + categories.entry(category.name()).or_default().push(kind); + } + } + } + + Frame::default() + .inner_margin(vec2(10.0, 10.0)) + .show(ui, |ui| { + ScrollArea::vertical() + .max_height(max_height) + .show(ui, |ui| { + ui.set_width(scroll_area_width); + ui.set_min_height(1000.); + for (category, kinds) in categories { + let filtered_kinds: Vec<_> = kinds + .into_iter() + .map(|kind| { + let kind_name = + kind.node_finder_label(user_state).to_string(); + (kind, kind_name) + }) + .filter(|(_kind, kind_name)| { + kind_name + .to_lowercase() + .contains(self.query.to_lowercase().as_str()) + }) + .collect(); + + if !filtered_kinds.is_empty() { + let default_open = !self.query.is_empty(); + + CollapsingHeader::new(&category) + .default_open(default_open) + .open(update_open.then_some(default_open)) + .show(ui, |ui| { + for (kind, kind_name) in filtered_kinds { + if ui + .selectable_label(false, kind_name) + .clicked() + { + submitted_archetype = Some(kind.clone()); + } else if query_submit { + submitted_archetype = Some(kind.clone()); + query_submit = false; + } + } + }); + } + } + + for kind in orphan_kinds { + let kind_name = kind.node_finder_label(user_state).to_string(); + + if ui.selectable_label(false, kind_name).clicked() { + submitted_archetype = Some(kind.clone()); + } else if query_submit { + submitted_archetype = Some(kind.clone()); + query_submit = false; + } + } + }); + }); + }); + }); + + submitted_archetype + } +} diff --git a/lightningbeam-ui/egui_node_graph2/src/scale.rs b/lightningbeam-ui/egui_node_graph2/src/scale.rs new file mode 100644 index 0000000..2787690 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/scale.rs @@ -0,0 +1,109 @@ +use egui::epaint::Shadow; +use egui::{style::WidgetVisuals, CornerRadius, Margin, Stroke, Style, Vec2}; + +// Copied from https://github.com/gzp-crey/shine + +pub trait Scale { + fn scale(&mut self, amount: f32); + + fn scaled(&self, amount: f32) -> Self + where + Self: Clone, + { + let mut scaled = self.clone(); + scaled.scale(amount); + scaled + } +} + +impl Scale for Vec2 { + fn scale(&mut self, amount: f32) { + self.x *= amount; + self.y *= amount; + } +} + +impl Scale for Margin { + fn scale(&mut self, amount: f32) { + self.left = (self.left as f32 * amount) as i8; + self.right = (self.right as f32 * amount) as i8; + self.top = (self.top as f32 * amount) as i8; + self.bottom = (self.bottom as f32 * amount) as i8; + } +} + +impl Scale for CornerRadius { + fn scale(&mut self, amount: f32) { + self.ne = (self.ne as f32 * amount) as u8; + self.nw = (self.nw as f32 * amount) as u8; + self.se = (self.se as f32 * amount) as u8; + self.sw = (self.sw as f32 * amount) as u8; + } +} + +impl Scale for Stroke { + fn scale(&mut self, amount: f32) { + self.width *= amount; + } +} + +impl Scale for Shadow { + fn scale(&mut self, amount: f32) { + self.spread = (self.spread as f32 * amount.clamp(0.4, 1.)) as u8; + } +} + +impl Scale for WidgetVisuals { + fn scale(&mut self, amount: f32) { + self.bg_stroke.scale(amount); + self.fg_stroke.scale(amount); + self.corner_radius.scale(amount); + self.expansion *= amount.clamp(0.4, 1.); + } +} + +impl Scale for Style { + fn scale(&mut self, amount: f32) { + if let Some(ov_font_id) = &mut self.override_font_id { + ov_font_id.size *= amount; + } + + for text_style in self.text_styles.values_mut() { + text_style.size *= amount; + } + + self.spacing.item_spacing.scale(amount); + self.spacing.window_margin.scale(amount); + self.spacing.button_padding.scale(amount); + self.spacing.indent *= amount; + self.spacing.interact_size.scale(amount); + self.spacing.slider_width *= amount; + self.spacing.text_edit_width *= amount; + self.spacing.icon_width *= amount; + self.spacing.icon_width_inner *= amount; + self.spacing.icon_spacing *= amount; + self.spacing.tooltip_width *= amount; + self.spacing.combo_height *= amount; + self.spacing.scroll.bar_width *= amount; + self.spacing.scroll.floating_allocated_width *= amount; + self.spacing.scroll.floating_width *= amount; + + self.interaction.resize_grab_radius_side *= amount; + self.interaction.resize_grab_radius_corner *= amount; + + self.visuals.widgets.noninteractive.scale(amount); + self.visuals.widgets.inactive.scale(amount); + self.visuals.widgets.hovered.scale(amount); + self.visuals.widgets.active.scale(amount); + self.visuals.widgets.open.scale(amount); + + self.visuals.selection.stroke.scale(amount); + + self.visuals.resize_corner_size *= amount; + self.visuals.text_cursor.stroke.width *= amount; + self.visuals.clip_rect_margin *= amount; + self.visuals.window_corner_radius.scale(amount); + self.visuals.window_shadow.scale(amount); + self.visuals.popup_shadow.scale(amount); + } +} diff --git a/lightningbeam-ui/egui_node_graph2/src/traits.rs b/lightningbeam-ui/egui_node_graph2/src/traits.rs new file mode 100644 index 0000000..694c090 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/traits.rs @@ -0,0 +1,284 @@ +use super::*; + +/// This trait must be implemented by the `ValueType` generic parameter of the +/// [`Graph`]. The trait allows drawing custom inline widgets for the different +/// types of the node graph. +/// +/// The [`Default`] trait bound is required to circumvent borrow checker issues +/// using `std::mem::take` Otherwise, it would be impossible to pass the +/// `node_data` parameter during `value_widget`. The default value is never +/// used, so the implementation is not important, but it should be reasonably +/// cheap to construct. +pub trait WidgetValueTrait: Default { + type Response; + type UserState; + type NodeData; + + /// This method will be called for each input parameter with a widget with an disconnected + /// input only. To display UI for connected inputs use [`WidgetValueTrait::value_widget_connected`]. + /// The return value is a vector of custom response objects which can be used + /// to implement handling of side effects. If unsure, the response Vec can + /// be empty. + fn value_widget( + &mut self, + param_name: &str, + node_id: NodeId, + ui: &mut egui::Ui, + user_state: &mut Self::UserState, + node_data: &Self::NodeData, + ) -> Vec; + + /// This method will be called for each input parameter with a widget with a connected + /// input only. To display UI for diconnected inputs use [`WidgetValueTrait::value_widget`]. + /// The return value is a vector of custom response objects which can be used + /// to implement handling of side effects. If unsure, the response Vec can + /// be empty. + /// + /// Shows the input name label by default. + fn value_widget_connected( + &mut self, + param_name: &str, + _node_id: NodeId, + ui: &mut egui::Ui, + _user_state: &mut Self::UserState, + _node_data: &Self::NodeData, + ) -> Vec { + ui.label(param_name); + + Default::default() + } +} + +/// This trait must be implemented by the `DataType` generic parameter of the +/// [`Graph`]. This trait tells the library how to visually expose data types +/// to the user. +pub trait DataTypeTrait: PartialEq + Eq { + /// The associated port color of this datatype + fn data_type_color(&self, user_state: &mut UserState) -> egui::Color32; + + /// The name of this datatype. Return type is specified as Cow because + /// some implementations will need to allocate a new string to provide an + /// answer while others won't. + /// + /// ## Example (borrowed value) + /// Use this when you can get the name of the datatype from its fields or as + /// a &'static str. Prefer this method when possible. + /// ```ignore + /// pub struct DataType { name: String } + /// + /// impl DataTypeTrait<()> for DataType { + /// fn name(&self) -> std::borrow::Cow { + /// Cow::Borrowed(&self.name) + /// } + /// } + /// ``` + /// + /// ## Example (owned value) + /// Use this when you can't derive the name of the datatype from its fields. + /// ```ignore + /// pub struct DataType { some_tag: i32 } + /// + /// impl DataTypeTrait<()> for DataType { + /// fn name(&self) -> std::borrow::Cow { + /// Cow::Owned(format!("Super amazing type #{}", self.some_tag)) + /// } + /// } + /// ``` + fn name(&self) -> std::borrow::Cow; +} + +/// This trait must be implemented for the `NodeData` generic parameter of the +/// [`Graph`]. This trait allows customizing some aspects of the node drawing. +pub trait NodeDataTrait +where + Self: Sized, +{ + /// Must be set to the custom user `NodeResponse` type + type Response; + /// Must be set to the custom user `UserState` type + type UserState; + /// Must be set to the custom user `DataType` type + type DataType; + /// Must be set to the custom user `ValueType` type + type ValueType; + + /// Additional UI elements to draw in the nodes, after the parameters. + fn bottom_ui( + &self, + ui: &mut egui::Ui, + node_id: NodeId, + graph: &Graph, + user_state: &mut Self::UserState, + ) -> Vec> + where + Self::Response: UserResponseTrait; + + /// UI to draw on the top bar of the node. + fn top_bar_ui( + &self, + _ui: &mut egui::Ui, + _node_id: NodeId, + _graph: &Graph, + _user_state: &mut Self::UserState, + ) -> Vec> + where + Self::Response: UserResponseTrait, + { + Default::default() + } + + /// UI to draw for each output + /// + /// Defaults to showing param_name as a simple label. + fn output_ui( + &self, + ui: &mut egui::Ui, + _node_id: NodeId, + _graph: &Graph, + _user_state: &mut Self::UserState, + param_name: &str, + ) -> Vec> + where + Self::Response: UserResponseTrait, + { + ui.label(param_name); + + Default::default() + } + + /// Set background color on titlebar + /// If the return value is None, the default color is set. + fn titlebar_color( + &self, + _ui: &egui::Ui, + _node_id: NodeId, + _graph: &Graph, + _user_state: &mut Self::UserState, + ) -> Option { + None + } + + /// Separator to put between elements in the node. + /// + /// Invoked between inputs, outputs and bottom UI. Useful for + /// complicated UIs that start to lose structure without explicit + /// separators. The `param_id` argument is the id of input or output + /// *preceeding* the separator. + /// + /// Default implementation does nothing. + fn separator( + &self, + _ui: &mut egui::Ui, + _node_id: NodeId, + _param_id: AnyParameterId, + _graph: &Graph, + _user_state: &mut Self::UserState, + ) { + } + + fn can_delete( + &self, + _node_id: NodeId, + _graph: &Graph, + _user_state: &mut Self::UserState, + ) -> bool { + true + } +} + +/// This trait can be implemented by any user type. The trait tells the library +/// how to enumerate the node templates it will present to the user as part of +/// the node finder. +pub trait NodeTemplateIter { + type Item; + fn all_kinds(&self) -> Vec; +} + +/// Describes a category of nodes. +/// +/// Used by [`NodeTemplateTrait::node_finder_categories`] to categorize nodes +/// templates into groups. +/// +/// If all nodes in a program are known beforehand, it's usefult to define +/// an enum containing all categories and implement [`CategoryTrait`] for it. This will +/// make it impossible to accidentally create a new category by mis-typing an existing +/// one, like in the case of using string types. +pub trait CategoryTrait { + /// Name of the category. + fn name(&self) -> String; +} + +impl CategoryTrait for () { + fn name(&self) -> String { + String::new() + } +} + +impl<'a> CategoryTrait for &'a str { + fn name(&self) -> String { + self.to_string() + } +} + +impl CategoryTrait for String { + fn name(&self) -> String { + self.clone() + } +} + +/// This trait must be implemented by the `NodeTemplate` generic parameter of +/// the [`GraphEditorState`]. It allows the customization of node templates. A +/// node template is what describes what kinds of nodes can be added to the +/// graph, what is their name, and what are their input / output parameters. +pub trait NodeTemplateTrait: Clone { + /// Must be set to the custom user `NodeData` type + type NodeData; + /// Must be set to the custom user `DataType` type + type DataType; + /// Must be set to the custom user `ValueType` type + type ValueType; + /// Must be set to the custom user `UserState` type + type UserState; + /// Must be a type that implements the [`CategoryTrait`] trait. + /// + /// `&'static str` is a good default if you intend to simply type out + /// the categories of your node. Use `()` if you don't need categories + /// at all. + type CategoryType; + + /// Returns a descriptive name for the node kind, used in the node finder. + /// + /// The return type is Cow to allow returning owned or borrowed values + /// more flexibly. Refer to the documentation for `DataTypeTrait::name` for + /// more information + fn node_finder_label(&self, user_state: &mut Self::UserState) -> std::borrow::Cow; + + /// Vec of categories to which the node belongs. + /// + /// It's often useful to organize similar nodes into categories, which will + /// then be used by the node finder to show a more manageable UI, especially + /// if the node template are numerous. + fn node_finder_categories(&self, _user_state: &mut Self::UserState) -> Vec { + Vec::default() + } + + /// Returns a descriptive name for the node kind, used in the graph. + fn node_graph_label(&self, user_state: &mut Self::UserState) -> String; + + /// Returns the user data for this node kind. + fn user_data(&self, user_state: &mut Self::UserState) -> Self::NodeData; + + /// This function is run when this node kind gets added to the graph. The + /// node will be empty by default, and this function can be used to fill its + /// parameters. + fn build_node( + &self, + graph: &mut Graph, + user_state: &mut Self::UserState, + node_id: NodeId, + ); +} + +/// The custom user response types when drawing nodes in the graph must +/// implement this trait. +pub trait UserResponseTrait: Clone + std::fmt::Debug {} diff --git a/lightningbeam-ui/egui_node_graph2/src/ui_state.rs b/lightningbeam-ui/egui_node_graph2/src/ui_state.rs new file mode 100644 index 0000000..cf641e5 --- /dev/null +++ b/lightningbeam-ui/egui_node_graph2/src/ui_state.rs @@ -0,0 +1,129 @@ +use super::*; +use egui::{Rect, Style, Ui, Vec2}; +use std::marker::PhantomData; +use std::sync::Arc; + +use crate::scale::Scale; +#[cfg(feature = "persistence")] +use serde::{Deserialize, Serialize}; + +const MIN_ZOOM: f32 = 0.2; +const MAX_ZOOM: f32 = 2.0; + +#[derive(Clone)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct GraphEditorState { + pub graph: Graph, + /// Nodes are drawn in this order. Draw order is important because nodes + /// that are drawn last are on top. + pub node_order: Vec, + /// An ongoing connection interaction: The mouse has dragged away from a + /// port and the user is holding the click + pub connection_in_progress: Option<(NodeId, AnyParameterId)>, + /// The currently selected node. Some interface actions depend on the + /// currently selected node. + pub selected_nodes: Vec, + /// The mouse drag start position for an ongoing box selection. + pub ongoing_box_selection: Option, + /// The position of each node. + pub node_positions: SecondaryMap, + /// The node finder is used to create new nodes. + pub node_finder: Option>, + /// The panning of the graph viewport. + pub pan_zoom: PanZoom, + pub _user_state: PhantomData UserState>, +} + +impl + GraphEditorState +{ + pub fn new(default_zoom: f32) -> Self { + Self { + pan_zoom: PanZoom::new(default_zoom), + ..Default::default() + } + } +} +impl Default + for GraphEditorState +{ + fn default() -> Self { + Self { + graph: Default::default(), + node_order: Default::default(), + connection_in_progress: Default::default(), + selected_nodes: Default::default(), + ongoing_box_selection: Default::default(), + node_positions: Default::default(), + node_finder: Default::default(), + pan_zoom: Default::default(), + _user_state: Default::default(), + } + } +} + +#[cfg(feature = "persistence")] +fn _default_clip_rect() -> Rect { + Rect::NOTHING +} + +#[derive(Clone)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct PanZoom { + pub pan: Vec2, + pub zoom: f32, + #[cfg_attr(feature = "persistence", serde(skip, default = "_default_clip_rect"))] + pub clip_rect: Rect, + #[cfg_attr(feature = "persistence", serde(skip, default))] + pub zoomed_style: Arc