matrix_sdk_base/room/
knock.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16
17use eyeball::{AsyncLock, ObservableWriteGuard};
18use ruma::{
19    events::{
20        room::member::{MembershipState, RoomMemberEventContent},
21        StateEventType, SyncStateEvent,
22    },
23    OwnedEventId, OwnedUserId,
24};
25use tracing::warn;
26
27use super::Room;
28use crate::{
29    deserialized_responses::{MemberEvent, RawMemberEvent, SyncOrStrippedState},
30    store::{Result as StoreResult, StateStoreExt},
31    StateStoreDataKey, StateStoreDataValue, StoreError,
32};
33
34impl Room {
35    /// Mark a list of requests to join the room as seen, given their state
36    /// event ids.
37    pub async fn mark_knock_requests_as_seen(&self, user_ids: &[OwnedUserId]) -> StoreResult<()> {
38        let raw_user_ids: Vec<&str> = user_ids.iter().map(|id| id.as_str()).collect();
39        let member_raw_events = self
40            .store
41            .get_state_events_for_keys(self.room_id(), StateEventType::RoomMember, &raw_user_ids)
42            .await?;
43        let mut event_to_user_ids = Vec::with_capacity(member_raw_events.len());
44
45        // Map the list of events ids to their user ids, if they are event ids for knock
46        // membership events. Log an error and continue otherwise.
47        for raw_event in member_raw_events {
48            let event = raw_event.cast::<RoomMemberEventContent>().deserialize()?;
49            match event {
50                SyncOrStrippedState::Sync(SyncStateEvent::Original(event)) => {
51                    if event.content.membership == MembershipState::Knock {
52                        event_to_user_ids.push((event.event_id, event.state_key))
53                    } else {
54                        warn!("Could not mark knock event as seen: event {} for user {} is not in Knock membership state.", event.event_id, event.state_key);
55                    }
56                }
57                _ => warn!(
58                    "Could not mark knock event as seen: event for user {} is not valid.",
59                    event.state_key()
60                ),
61            }
62        }
63
64        let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?;
65        let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default();
66
67        current_seen_events.extend(event_to_user_ids);
68
69        self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?;
70
71        Ok(())
72    }
73
74    /// Removes the seen knock request ids that are no longer valid given the
75    /// current room members.
76    pub async fn remove_outdated_seen_knock_requests_ids(&self) -> StoreResult<()> {
77        let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?;
78        let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default();
79
80        // Get and deserialize the member events for the seen knock requests
81        let keys: Vec<OwnedUserId> = current_seen_events.values().map(|id| id.to_owned()).collect();
82        let raw_member_events: Vec<RawMemberEvent> =
83            self.store.get_state_events_for_keys_static(self.room_id(), &keys).await?;
84        let member_events = raw_member_events
85            .into_iter()
86            .map(|raw| raw.deserialize())
87            .collect::<Result<Vec<MemberEvent>, _>>()?;
88
89        let mut ids_to_remove = Vec::new();
90
91        for (event_id, user_id) in current_seen_events.iter() {
92            // Check the seen knock request ids against the current room member events for
93            // the room members associated to them
94            let matching_member = member_events.iter().find(|event| event.user_id() == user_id);
95
96            if let Some(member) = matching_member {
97                let member_event_id = member.event_id();
98                // If the member event is not a knock or it's different knock, it's outdated
99                if *member.membership() != MembershipState::Knock
100                    || member_event_id.is_some_and(|id| id != event_id)
101                {
102                    ids_to_remove.push(event_id.to_owned());
103                }
104            } else {
105                ids_to_remove.push(event_id.to_owned());
106            }
107        }
108
109        // If there are no ids to remove, do nothing
110        if ids_to_remove.is_empty() {
111            return Ok(());
112        }
113
114        for event_id in ids_to_remove {
115            current_seen_events.remove(&event_id);
116        }
117
118        self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?;
119
120        Ok(())
121    }
122
123    /// Get the list of seen knock request event ids in this room.
124    pub async fn get_seen_knock_request_ids(
125        &self,
126    ) -> Result<BTreeMap<OwnedEventId, OwnedUserId>, StoreError> {
127        Ok(self.get_write_guarded_current_knock_request_ids().await?.clone().unwrap_or_default())
128    }
129
130    async fn get_write_guarded_current_knock_request_ids(
131        &self,
132    ) -> StoreResult<ObservableWriteGuard<'_, Option<BTreeMap<OwnedEventId, OwnedUserId>>, AsyncLock>>
133    {
134        let mut guard = self.seen_knock_request_ids_map.write().await;
135        // If there are no loaded request ids yet
136        if guard.is_none() {
137            // Load the values from the store and update the shared observable contents
138            let updated_seen_ids = self
139                .store
140                .get_kv_data(StateStoreDataKey::SeenKnockRequests(self.room_id()))
141                .await?
142                .and_then(|v| v.into_seen_knock_requests())
143                .unwrap_or_default();
144
145            ObservableWriteGuard::set(&mut guard, Some(updated_seen_ids));
146        }
147        Ok(guard)
148    }
149
150    async fn update_seen_knock_request_ids(
151        &self,
152        mut guard: ObservableWriteGuard<'_, Option<BTreeMap<OwnedEventId, OwnedUserId>>, AsyncLock>,
153        new_value: BTreeMap<OwnedEventId, OwnedUserId>,
154    ) -> StoreResult<()> {
155        // Save the new values to the shared observable
156        ObservableWriteGuard::set(&mut guard, Some(new_value.clone()));
157
158        // Save them into the store too
159        self.store
160            .set_kv_data(
161                StateStoreDataKey::SeenKnockRequests(self.room_id()),
162                StateStoreDataValue::SeenKnockRequests(new_value),
163            )
164            .await?;
165
166        Ok(())
167    }
168}