From 673c909b047717467cbaf05de2a3efc8c325c32e Mon Sep 17 00:00:00 2001 From: Matt Campbell Date: Fri, 20 Dec 2024 07:38:23 -0600 Subject: [PATCH] feat: Cache Windows node COM objects Co-authored-by: Arnold Loubriat --- platforms/windows/src/adapter.rs | 84 +++++++++++++++++++------------- platforms/windows/src/context.rs | 27 +++++++++- platforms/windows/src/node.rs | 76 ++++++++++++++--------------- platforms/windows/src/text.rs | 25 +++------- 4 files changed, 116 insertions(+), 96 deletions(-) diff --git a/platforms/windows/src/adapter.rs b/platforms/windows/src/adapter.rs index 4198ae37..e0eb7e00 100644 --- a/platforms/windows/src/adapter.rs +++ b/platforms/windows/src/adapter.rs @@ -25,8 +25,8 @@ use crate::{ }; fn focus_event(context: &Arc, node_id: NodeId) -> QueuedEvent { - let platform_node = PlatformNode::new(context, node_id); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = context.get_or_create_platform_node(node_id); + let element: IRawElementProviderSimple = platform_node.into_interface(); QueuedEvent::Simple { element, event_id: UIA_AutomationFocusChangedEventId, @@ -60,8 +60,8 @@ impl AdapterChangeHandler<'_> { if self.text_changed.contains(&id) { return; } - let platform_node = PlatformNode::new(self.context, node.id()); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = self.context.get_or_create_platform_node(node.id()); + let element: IRawElementProviderSimple = platform_node.into_interface(); // Text change events must come before selection change // events. It doesn't matter if text change events come // before other events. @@ -135,19 +135,25 @@ impl AdapterChangeHandler<'_> { if let Some(only_selected_child) = only_selected_child { self.queue.push(QueuedEvent::Simple { - element: PlatformNode::new(self.context, only_selected_child.id()).into(), + element: self + .context + .get_or_create_platform_node(only_selected_child.id()) + .into_interface(), event_id: UIA_SelectionItem_ElementSelectedEventId, }); self.queue.push(QueuedEvent::PropertyChanged { - element: PlatformNode::new(self.context, only_selected_child.id()).into(), + element: self + .context + .get_or_create_platform_node(only_selected_child.id()) + .into_interface(), property_id: UIA_SelectionItemIsSelectedPropertyId, old_value: false.into(), new_value: true.into(), }); for child_id in changes.removed_items.iter() { - let platform_node = PlatformNode::new(self.context, *child_id); + let platform_node = self.context.get_or_create_platform_node(*child_id); self.queue.push(QueuedEvent::PropertyChanged { - element: platform_node.into(), + element: platform_node.into_interface(), property_id: UIA_SelectionItemIsSelectedPropertyId, old_value: true.into(), new_value: false.into(), @@ -161,9 +167,9 @@ impl AdapterChangeHandler<'_> { if let Some(container) = container.filter(|_| { changes.added_items.len() + changes.removed_items.len() > INVALIDATE_LIMIT }) { - let platform_node = PlatformNode::new(self.context, container.id()); + let platform_node = self.context.get_or_create_platform_node(container.id()); self.queue.push(QueuedEvent::Simple { - element: platform_node.into(), + element: platform_node.into_interface(), event_id: UIA_Selection_InvalidatedEventId, }); } else { @@ -171,14 +177,20 @@ impl AdapterChangeHandler<'_> { container.is_some_and(|c| c.is_multiselectable()); for added_id in changes.added_items.iter() { self.queue.push(QueuedEvent::Simple { - element: PlatformNode::new(self.context, *added_id).into(), + element: self + .context + .get_or_create_platform_node(*added_id) + .into_interface(), event_id: match container_is_multiselectable { true => UIA_SelectionItem_ElementAddedToSelectionEventId, false => UIA_SelectionItem_ElementSelectedEventId, }, }); self.queue.push(QueuedEvent::PropertyChanged { - element: PlatformNode::new(self.context, *added_id).into(), + element: self + .context + .get_or_create_platform_node(*added_id) + .into_interface(), property_id: UIA_SelectionItemIsSelectedPropertyId, old_value: false.into(), new_value: true.into(), @@ -186,11 +198,17 @@ impl AdapterChangeHandler<'_> { } for removed_id in changes.removed_items.iter() { self.queue.push(QueuedEvent::Simple { - element: PlatformNode::new(self.context, *removed_id).into(), + element: self + .context + .get_or_create_platform_node(*removed_id) + .into_interface(), event_id: UIA_SelectionItem_ElementRemovedFromSelectionEventId, }); self.queue.push(QueuedEvent::PropertyChanged { - element: PlatformNode::new(self.context, *removed_id).into(), + element: self + .context + .get_or_create_platform_node(*removed_id) + .into_interface(), property_id: UIA_SelectionItemIsSelectedPropertyId, old_value: true.into(), new_value: false.into(), @@ -215,16 +233,16 @@ impl TreeChangeHandler for AdapterChangeHandler<'_> { } let wrapper = NodeWrapper(node); if node.is_dialog() { - let platform_node = PlatformNode::new(self.context, node.id()); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = self.context.get_or_create_platform_node(node.id()); + let element: IRawElementProviderSimple = platform_node.into_interface(); self.queue.push(QueuedEvent::Simple { element, event_id: UIA_Window_WindowOpenedEventId, }); } if wrapper.name().is_some() && node.live() != Live::Off { - let platform_node = PlatformNode::new(self.context, node.id()); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = self.context.get_or_create_platform_node(node.id()); + let element: IRawElementProviderSimple = platform_node.into_interface(); self.queue.push(QueuedEvent::Simple { element, event_id: UIA_LiveRegionChangedEventId, @@ -243,8 +261,8 @@ impl TreeChangeHandler for AdapterChangeHandler<'_> { if filter(new_node) != FilterResult::Include { if !old_node_was_filtered_out { if old_node.is_dialog() { - let platform_node = PlatformNode::new(self.context, old_node.id()); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = self.context.get_or_create_platform_node(old_node.id()); + let element: IRawElementProviderSimple = platform_node.into_interface(); self.queue.push(QueuedEvent::Simple { element, event_id: UIA_Window_WindowClosedEventId, @@ -257,16 +275,11 @@ impl TreeChangeHandler for AdapterChangeHandler<'_> { } return; } - let platform_node = PlatformNode::new(self.context, new_node.id()); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = self.context.get_or_create_platform_node(new_node.id()); + let element: IRawElementProviderSimple = platform_node.into_interface(); let old_wrapper = NodeWrapper(old_node); let new_wrapper = NodeWrapper(new_node); - new_wrapper.enqueue_property_changes( - &mut self.queue, - &PlatformNode::new(self.context, new_node.id()), - &element, - &old_wrapper, - ); + new_wrapper.enqueue_property_changes(&mut self.queue, self.context, &element, &old_wrapper); let new_name = new_wrapper.name(); if new_name.is_some() && new_node.live() != Live::Off @@ -280,8 +293,8 @@ impl TreeChangeHandler for AdapterChangeHandler<'_> { }); } if old_node_was_filtered_out && new_node.is_dialog() { - let platform_node = PlatformNode::new(self.context, new_node.id()); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = self.context.get_or_create_platform_node(new_node.id()); + let element: IRawElementProviderSimple = platform_node.into_interface(); self.queue.push(QueuedEvent::Simple { element, event_id: UIA_Window_WindowOpenedEventId, @@ -303,12 +316,13 @@ impl TreeChangeHandler for AdapterChangeHandler<'_> { fn node_removed(&mut self, node: &Node) { self.insert_text_change_if_needed(node); + self.context.remove_platform_node(node.id()); if filter(node) != FilterResult::Include { return; } if node.is_dialog() { - let platform_node = PlatformNode::new(self.context, node.id()); - let element: IRawElementProviderSimple = platform_node.into(); + let platform_node = self.context.get_or_create_platform_node(node.id()); + let element: IRawElementProviderSimple = platform_node.into_interface(); self.queue.push(QueuedEvent::Simple { element, event_id: UIA_Window_WindowClosedEventId, @@ -509,7 +523,7 @@ impl Adapter { let tree = Tree::new(initial_state, *is_window_focused); let context = Context::new(hwnd, tree, Arc::clone(action_handler), false); let node_id = context.read_tree().state().root_id(); - let platform_node = PlatformNode::new(&context, node_id); + let platform_node = context.get_or_create_platform_node(node_id); self.state = State::Active(context); (hwnd, platform_node) } @@ -532,10 +546,10 @@ impl Adapter { State::Placeholder(context) => (context.hwnd, PlatformNode::unspecified_root(context)), State::Active(context) => { let node_id = context.read_tree().state().root_id(); - (context.hwnd, PlatformNode::new(context, node_id)) + (context.hwnd, context.get_or_create_platform_node(node_id)) } }; - let el: IRawElementProviderSimple = platform_node.into(); + let el: IRawElementProviderSimple = platform_node.into_interface(); Some(WmGetObjectResult { hwnd, wparam, diff --git a/platforms/windows/src/context.rs b/platforms/windows/src/context.rs index 71e57cb2..72e5ce5c 100644 --- a/platforms/windows/src/context.rs +++ b/platforms/windows/src/context.rs @@ -4,11 +4,13 @@ // the LICENSE-MIT file), at your option. use accesskit::{ActionHandler, ActionRequest, Point}; -use accesskit_consumer::Tree; +use accesskit_consumer::{NodeId, Tree}; +use hashbrown::HashMap; use std::fmt::{Debug, Formatter}; use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, atomic::AtomicBool}; +use windows::core::ComObject; -use crate::{util::*, window_handle::WindowHandle}; +use crate::{node::PlatformNode, util::*, window_handle::WindowHandle}; pub(crate) trait ActionHandlerNoMut { fn do_action(&self, request: ActionRequest); @@ -33,6 +35,7 @@ pub(crate) struct Context { pub(crate) tree: RwLock, pub(crate) action_handler: Arc, pub(crate) is_placeholder: AtomicBool, + platform_nodes: Mutex>>, } impl Debug for Context { @@ -58,6 +61,7 @@ impl Context { tree: RwLock::new(tree), action_handler, is_placeholder: AtomicBool::new(is_placeholder), + platform_nodes: Mutex::new(HashMap::new()), }) } @@ -72,4 +76,23 @@ impl Context { pub(crate) fn do_action(&self, request: ActionRequest) { self.action_handler.do_action(request); } + + pub(crate) fn get_or_create_platform_node( + self: &Arc, + id: NodeId, + ) -> ComObject { + let mut platform_nodes = self.platform_nodes.lock().unwrap(); + if let Some(result) = platform_nodes.get(&id) { + return result.clone(); + } + + let result = PlatformNode::new(self, id); + platform_nodes.insert(id, result.clone()); + result + } + + pub(crate) fn remove_platform_node(&self, id: NodeId) { + let mut platform_nodes = self.platform_nodes.lock().unwrap(); + platform_nodes.remove(&id); + } } diff --git a/platforms/windows/src/node.rs b/platforms/windows/src/node.rs index 14c779be..cfa8815a 100644 --- a/platforms/windows/src/node.rs +++ b/platforms/windows/src/node.rs @@ -726,11 +726,11 @@ impl NodeWrapper<'_> { pub(crate) fn enqueue_property_changes( &self, queue: &mut Vec, - platform_node: &PlatformNode, + context: &Arc, element: &IRawElementProviderSimple, old: &NodeWrapper, ) { - self.enqueue_simple_property_changes(queue, platform_node, element, old); + self.enqueue_simple_property_changes(queue, context, element, old); self.enqueue_pattern_property_changes(queue, element, old); self.enqueue_property_implied_events(queue, element, old); } @@ -792,18 +792,20 @@ pub(crate) struct PlatformNode { } impl PlatformNode { - pub(crate) fn new(context: &Arc, node_id: NodeId) -> Self { + pub(crate) fn new(context: &Arc, node_id: NodeId) -> ComObject { Self { context: Arc::downgrade(context), node_id: Some(node_id), } + .into_object() } - pub(crate) fn unspecified_root(context: &Arc) -> Self { + pub(crate) fn unspecified_root(context: &Arc) -> ComObject { Self { context: Arc::downgrade(context), node_id: None, } + .into_object() } fn upgrade_context(&self) -> Result> { @@ -812,27 +814,20 @@ impl PlatformNode { fn with_tree_state_and_context(&self, f: F) -> Result where - F: FnOnce(&TreeState, &Context) -> Result, + F: FnOnce(&TreeState, &Arc) -> Result, { self.with_tree_and_context(|tree, context| f(tree.state(), context)) } fn with_tree_and_context(&self, f: F) -> Result where - F: FnOnce(&Tree, &Context) -> Result, + F: FnOnce(&Tree, &Arc) -> Result, { let context = self.upgrade_context()?; let tree = context.read_tree(); f(&tree, &context) } - fn with_tree_state(&self, f: F) -> Result - where - F: FnOnce(&TreeState) -> Result, - { - self.with_tree_state_and_context(|state, _| f(state)) - } - fn node<'a>(&self, tree: &'a Tree) -> Result> { let state = tree.state(); if let Some(id) = self.node_id { @@ -857,7 +852,7 @@ impl PlatformNode { fn resolve_with_context(&self, f: F) -> Result where - for<'a> F: FnOnce(Node<'a>, &Context) -> Result, + for<'a> F: FnOnce(Node<'a>, &Arc) -> Result, { self.with_tree_and_context(|tree, context| { let node = self.node(tree)?; @@ -867,7 +862,7 @@ impl PlatformNode { fn resolve_with_tree_state_and_context(&self, f: F) -> Result where - for<'a> F: FnOnce(Node<'a>, &TreeState, &Context) -> Result, + for<'a> F: FnOnce(Node<'a>, &TreeState, &Arc) -> Result, { self.with_tree_and_context(|tree, context| { let node = self.node(tree)?; @@ -884,7 +879,7 @@ impl PlatformNode { fn resolve_with_context_for_text_pattern(&self, f: F) -> Result where - for<'a> F: FnOnce(Node<'a>, &Context) -> Result, + for<'a> F: FnOnce(Node<'a>, &Arc) -> Result, { self.with_tree_and_context(|tree, context| { let node = self.node(tree)?; @@ -984,13 +979,6 @@ impl PlatformNode { }) } - fn relative(&self, node_id: NodeId) -> Self { - Self { - context: self.context.clone(), - node_id: Some(node_id), - } - } - fn is_root(&self, state: &TreeState) -> bool { self.node_id.is_some_and(|id| id == state.root_id()) } @@ -1029,8 +1017,8 @@ impl IRawElementProviderSimple_Impl for PlatformNode_Impl { let controlled: Vec = node .controls() .filter(|controlled| filter(controlled) == FilterResult::Include) - .map(|controlled| self.relative(controlled.id())) - .map(IRawElementProviderSimple::from) + .map(|controlled| context.get_or_create_platform_node(controlled.id())) + .map(|result| result.into_interface::()) .filter_map(|controlled| controlled.cast::().ok()) .collect(); result = controlled.into(); @@ -1056,7 +1044,7 @@ impl IRawElementProviderSimple_Impl for PlatformNode_Impl { #[allow(non_snake_case)] impl IRawElementProviderFragment_Impl for PlatformNode_Impl { fn Navigate(&self, direction: NavigateDirection) -> Result { - self.resolve(|node| { + self.resolve_with_context(|node, context| { let result = match direction { NavigateDirection_Parent => node.filtered_parent(&filter_with_root_exception), NavigateDirection_NextSibling => node.following_filtered_siblings(&filter).next(), @@ -1068,7 +1056,9 @@ impl IRawElementProviderFragment_Impl for PlatformNode_Impl { _ => None, }; match result { - Some(result) => Ok(self.relative(result.id()).into()), + Some(result) => Ok(context + .get_or_create_platform_node(result.id()) + .into_interface()), None => Err(Error::empty()), } }) @@ -1113,12 +1103,14 @@ impl IRawElementProviderFragment_Impl for PlatformNode_Impl { } fn FragmentRoot(&self) -> Result { - self.with_tree_state(|state| { + self.with_tree_state_and_context(|state, context| { if self.is_root(state) { Ok(self.to_interface()) } else { let root_id = state.root_id(); - Ok(self.relative(root_id).into()) + Ok(context + .get_or_create_platform_node(root_id) + .into_interface()) } }) } @@ -1133,21 +1125,26 @@ impl IRawElementProviderFragmentRoot_Impl for PlatformNode_Impl { let point = node.transform().inverse() * point; node.node_at_point(point, &filter).map_or_else( || Err(Error::empty()), - |node| Ok(self.relative(node.id()).into()), + |node| { + Ok(context + .get_or_create_platform_node(node.id()) + .into_interface()) + }, ) }) } fn GetFocus(&self) -> Result { - self.with_tree_state(|state| { + self.with_tree_state_and_context(|state, context| { if let Some(node) = state.focus() { + let id = node.id(); let self_id = if let Some(id) = self.node_id { id } else { state.root_id() }; - if node.id() != self_id { - return Ok(self.relative(node.id()).into()); + if id != self_id { + return Ok(context.get_or_create_platform_node(id).into_interface()); } } Err(Error::empty()) @@ -1169,7 +1166,7 @@ macro_rules! properties { fn enqueue_simple_property_changes( &self, queue: &mut Vec, - platform_node: &PlatformNode, + context: &Arc, element: &IRawElementProviderSimple, old: &NodeWrapper, ) { @@ -1197,7 +1194,7 @@ macro_rules! properties { match (old_controlled, new_controlled) { (Some(a), Some(b)) => { are_equal = are_equal && a.id() == b.id(); - controls.push(platform_node.relative(b.id()).into()); + controls.push(context.get_or_create_platform_node(b.id()).into_interface::().into()); } (None, None) => break, _ => are_equal = false, @@ -1372,9 +1369,9 @@ patterns! { }, fn SelectionContainer(&self) -> Result { - self.resolve(|node| { + self.resolve_with_context(|node, context| { if let Some(container) = node.selection_container(&filter) { - Ok(self.relative(container.id()).into()) + Ok(context.get_or_create_platform_node(container.id()).into_interface()) } else { Err(E_FAIL.into()) } @@ -1386,12 +1383,11 @@ patterns! { (UIA_SelectionIsSelectionRequiredPropertyId, IsSelectionRequired, is_required, BOOL) ), ( fn GetSelection(&self) -> Result<*mut SAFEARRAY> { - self.resolve(|node| { + self.resolve_with_context(|node, context| { let selection: Vec<_> = node .items(&filter) .filter(|item| item.is_selected() == Some(true)) - .map(|item| self.relative(item.id())) - .map(IRawElementProviderSimple::from) + .map(|item| context.get_or_create_platform_node(item.id()).into_interface::()) .filter_map(|item| item.cast::().ok()) .collect(); Ok(safe_array_from_com_slice(&selection)) diff --git a/platforms/windows/src/text.rs b/platforms/windows/src/text.rs index e3388cfb..46c645e6 100644 --- a/platforms/windows/src/text.rs +++ b/platforms/windows/src/text.rs @@ -18,7 +18,7 @@ use windows::{ core::*, }; -use crate::{context::Context, node::PlatformNode, util::*}; +use crate::{context::Context, util::*}; fn upgrade_range<'a>(weak: &WeakRange, tree_state: &'a TreeState) -> Result> { if let Some(range) = weak.upgrade(tree_state) { @@ -244,16 +244,6 @@ impl PlatformRange { upgrade_range_node(&state, tree_state) } - fn with_node(&self, f: F) -> Result - where - F: FnOnce(Node) -> Result, - { - self.with_tree_state(|tree_state| { - let node = self.upgrade_node(tree_state)?; - f(node) - }) - } - fn upgrade_for_read<'a>(&self, tree_state: &'a TreeState) -> Result> { let state = self.state.read().unwrap(); upgrade_range(&state, tree_state) @@ -513,14 +503,11 @@ impl ITextRangeProvider_Impl for PlatformRange_Impl { } fn GetEnclosingElement(&self) -> Result { - self.with_node(|node| { - // Revisit this if we eventually support embedded objects. - Ok(PlatformNode { - context: self.context.clone(), - node_id: Some(node.id()), - } - .into()) - }) + // Revisit this if we eventually support embedded objects. + let context = self.upgrade_context()?; + let tree = context.read_tree(); + let id = self.upgrade_node(tree.state())?.id(); + Ok(context.get_or_create_platform_node(id).into_interface()) } fn GetText(&self, _max_length: i32) -> Result {