matrix_sdk_crypto/olm/group_sessions/
outbound.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    cmp::max,
17    collections::{BTreeMap, BTreeSet},
18    fmt,
19    ops::Bound,
20    sync::{
21        atomic::{AtomicBool, AtomicU64, Ordering},
22        Arc, RwLockReadGuard,
23    },
24    time::Duration,
25};
26
27use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock as StdRwLock};
28use ruma::{
29    events::{
30        room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility},
31        AnyMessageLikeEventContent,
32    },
33    serde::Raw,
34    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId,
35    SecondsSinceUnixEpoch, TransactionId, UserId,
36};
37use serde::{Deserialize, Serialize};
38use tokio::sync::RwLock;
39use tracing::{debug, error, info};
40use vodozemac::{megolm::SessionConfig, Curve25519PublicKey};
41pub use vodozemac::{
42    megolm::{GroupSession, GroupSessionPickle, MegolmMessage, SessionKey},
43    olm::IdentityKeys,
44    PickleError,
45};
46
47use super::SessionCreationError;
48#[cfg(feature = "experimental-algorithms")]
49use crate::types::events::room::encrypted::MegolmV2AesSha2Content;
50use crate::{
51    olm::account::shared_history_from_history_visibility,
52    session_manager::CollectStrategy,
53    store::caches::SequenceNumber,
54    types::{
55        events::{
56            room::encrypted::{
57                MegolmV1AesSha2Content, RoomEncryptedEventContent, RoomEventEncryptionScheme,
58            },
59            room_key::{MegolmV1AesSha2Content as MegolmV1AesSha2RoomKeyContent, RoomKeyContent},
60            room_key_withheld::RoomKeyWithheldContent,
61        },
62        requests::ToDeviceRequest,
63        EventEncryptionAlgorithm,
64    },
65    DeviceData,
66};
67
68const ONE_HOUR: Duration = Duration::from_secs(60 * 60);
69const ONE_WEEK: Duration = Duration::from_secs(60 * 60 * 24 * 7);
70
71const ROTATION_PERIOD: Duration = ONE_WEEK;
72const ROTATION_MESSAGES: u64 = 100;
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
75/// Information about whether a session was shared with a device.
76pub(crate) enum ShareState {
77    /// The session was not shared with the device.
78    NotShared,
79    /// The session was shared with the device with the given device ID, but
80    /// with a different curve25519 key.
81    SharedButChangedSenderKey,
82    /// The session was shared with the device, at the given message index. The
83    /// `olm_wedging_index` is the value of the `olm_wedging_index` from the
84    /// [`DeviceData`] at the time that we last shared the session with the
85    /// device, and indicates whether we need to re-share the session with the
86    /// device.
87    Shared { message_index: u32, olm_wedging_index: SequenceNumber },
88}
89
90/// Settings for an encrypted room.
91///
92/// This determines the algorithm and rotation periods of a group session.
93#[derive(Clone, Debug, Deserialize, Serialize)]
94pub struct EncryptionSettings {
95    /// The encryption algorithm that should be used in the room.
96    pub algorithm: EventEncryptionAlgorithm,
97    /// How long the session should be used before changing it.
98    pub rotation_period: Duration,
99    /// How many messages should be sent before changing the session.
100    pub rotation_period_msgs: u64,
101    /// The history visibility of the room when the session was created.
102    pub history_visibility: HistoryVisibility,
103    /// The strategy used to distribute the room keys to participant.
104    /// Default will send to all devices.
105    #[serde(default)]
106    pub sharing_strategy: CollectStrategy,
107}
108
109impl Default for EncryptionSettings {
110    fn default() -> Self {
111        Self {
112            algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
113            rotation_period: ROTATION_PERIOD,
114            rotation_period_msgs: ROTATION_MESSAGES,
115            history_visibility: HistoryVisibility::Shared,
116            sharing_strategy: CollectStrategy::default(),
117        }
118    }
119}
120
121impl EncryptionSettings {
122    /// Create new encryption settings using an `RoomEncryptionEventContent`,
123    /// a history visibility, and key sharing strategy.
124    pub fn new(
125        content: RoomEncryptionEventContent,
126        history_visibility: HistoryVisibility,
127        sharing_strategy: CollectStrategy,
128    ) -> Self {
129        let rotation_period: Duration =
130            content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
131        let rotation_period_msgs: u64 =
132            content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
133
134        Self {
135            algorithm: EventEncryptionAlgorithm::from(content.algorithm.as_str()),
136            rotation_period,
137            rotation_period_msgs,
138            history_visibility,
139            sharing_strategy,
140        }
141    }
142}
143
144/// Outbound group session.
145///
146/// Outbound group sessions are used to exchange room messages between a group
147/// of participants. Outbound group sessions are used to encrypt the room
148/// messages.
149#[derive(Clone)]
150pub struct OutboundGroupSession {
151    inner: Arc<RwLock<GroupSession>>,
152    device_id: OwnedDeviceId,
153    account_identity_keys: Arc<IdentityKeys>,
154    session_id: Arc<str>,
155    room_id: OwnedRoomId,
156    pub(crate) creation_time: SecondsSinceUnixEpoch,
157    message_count: Arc<AtomicU64>,
158    shared: Arc<AtomicBool>,
159    invalidated: Arc<AtomicBool>,
160    settings: Arc<EncryptionSettings>,
161    shared_with_set: Arc<StdRwLock<ShareInfoSet>>,
162    to_share_with_set: Arc<StdRwLock<ToShareMap>>,
163}
164
165/// A a map of userid/device it to a `ShareInfo`.
166///
167/// Holds the `ShareInfo` for all the user/device pairs that will receive the
168/// room key.
169pub type ShareInfoSet = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>;
170
171type ToShareMap = BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>;
172
173/// Struct holding info about the share state of a outbound group session.
174#[derive(Clone, Debug, Serialize, Deserialize)]
175pub enum ShareInfo {
176    /// When the key has been shared
177    Shared(SharedWith),
178    /// When the session has been withheld
179    Withheld(WithheldCode),
180}
181
182impl ShareInfo {
183    /// Helper to create a SharedWith info
184    pub fn new_shared(
185        sender_key: Curve25519PublicKey,
186        message_index: u32,
187        olm_wedging_index: SequenceNumber,
188    ) -> Self {
189        ShareInfo::Shared(SharedWith { sender_key, message_index, olm_wedging_index })
190    }
191
192    /// Helper to create a Withheld info
193    pub fn new_withheld(code: WithheldCode) -> Self {
194        ShareInfo::Withheld(code)
195    }
196}
197
198#[derive(Clone, Debug, Serialize, Deserialize)]
199pub struct SharedWith {
200    /// The sender key of the device that was used to encrypt the room key.
201    pub sender_key: Curve25519PublicKey,
202    /// The message index that the device received.
203    pub message_index: u32,
204    /// The Olm wedging index of the device at the time the session was shared.
205    #[serde(default)]
206    pub olm_wedging_index: SequenceNumber,
207}
208
209/// A read-only view into the device sharing state of an
210/// [`OutboundGroupSession`].
211pub(crate) struct SharingView<'a> {
212    shared_with_set: RwLockReadGuard<'a, ShareInfoSet>,
213    to_share_with_set: RwLockReadGuard<'a, ToShareMap>,
214}
215
216impl SharingView<'_> {
217    /// Has the session been shared with the given user/device pair (or if not,
218    /// is there such a request pending).
219    pub(crate) fn get_share_state(&self, device: &DeviceData) -> ShareState {
220        self.iter_shares(Some(device.user_id()), Some(device.device_id()))
221            .map(|(_, _, info)| match info {
222                ShareInfo::Shared(info) => {
223                    if device.curve25519_key() == Some(info.sender_key) {
224                        ShareState::Shared {
225                            message_index: info.message_index,
226                            olm_wedging_index: info.olm_wedging_index,
227                        }
228                    } else {
229                        ShareState::SharedButChangedSenderKey
230                    }
231                }
232                ShareInfo::Withheld(_) => ShareState::NotShared,
233            })
234            // Return the most "definitive" ShareState found (in case there
235            // are multiple entries for the same device).
236            .max()
237            .unwrap_or(ShareState::NotShared)
238    }
239
240    /// Has the session been withheld for the given user/device pair (or if not,
241    /// is there such a request pending).
242    pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
243        self.iter_shares(Some(device.user_id()), Some(device.device_id()))
244            .any(|(_, _, info)| matches!(info, ShareInfo::Withheld(c) if c == code))
245    }
246
247    /// Enumerate all sent or pending sharing requests for the given device (or
248    /// for all devices if not specified).  This can yield the same device
249    /// multiple times.
250    pub(crate) fn iter_shares<'b, 'c>(
251        &self,
252        user_id: Option<&'b UserId>,
253        device_id: Option<&'c DeviceId>,
254    ) -> impl Iterator<Item = (&UserId, &DeviceId, &ShareInfo)> + use<'_, 'b, 'c> {
255        fn iter_share_info_set<'a, 'b, 'c>(
256            set: &'a ShareInfoSet,
257            user_ids: (Bound<&'b UserId>, Bound<&'b UserId>),
258            device_ids: (Bound<&'c DeviceId>, Bound<&'c DeviceId>),
259        ) -> impl Iterator<Item = (&'a UserId, &'a DeviceId, &'a ShareInfo)> + use<'a, 'b, 'c>
260        {
261            set.range::<UserId, _>(user_ids).flat_map(move |(uid, d)| {
262                d.range::<DeviceId, _>(device_ids)
263                    .map(|(id, info)| (uid.as_ref(), id.as_ref(), info))
264            })
265        }
266
267        let user_ids = user_id
268            .map(|u| (Bound::Included(u), Bound::Included(u)))
269            .unwrap_or((Bound::Unbounded, Bound::Unbounded));
270        let device_ids = device_id
271            .map(|d| (Bound::Included(d), Bound::Included(d)))
272            .unwrap_or((Bound::Unbounded, Bound::Unbounded));
273
274        let already_shared = iter_share_info_set(&self.shared_with_set, user_ids, device_ids);
275        let pending = self
276            .to_share_with_set
277            .values()
278            .flat_map(move |(_, set)| iter_share_info_set(set, user_ids, device_ids));
279        already_shared.chain(pending)
280    }
281
282    /// Enumerate all users that have received the session, or have pending
283    /// requests to receive it.  This can yield the same user multiple times,
284    /// so you may want to `collect()` the result into a `BTreeSet`.
285    pub(crate) fn shared_with_users(&self) -> impl Iterator<Item = &UserId> {
286        self.iter_shares(None, None).filter_map(|(u, _, info)| match info {
287            ShareInfo::Shared(_) => Some(u),
288            ShareInfo::Withheld(_) => None,
289        })
290    }
291}
292
293impl OutboundGroupSession {
294    pub(super) fn session_config(
295        algorithm: &EventEncryptionAlgorithm,
296    ) -> Result<SessionConfig, SessionCreationError> {
297        match algorithm {
298            EventEncryptionAlgorithm::MegolmV1AesSha2 => Ok(SessionConfig::version_1()),
299            #[cfg(feature = "experimental-algorithms")]
300            EventEncryptionAlgorithm::MegolmV2AesSha2 => Ok(SessionConfig::version_2()),
301            _ => Err(SessionCreationError::Algorithm(algorithm.to_owned())),
302        }
303    }
304
305    /// Create a new outbound group session for the given room.
306    ///
307    /// Outbound group sessions are used to encrypt room messages.
308    ///
309    /// # Arguments
310    ///
311    /// * `device_id` - The id of the device that created this session.
312    ///
313    /// * `identity_keys` - The identity keys of the account that created this
314    ///   session.
315    ///
316    /// * `room_id` - The id of the room that the session is used in.
317    ///
318    /// * `settings` - Settings determining the algorithm and rotation period of
319    ///   the outbound group session.
320    pub fn new(
321        device_id: OwnedDeviceId,
322        identity_keys: Arc<IdentityKeys>,
323        room_id: &RoomId,
324        settings: EncryptionSettings,
325    ) -> Result<Self, SessionCreationError> {
326        let config = Self::session_config(&settings.algorithm)?;
327
328        let session = GroupSession::new(config);
329        let session_id = session.session_id();
330
331        Ok(OutboundGroupSession {
332            inner: RwLock::new(session).into(),
333            room_id: room_id.into(),
334            device_id,
335            account_identity_keys: identity_keys,
336            session_id: session_id.into(),
337            creation_time: SecondsSinceUnixEpoch::now(),
338            message_count: Arc::new(AtomicU64::new(0)),
339            shared: Arc::new(AtomicBool::new(false)),
340            invalidated: Arc::new(AtomicBool::new(false)),
341            settings: Arc::new(settings),
342            shared_with_set: Default::default(),
343            to_share_with_set: Default::default(),
344        })
345    }
346
347    /// Add a to-device request that is sending the session key (or room key)
348    /// belonging to this [`OutboundGroupSession`] to other members of the
349    /// group.
350    ///
351    /// The request will get persisted with the session which allows seamless
352    /// session reuse across application restarts.
353    ///
354    /// **Warning** this method is only exposed to be used in integration tests
355    /// of crypto-store implementations. **Do not use this outside of tests**.
356    pub fn add_request(
357        &self,
358        request_id: OwnedTransactionId,
359        request: Arc<ToDeviceRequest>,
360        share_infos: ShareInfoSet,
361    ) {
362        self.to_share_with_set.write().insert(request_id, (request, share_infos));
363    }
364
365    /// Create a new `m.room_key.withheld` event content with the given code for
366    /// this outbound group session.
367    pub fn withheld_code(&self, code: WithheldCode) -> RoomKeyWithheldContent {
368        RoomKeyWithheldContent::new(
369            self.settings().algorithm.to_owned(),
370            code,
371            self.room_id().to_owned(),
372            self.session_id().to_owned(),
373            self.sender_key().to_owned(),
374            (*self.device_id).to_owned(),
375        )
376    }
377
378    /// This should be called if an the user wishes to rotate this session.
379    pub fn invalidate_session(&self) {
380        self.invalidated.store(true, Ordering::Relaxed)
381    }
382
383    /// Get the encryption settings of this outbound session.
384    pub fn settings(&self) -> &EncryptionSettings {
385        &self.settings
386    }
387
388    /// Mark the request with the given request id as sent.
389    ///
390    /// This removes the request from the queue and marks the set of
391    /// users/devices that received the session.
392    pub fn mark_request_as_sent(
393        &self,
394        request_id: &TransactionId,
395    ) -> BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>> {
396        let mut no_olm_devices = BTreeMap::new();
397
398        let removed = self.to_share_with_set.write().remove(request_id);
399        if let Some((to_device, request)) = removed {
400            let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request
401                .iter()
402                .map(|(u, d)| (u.as_ref(), d.keys().map(|d| d.as_ref()).collect()))
403                .collect();
404
405            info!(
406                ?request_id,
407                ?recipients,
408                ?to_device.event_type,
409                "Marking to-device request carrying a room key or a withheld as sent"
410            );
411
412            for (user_id, info) in request {
413                let no_olms: BTreeSet<OwnedDeviceId> = info
414                    .iter()
415                    .filter(|(_, info)| matches!(info, ShareInfo::Withheld(WithheldCode::NoOlm)))
416                    .map(|(d, _)| d.to_owned())
417                    .collect();
418                no_olm_devices.insert(user_id.to_owned(), no_olms);
419
420                self.shared_with_set.write().entry(user_id).or_default().extend(info);
421            }
422
423            if self.to_share_with_set.read().is_empty() {
424                debug!(
425                    session_id = self.session_id(),
426                    room_id = ?self.room_id,
427                    "All m.room_key and withheld to-device requests were sent out, marking \
428                     session as shared.",
429                );
430
431                self.mark_as_shared();
432            }
433        } else {
434            let request_ids: Vec<String> =
435                self.to_share_with_set.read().keys().map(|k| k.to_string()).collect();
436
437            error!(
438                all_request_ids = ?request_ids,
439                request_id = ?request_id,
440                "Marking to-device request carrying a room key as sent but no \
441                 request found with the given id"
442            );
443        }
444
445        no_olm_devices
446    }
447
448    /// Encrypt the given plaintext using this session.
449    ///
450    /// Returns the encrypted ciphertext.
451    ///
452    /// # Arguments
453    ///
454    /// * `plaintext` - The plaintext that should be encrypted.
455    pub(crate) async fn encrypt_helper(&self, plaintext: String) -> MegolmMessage {
456        let mut session = self.inner.write().await;
457        self.message_count.fetch_add(1, Ordering::SeqCst);
458        session.encrypt(&plaintext)
459    }
460
461    /// Encrypt a room message for the given room.
462    ///
463    /// Beware that a room key needs to be shared before this method
464    /// can be called using the `share_room_key()` method.
465    ///
466    /// # Arguments
467    ///
468    /// * `event_type` - The plaintext type of the event, the outer type of the
469    ///   event will become `m.room.encrypted`.
470    ///
471    /// * `content` - The plaintext content of the message that should be
472    ///   encrypted in raw JSON form.
473    ///
474    /// # Panics
475    ///
476    /// Panics if the content can't be serialized.
477    pub async fn encrypt(
478        &self,
479        event_type: &str,
480        content: &Raw<AnyMessageLikeEventContent>,
481    ) -> Raw<RoomEncryptedEventContent> {
482        #[derive(Serialize)]
483        struct Payload<'a> {
484            #[serde(rename = "type")]
485            event_type: &'a str,
486            content: &'a Raw<AnyMessageLikeEventContent>,
487            room_id: &'a RoomId,
488        }
489
490        let payload = Payload { event_type, content, room_id: &self.room_id };
491        let payload_json =
492            serde_json::to_string(&payload).expect("payload serialization never fails");
493
494        let relates_to = content
495            .get_field::<serde_json::Value>("m.relates_to")
496            .expect("serde_json::Value deserialization with valid JSON input never fails");
497
498        let ciphertext = self.encrypt_helper(payload_json).await;
499        let scheme: RoomEventEncryptionScheme = match self.settings.algorithm {
500            EventEncryptionAlgorithm::MegolmV1AesSha2 => MegolmV1AesSha2Content {
501                ciphertext,
502                sender_key: self.account_identity_keys.curve25519,
503                session_id: self.session_id().to_owned(),
504                device_id: (*self.device_id).to_owned(),
505            }
506            .into(),
507            #[cfg(feature = "experimental-algorithms")]
508            EventEncryptionAlgorithm::MegolmV2AesSha2 => {
509                MegolmV2AesSha2Content { ciphertext, session_id: self.session_id().to_owned() }
510                    .into()
511            }
512            _ => unreachable!(
513                "An outbound group session is always using one of the supported algorithms"
514            ),
515        };
516
517        let content = RoomEncryptedEventContent { scheme, relates_to, other: Default::default() };
518
519        Raw::new(&content).expect("m.room.encrypted event content can always be serialized")
520    }
521
522    fn elapsed(&self) -> bool {
523        let creation_time = Duration::from_secs(self.creation_time.get().into());
524        let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
525        now.checked_sub(creation_time)
526            .map(|elapsed| elapsed >= self.safe_rotation_period())
527            .unwrap_or(true)
528    }
529
530    /// Returns the rotation_period_ms that was set for this session, clamped
531    /// to be no less than one hour.
532    ///
533    /// This is to prevent a malicious or careless user causing sessions to be
534    /// rotated very frequently.
535    ///
536    /// The feature flag `_disable-minimum-rotation-period-ms` can
537    /// be used to prevent this behaviour (which can be useful for tests).
538    fn safe_rotation_period(&self) -> Duration {
539        if cfg!(feature = "_disable-minimum-rotation-period-ms") {
540            self.settings.rotation_period
541        } else {
542            max(self.settings.rotation_period, ONE_HOUR)
543        }
544    }
545
546    /// Check if the session has expired and if it should be rotated.
547    ///
548    /// A session will expire after some time or if enough messages have been
549    /// encrypted using it.
550    pub fn expired(&self) -> bool {
551        let count = self.message_count.load(Ordering::SeqCst);
552        // We clamp the rotation period for message counts to be between 1 and
553        // 10000. The Megolm session should be usable for at least 1 message,
554        // and at most 10000 messages. Realistically Megolm uses u32 for it's
555        // internal counter and one could use the Megolm session for up to
556        // u32::MAX messages, but we're staying on the safe side of things.
557        let rotation_period_msgs = self.settings.rotation_period_msgs.clamp(1, 10_000);
558
559        count >= rotation_period_msgs || self.elapsed()
560    }
561
562    /// Has the session been invalidated.
563    pub fn invalidated(&self) -> bool {
564        self.invalidated.load(Ordering::Relaxed)
565    }
566
567    /// Mark the session as shared.
568    ///
569    /// Messages shouldn't be encrypted with the session before it has been
570    /// shared.
571    pub fn mark_as_shared(&self) {
572        self.shared.store(true, Ordering::Relaxed);
573    }
574
575    /// Check if the session has been marked as shared.
576    pub fn shared(&self) -> bool {
577        self.shared.load(Ordering::Relaxed)
578    }
579
580    /// Get the session key of this session.
581    ///
582    /// A session key can be used to to create an `InboundGroupSession`.
583    pub async fn session_key(&self) -> SessionKey {
584        let session = self.inner.read().await;
585        session.session_key()
586    }
587
588    /// Gets the Sender Key
589    pub fn sender_key(&self) -> Curve25519PublicKey {
590        self.account_identity_keys.as_ref().curve25519.to_owned()
591    }
592
593    /// Get the room id of the room this session belongs to.
594    pub fn room_id(&self) -> &RoomId {
595        &self.room_id
596    }
597
598    /// Returns the unique identifier for this session.
599    pub fn session_id(&self) -> &str {
600        &self.session_id
601    }
602
603    /// Get the current message index for this session.
604    ///
605    /// Each message is sent with an increasing index. This returns the
606    /// message index that will be used for the next encrypted message.
607    pub async fn message_index(&self) -> u32 {
608        let session = self.inner.read().await;
609        session.message_index()
610    }
611
612    pub(crate) async fn as_content(&self) -> RoomKeyContent {
613        let session_key = self.session_key().await;
614        let shared_history =
615            shared_history_from_history_visibility(&self.settings.history_visibility);
616
617        RoomKeyContent::MegolmV1AesSha2(
618            MegolmV1AesSha2RoomKeyContent::new(
619                self.room_id().to_owned(),
620                self.session_id().to_owned(),
621                session_key,
622                shared_history,
623            )
624            .into(),
625        )
626    }
627
628    /// Create a read-only view into the device sharing state of this session.
629    /// This view includes pending requests, so it is not guaranteed that the
630    /// represented state has been fully propagated yet.
631    pub(crate) fn sharing_view(&self) -> SharingView<'_> {
632        SharingView {
633            shared_with_set: self.shared_with_set.read(),
634            to_share_with_set: self.to_share_with_set.read(),
635        }
636    }
637
638    /// Mark the session as shared with the given user/device pair, starting
639    /// from some message index.
640    #[cfg(test)]
641    pub fn mark_shared_with_from_index(
642        &self,
643        user_id: &UserId,
644        device_id: &DeviceId,
645        sender_key: Curve25519PublicKey,
646        index: u32,
647    ) {
648        self.shared_with_set.write().entry(user_id.to_owned()).or_default().insert(
649            device_id.to_owned(),
650            ShareInfo::new_shared(sender_key, index, Default::default()),
651        );
652    }
653
654    /// Mark the session as shared with the given user/device pair, starting
655    /// from the current index.
656    #[cfg(test)]
657    pub async fn mark_shared_with(
658        &self,
659        user_id: &UserId,
660        device_id: &DeviceId,
661        sender_key: Curve25519PublicKey,
662    ) {
663        let share_info =
664            ShareInfo::new_shared(sender_key, self.message_index().await, Default::default());
665        self.shared_with_set
666            .write()
667            .entry(user_id.to_owned())
668            .or_default()
669            .insert(device_id.to_owned(), share_info);
670    }
671
672    /// Get the list of requests that need to be sent out for this session to be
673    /// marked as shared.
674    pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
675        self.to_share_with_set.read().values().map(|(req, _)| req.clone()).collect()
676    }
677
678    /// Get the list of request ids this session is waiting for to be sent out.
679    pub(crate) fn pending_request_ids(&self) -> Vec<OwnedTransactionId> {
680        self.to_share_with_set.read().keys().cloned().collect()
681    }
682
683    /// Restore a Session from a previously pickled string.
684    ///
685    /// Returns the restored group session or a `OlmGroupSessionError` if there
686    /// was an error.
687    ///
688    /// # Arguments
689    ///
690    /// * `device_id` - The device ID of the device that created this session.
691    ///   Put differently, our own device ID.
692    ///
693    /// * `identity_keys` - The identity keys of the device that created this
694    ///   session, our own identity keys.
695    ///
696    /// * `pickle` - The pickled version of the `OutboundGroupSession`.
697    ///
698    /// * `pickle_mode` - The mode that was used to pickle the session, either
699    ///   an unencrypted mode or an encrypted using passphrase.
700    pub fn from_pickle(
701        device_id: OwnedDeviceId,
702        identity_keys: Arc<IdentityKeys>,
703        pickle: PickledOutboundGroupSession,
704    ) -> Result<Self, PickleError> {
705        let inner: GroupSession = pickle.pickle.into();
706        let session_id = inner.session_id();
707
708        Ok(Self {
709            inner: Arc::new(RwLock::new(inner)),
710            device_id,
711            account_identity_keys: identity_keys,
712            session_id: session_id.into(),
713            room_id: pickle.room_id,
714            creation_time: pickle.creation_time,
715            message_count: AtomicU64::from(pickle.message_count).into(),
716            shared: AtomicBool::from(pickle.shared).into(),
717            invalidated: AtomicBool::from(pickle.invalidated).into(),
718            settings: pickle.settings,
719            shared_with_set: Arc::new(StdRwLock::new(pickle.shared_with_set)),
720            to_share_with_set: Arc::new(StdRwLock::new(pickle.requests)),
721        })
722    }
723
724    /// Store the group session as a base64 encoded string and associated data
725    /// belonging to the session.
726    ///
727    /// # Arguments
728    ///
729    /// * `pickle_mode` - The mode that should be used to pickle the group
730    ///   session, either an unencrypted mode or an encrypted using passphrase.
731    pub async fn pickle(&self) -> PickledOutboundGroupSession {
732        let pickle = self.inner.read().await.pickle();
733
734        PickledOutboundGroupSession {
735            pickle,
736            room_id: self.room_id.clone(),
737            settings: self.settings.clone(),
738            creation_time: self.creation_time,
739            message_count: self.message_count.load(Ordering::SeqCst),
740            shared: self.shared(),
741            invalidated: self.invalidated(),
742            shared_with_set: self.shared_with_set.read().clone(),
743            requests: self.to_share_with_set.read().clone(),
744        }
745    }
746}
747
748#[derive(Clone, Debug, Serialize, Deserialize)]
749pub struct OutboundGroupSessionPickle(String);
750
751impl From<String> for OutboundGroupSessionPickle {
752    fn from(p: String) -> Self {
753        Self(p)
754    }
755}
756
757#[cfg(not(tarpaulin_include))]
758impl fmt::Debug for OutboundGroupSession {
759    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
760        f.debug_struct("OutboundGroupSession")
761            .field("session_id", &self.session_id)
762            .field("room_id", &self.room_id)
763            .field("creation_time", &self.creation_time)
764            .field("message_count", &self.message_count)
765            .finish()
766    }
767}
768
769/// A pickled version of an `InboundGroupSession`.
770///
771/// Holds all the information that needs to be stored in a database to restore
772/// an InboundGroupSession.
773#[derive(Deserialize, Serialize)]
774#[allow(missing_debug_implementations)]
775pub struct PickledOutboundGroupSession {
776    /// The pickle string holding the OutboundGroupSession.
777    pub pickle: GroupSessionPickle,
778    /// The settings this session adheres to.
779    pub settings: Arc<EncryptionSettings>,
780    /// The room id this session is used for.
781    pub room_id: OwnedRoomId,
782    /// The timestamp when this session was created.
783    pub creation_time: SecondsSinceUnixEpoch,
784    /// The number of messages this session has already encrypted.
785    pub message_count: u64,
786    /// Is the session shared.
787    pub shared: bool,
788    /// Has the session been invalidated.
789    pub invalidated: bool,
790    /// The set of users the session has been already shared with.
791    pub shared_with_set: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
792    /// Requests that need to be sent out to share the session.
793    pub requests: BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>,
794}
795
796#[cfg(test)]
797mod tests {
798    use std::time::Duration;
799
800    use ruma::{
801        events::room::{
802            encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility,
803        },
804        uint, EventEncryptionAlgorithm,
805    };
806
807    use super::{EncryptionSettings, ShareState, ROTATION_MESSAGES, ROTATION_PERIOD};
808    use crate::CollectStrategy;
809
810    #[test]
811    fn test_encryption_settings_conversion() {
812        let mut content =
813            RoomEncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2);
814        let settings = EncryptionSettings::new(
815            content.clone(),
816            HistoryVisibility::Joined,
817            CollectStrategy::AllDevices,
818        );
819
820        assert_eq!(settings.rotation_period, ROTATION_PERIOD);
821        assert_eq!(settings.rotation_period_msgs, ROTATION_MESSAGES);
822
823        content.rotation_period_ms = Some(uint!(3600));
824        content.rotation_period_msgs = Some(uint!(500));
825
826        let settings = EncryptionSettings::new(
827            content,
828            HistoryVisibility::Shared,
829            CollectStrategy::AllDevices,
830        );
831
832        assert_eq!(settings.rotation_period, Duration::from_millis(3600));
833        assert_eq!(settings.rotation_period_msgs, 500);
834    }
835
836    /// Ensure that the `ShareState` PartialOrd instance orders according to
837    /// specificity of the value.
838    #[test]
839    fn test_share_state_ordering() {
840        let values = [
841            ShareState::NotShared,
842            ShareState::SharedButChangedSenderKey,
843            ShareState::Shared { message_index: 1, olm_wedging_index: Default::default() },
844        ];
845        // Make sure our test case of possible variants is exhaustive
846        match values[0] {
847            ShareState::NotShared
848            | ShareState::SharedButChangedSenderKey
849            | ShareState::Shared { .. } => {}
850        }
851        assert!(values.is_sorted());
852    }
853
854    #[cfg(any(target_os = "linux", target_os = "macos", target_arch = "wasm32"))]
855    mod expiration {
856        use std::{sync::atomic::Ordering, time::Duration};
857
858        use matrix_sdk_test::async_test;
859        use ruma::{
860            device_id, events::room::message::RoomMessageEventContent, room_id, serde::Raw, uint,
861            user_id, SecondsSinceUnixEpoch,
862        };
863
864        use crate::{
865            olm::{OutboundGroupSession, SenderData},
866            Account, EncryptionSettings, MegolmError,
867        };
868
869        const TWO_HOURS: Duration = Duration::from_secs(60 * 60 * 2);
870
871        #[async_test]
872        async fn test_session_is_not_expired_if_no_messages_sent_and_no_time_passed() {
873            // Given a session that expires after one message
874            let session = create_session(EncryptionSettings {
875                rotation_period_msgs: 1,
876                ..Default::default()
877            })
878            .await;
879
880            // When we send no messages at all
881
882            // Then it is not expired
883            assert!(!session.expired());
884        }
885
886        #[async_test]
887        async fn test_session_is_expired_if_we_rotate_every_message_and_one_was_sent(
888        ) -> Result<(), MegolmError> {
889            // Given a session that expires after one message
890            let session = create_session(EncryptionSettings {
891                rotation_period_msgs: 1,
892                ..Default::default()
893            })
894            .await;
895
896            // When we send a message
897            let _ = session
898                .encrypt(
899                    "m.room.message",
900                    &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
901                )
902                .await;
903
904            // Then the session is expired
905            assert!(session.expired());
906
907            Ok(())
908        }
909
910        #[async_test]
911        async fn test_session_with_rotation_period_is_not_expired_after_no_time() {
912            // Given a session with a 2h expiration
913            let session = create_session(EncryptionSettings {
914                rotation_period: TWO_HOURS,
915                ..Default::default()
916            })
917            .await;
918
919            // When we don't allow any time to pass
920
921            // Then it is not expired
922            assert!(!session.expired());
923        }
924
925        #[async_test]
926        async fn test_session_is_expired_after_rotation_period() {
927            // Given a session with a 2h expiration
928            let mut session = create_session(EncryptionSettings {
929                rotation_period: TWO_HOURS,
930                ..Default::default()
931            })
932            .await;
933
934            // When 3 hours have passed
935            let now = SecondsSinceUnixEpoch::now();
936            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(10800));
937
938            // Then the session is expired
939            assert!(session.expired());
940        }
941
942        #[async_test]
943        #[cfg(not(feature = "_disable-minimum-rotation-period-ms"))]
944        async fn test_session_does_not_expire_under_one_hour_even_if_we_ask_for_shorter() {
945            // Given a session with a 100ms expiration
946            let mut session = create_session(EncryptionSettings {
947                rotation_period: Duration::from_millis(100),
948                ..Default::default()
949            })
950            .await;
951
952            // When less than an hour has passed
953            let now = SecondsSinceUnixEpoch::now();
954            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
955
956            // Then the session is not expired: we enforce a minimum of 1 hour
957            assert!(!session.expired());
958
959            // But when more than an hour has passed
960            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(3601));
961
962            // Then the session is expired
963            assert!(session.expired());
964        }
965
966        #[async_test]
967        #[cfg(feature = "_disable-minimum-rotation-period-ms")]
968        async fn test_with_disable_minrotperiod_feature_sessions_can_expire_quickly() {
969            // Given a session with a 100ms expiration
970            let mut session = create_session(EncryptionSettings {
971                rotation_period: Duration::from_millis(100),
972                ..Default::default()
973            })
974            .await;
975
976            // When less than an hour has passed
977            let now = SecondsSinceUnixEpoch::now();
978            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
979
980            // Then the session is expired: the feature flag has prevented us enforcing a
981            // minimum
982            assert!(session.expired());
983        }
984
985        #[async_test]
986        async fn test_session_with_zero_msgs_rotation_is_not_expired_initially() {
987            // Given a session that is supposed to expire after zero messages
988            let session = create_session(EncryptionSettings {
989                rotation_period_msgs: 0,
990                ..Default::default()
991            })
992            .await;
993
994            // When we send no messages
995
996            // Then the session is not expired: we are protected against this nonsensical
997            // setup
998            assert!(!session.expired());
999        }
1000
1001        #[async_test]
1002        async fn test_session_with_zero_msgs_rotation_expires_after_one_message(
1003        ) -> Result<(), MegolmError> {
1004            // Given a session that is supposed to expire after zero messages
1005            let session = create_session(EncryptionSettings {
1006                rotation_period_msgs: 0,
1007                ..Default::default()
1008            })
1009            .await;
1010
1011            // When we send a message
1012            let _ = session
1013                .encrypt(
1014                    "m.room.message",
1015                    &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
1016                )
1017                .await;
1018
1019            // Then the session is expired: we treated rotation_period_msgs=0 as if it were
1020            // =1
1021            assert!(session.expired());
1022
1023            Ok(())
1024        }
1025
1026        #[async_test]
1027        async fn test_session_expires_after_10k_messages_even_if_we_ask_for_more() {
1028            // Given we asked to expire after 100K messages
1029            let session = create_session(EncryptionSettings {
1030                rotation_period_msgs: 100_000,
1031                ..Default::default()
1032            })
1033            .await;
1034
1035            // Sanity: it does not expire after <10K messages
1036            assert!(!session.expired());
1037            session.message_count.store(1000, Ordering::SeqCst);
1038            assert!(!session.expired());
1039            session.message_count.store(9999, Ordering::SeqCst);
1040            assert!(!session.expired());
1041
1042            // When we have sent >= 10K messages
1043            session.message_count.store(10_000, Ordering::SeqCst);
1044
1045            // Then it is considered expired: we enforce a maximum of 10K messages before
1046            // rotation.
1047            assert!(session.expired());
1048        }
1049
1050        async fn create_session(settings: EncryptionSettings) -> OutboundGroupSession {
1051            let account =
1052                Account::with_device_id(user_id!("@alice:example.org"), device_id!("DEVICEID"))
1053                    .static_data;
1054            let (session, _) = account
1055                .create_group_session_pair(
1056                    room_id!("!test_room:example.org"),
1057                    settings,
1058                    SenderData::unknown(),
1059                )
1060                .await
1061                .unwrap();
1062            session
1063        }
1064    }
1065}