1use std::{
42 collections::{BTreeMap, BTreeSet, HashMap, HashSet},
43 fmt::Debug,
44 ops::Deref,
45 pin::pin,
46 sync::{atomic::Ordering, Arc},
47 time::Duration,
48};
49
50use as_variant::as_variant;
51use futures_core::Stream;
52use futures_util::StreamExt;
53use matrix_sdk_common::locks::RwLock as StdRwLock;
54use ruma::{
55 encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId,
56 OwnedRoomId, OwnedUserId, RoomId, UserId,
57};
58use serde::{de::DeserializeOwned, Deserialize, Serialize};
59use thiserror::Error;
60use tokio::sync::{Mutex, MutexGuard, Notify, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
61use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
62use tracing::{info, warn};
63use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey};
64use zeroize::{Zeroize, ZeroizeOnDrop};
65
66#[cfg(doc)]
67use crate::{backups::BackupMachine, identities::OwnUserIdentity};
68use crate::{
69 gossiping::GossippedSecret,
70 identities::{user::UserIdentity, Device, DeviceData, UserDevices, UserIdentityData},
71 olm::{
72 Account, ExportedRoomKey, InboundGroupSession, OlmMessageHash, OutboundGroupSession,
73 PrivateCrossSigningIdentity, SenderData, Session, StaticAccountData,
74 },
75 types::{
76 events::room_key_withheld::RoomKeyWithheldEvent, BackupSecrets, CrossSigningSecrets,
77 EventEncryptionAlgorithm, MegolmBackupV1Curve25519AesSha2Secrets, RoomKeyExport,
78 SecretsBundle,
79 },
80 verification::VerificationMachine,
81 CrossSigningStatus, OwnUserIdentityData, RoomKeyImportResult,
82};
83
84pub mod caches;
85mod crypto_store_wrapper;
86mod error;
87mod memorystore;
88mod traits;
89
90#[cfg(any(test, feature = "testing"))]
91#[macro_use]
92#[allow(missing_docs)]
93pub mod integration_tests;
94
95use caches::{SequenceNumber, UsersForKeyQuery};
96pub(crate) use crypto_store_wrapper::CryptoStoreWrapper;
97pub use error::{CryptoStoreError, Result};
98use matrix_sdk_common::{
99 deserialized_responses::WithheldCode, store_locks::CrossProcessStoreLock, timeout::timeout,
100};
101pub use memorystore::MemoryStore;
102pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore};
103
104use crate::types::{
105 events::{room_key_bundle::RoomKeyBundleContent, room_key_withheld::RoomKeyWithheldContent},
106 room_history::RoomKeyBundle,
107};
108pub use crate::{
109 dehydrated_devices::DehydrationError,
110 gossiping::{GossipRequest, SecretInfo},
111};
112
113#[derive(Debug, Clone)]
120pub struct Store {
121 inner: Arc<StoreInner>,
122}
123
124#[derive(Debug, Default)]
125pub(crate) struct KeyQueryManager {
126 users_for_key_query: Mutex<UsersForKeyQuery>,
128
129 users_for_key_query_notify: Notify,
131}
132
133impl KeyQueryManager {
134 pub async fn synced<'a>(&'a self, cache: &'a StoreCache) -> Result<SyncedKeyQueryManager<'a>> {
135 self.ensure_sync_tracked_users(cache).await?;
136 Ok(SyncedKeyQueryManager { cache, manager: self })
137 }
138
139 async fn ensure_sync_tracked_users(&self, cache: &StoreCache) -> Result<()> {
146 let loaded = cache.loaded_tracked_users.read().await;
148 if *loaded {
149 return Ok(());
150 }
151
152 drop(loaded);
154 let mut loaded = cache.loaded_tracked_users.write().await;
155
156 if *loaded {
160 return Ok(());
161 }
162
163 let tracked_users = cache.store.load_tracked_users().await?;
164
165 let mut query_users_lock = self.users_for_key_query.lock().await;
166 let mut tracked_users_cache = cache.tracked_users.write();
167 for user in tracked_users {
168 tracked_users_cache.insert(user.user_id.to_owned());
169
170 if user.dirty {
171 query_users_lock.insert_user(&user.user_id);
172 }
173 }
174
175 *loaded = true;
176
177 Ok(())
178 }
179
180 pub async fn wait_if_user_key_query_pending(
190 &self,
191 cache: StoreCacheGuard,
192 timeout_duration: Duration,
193 user: &UserId,
194 ) -> Result<UserKeyQueryResult> {
195 {
196 self.ensure_sync_tracked_users(&cache).await?;
199 drop(cache);
200 }
201
202 let mut users_for_key_query = self.users_for_key_query.lock().await;
203 let Some(waiter) = users_for_key_query.maybe_register_waiting_task(user) else {
204 return Ok(UserKeyQueryResult::WasNotPending);
205 };
206
207 let wait_for_completion = async {
208 while !waiter.completed.load(Ordering::Relaxed) {
209 let mut notified = pin!(self.users_for_key_query_notify.notified());
213 notified.as_mut().enable();
214 drop(users_for_key_query);
215
216 notified.await;
218
219 users_for_key_query = self.users_for_key_query.lock().await;
223 }
224 };
225
226 match timeout(Box::pin(wait_for_completion), timeout_duration).await {
227 Err(_) => {
228 warn!(
229 user_id = ?user,
230 "The user has a pending `/keys/query` request which did \
231 not finish yet, some devices might be missing."
232 );
233
234 Ok(UserKeyQueryResult::TimeoutExpired)
235 }
236 _ => Ok(UserKeyQueryResult::WasPending),
237 }
238 }
239}
240
241pub(crate) struct SyncedKeyQueryManager<'a> {
242 cache: &'a StoreCache,
243 manager: &'a KeyQueryManager,
244}
245
246impl SyncedKeyQueryManager<'_> {
247 pub async fn update_tracked_users(&self, users: impl Iterator<Item = &UserId>) -> Result<()> {
252 let mut store_updates = Vec::new();
253 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
254
255 {
256 let mut tracked_users = self.cache.tracked_users.write();
257 for user_id in users {
258 if tracked_users.insert(user_id.to_owned()) {
259 key_query_lock.insert_user(user_id);
260 store_updates.push((user_id, true))
261 }
262 }
263 }
264
265 self.cache.store.save_tracked_users(&store_updates).await
266 }
267
268 pub async fn mark_tracked_users_as_changed(
275 &self,
276 users: impl Iterator<Item = &UserId>,
277 ) -> Result<()> {
278 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
279 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
280
281 {
282 let tracked_users = &self.cache.tracked_users.read();
283 for user_id in users {
284 if tracked_users.contains(user_id) {
285 key_query_lock.insert_user(user_id);
286 store_updates.push((user_id, true));
287 }
288 }
289 }
290
291 self.cache.store.save_tracked_users(&store_updates).await
292 }
293
294 pub async fn mark_tracked_users_as_up_to_date(
300 &self,
301 users: impl Iterator<Item = &UserId>,
302 sequence_number: SequenceNumber,
303 ) -> Result<()> {
304 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
305 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
306
307 {
308 let tracked_users = self.cache.tracked_users.read();
309 for user_id in users {
310 if tracked_users.contains(user_id) {
311 let clean = key_query_lock.maybe_remove_user(user_id, sequence_number);
312 store_updates.push((user_id, !clean));
313 }
314 }
315 }
316
317 self.cache.store.save_tracked_users(&store_updates).await?;
318 self.manager.users_for_key_query_notify.notify_waiters();
320
321 Ok(())
322 }
323
324 pub async fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
336 self.manager.users_for_key_query.lock().await.users_for_key_query()
337 }
338
339 pub fn tracked_users(&self) -> HashSet<OwnedUserId> {
341 self.cache.tracked_users.read().iter().cloned().collect()
342 }
343
344 pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
350 self.manager.users_for_key_query.lock().await.insert_user(user);
351 self.cache.tracked_users.write().insert(user.to_owned());
352
353 self.cache.store.save_tracked_users(&[(user, true)]).await
354 }
355}
356
357#[derive(Debug)]
358pub(crate) struct StoreCache {
359 store: Arc<CryptoStoreWrapper>,
360 tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
361 loaded_tracked_users: RwLock<bool>,
362 account: Mutex<Option<Account>>,
363}
364
365impl StoreCache {
366 pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
367 self.store.as_ref()
368 }
369
370 async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
382 let mut guard = self.account.lock().await;
383 if guard.is_some() {
384 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
385 } else {
386 match self.store.load_account().await? {
387 Some(account) => {
388 *guard = Some(account);
389 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
390 }
391 None => Err(CryptoStoreError::AccountUnset),
392 }
393 }
394 }
395}
396
397pub(crate) struct StoreCacheGuard {
403 cache: OwnedRwLockReadGuard<StoreCache>,
404 }
406
407impl StoreCacheGuard {
408 pub async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
416 self.cache.account().await
417 }
418}
419
420impl Deref for StoreCacheGuard {
421 type Target = StoreCache;
422
423 fn deref(&self) -> &Self::Target {
424 &self.cache
425 }
426}
427
428#[allow(missing_debug_implementations)]
430pub struct StoreTransaction {
431 store: Store,
432 changes: PendingChanges,
433 cache: OwnedRwLockWriteGuard<StoreCache>,
435}
436
437impl StoreTransaction {
438 async fn new(store: Store) -> Self {
440 let cache = store.inner.cache.clone();
441
442 Self { store, changes: PendingChanges::default(), cache: cache.clone().write_owned().await }
443 }
444
445 pub(crate) fn cache(&self) -> &StoreCache {
446 &self.cache
447 }
448
449 pub fn store(&self) -> &Store {
451 &self.store
452 }
453
454 pub async fn account(&mut self) -> Result<&mut Account> {
461 if self.changes.account.is_none() {
462 let _ = self.cache.account().await?;
464 self.changes.account = self.cache.account.lock().await.take();
465 }
466 Ok(self.changes.account.as_mut().unwrap())
467 }
468
469 pub async fn commit(self) -> Result<()> {
472 if self.changes.is_empty() {
473 return Ok(());
474 }
475
476 let account = self.changes.account.as_ref().map(|acc| acc.deep_clone());
478
479 self.store.save_pending_changes(self.changes).await?;
480
481 if let Some(account) = account {
483 *self.cache.account.lock().await = Some(account);
484 }
485
486 Ok(())
487 }
488}
489
490#[derive(Debug)]
491struct StoreInner {
492 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
493 store: Arc<CryptoStoreWrapper>,
494
495 cache: Arc<RwLock<StoreCache>>,
499
500 verification_machine: VerificationMachine,
501
502 static_account: StaticAccountData,
505}
506
507#[derive(Default, Debug)]
513#[allow(missing_docs)]
514pub struct PendingChanges {
515 pub account: Option<Account>,
516}
517
518impl PendingChanges {
519 pub fn is_empty(&self) -> bool {
521 self.account.is_none()
522 }
523}
524
525#[derive(Default, Debug)]
528#[allow(missing_docs)]
529pub struct Changes {
530 pub private_identity: Option<PrivateCrossSigningIdentity>,
531 pub backup_version: Option<String>,
532 pub backup_decryption_key: Option<BackupDecryptionKey>,
533 pub dehydrated_device_pickle_key: Option<DehydratedDeviceKey>,
534 pub sessions: Vec<Session>,
535 pub message_hashes: Vec<OlmMessageHash>,
536 pub inbound_group_sessions: Vec<InboundGroupSession>,
537 pub outbound_group_sessions: Vec<OutboundGroupSession>,
538 pub key_requests: Vec<GossipRequest>,
539 pub identities: IdentityChanges,
540 pub devices: DeviceChanges,
541 pub withheld_session_info: BTreeMap<OwnedRoomId, BTreeMap<String, RoomKeyWithheldEvent>>,
543 pub room_settings: HashMap<OwnedRoomId, RoomSettings>,
544 pub secrets: Vec<GossippedSecret>,
545 pub next_batch_token: Option<String>,
546
547 pub received_room_key_bundles: Vec<StoredRoomKeyBundleData>,
550}
551
552#[derive(Clone, Debug, Serialize, Deserialize)]
556pub struct StoredRoomKeyBundleData {
557 pub sender_user: OwnedUserId,
559
560 pub sender_data: SenderData,
563
564 pub bundle_data: RoomKeyBundleContent,
566}
567
568#[derive(Clone, Debug, Serialize, Deserialize)]
570pub struct TrackedUser {
571 pub user_id: OwnedUserId,
573 pub dirty: bool,
578}
579
580impl Changes {
581 pub fn is_empty(&self) -> bool {
583 self.private_identity.is_none()
584 && self.backup_version.is_none()
585 && self.backup_decryption_key.is_none()
586 && self.dehydrated_device_pickle_key.is_none()
587 && self.sessions.is_empty()
588 && self.message_hashes.is_empty()
589 && self.inbound_group_sessions.is_empty()
590 && self.outbound_group_sessions.is_empty()
591 && self.key_requests.is_empty()
592 && self.identities.is_empty()
593 && self.devices.is_empty()
594 && self.withheld_session_info.is_empty()
595 && self.room_settings.is_empty()
596 && self.secrets.is_empty()
597 && self.next_batch_token.is_none()
598 && self.received_room_key_bundles.is_empty()
599 }
600}
601
602#[derive(Debug, Clone, Default)]
613#[allow(missing_docs)]
614pub struct IdentityChanges {
615 pub new: Vec<UserIdentityData>,
616 pub changed: Vec<UserIdentityData>,
617 pub unchanged: Vec<UserIdentityData>,
618}
619
620impl IdentityChanges {
621 fn is_empty(&self) -> bool {
622 self.new.is_empty() && self.changed.is_empty()
623 }
624
625 fn into_maps(
628 self,
629 ) -> (
630 BTreeMap<OwnedUserId, UserIdentityData>,
631 BTreeMap<OwnedUserId, UserIdentityData>,
632 BTreeMap<OwnedUserId, UserIdentityData>,
633 ) {
634 let new: BTreeMap<_, _> = self
635 .new
636 .into_iter()
637 .map(|identity| (identity.user_id().to_owned(), identity))
638 .collect();
639
640 let changed: BTreeMap<_, _> = self
641 .changed
642 .into_iter()
643 .map(|identity| (identity.user_id().to_owned(), identity))
644 .collect();
645
646 let unchanged: BTreeMap<_, _> = self
647 .unchanged
648 .into_iter()
649 .map(|identity| (identity.user_id().to_owned(), identity))
650 .collect();
651
652 (new, changed, unchanged)
653 }
654}
655
656#[derive(Debug, Clone, Default)]
657#[allow(missing_docs)]
658pub struct DeviceChanges {
659 pub new: Vec<DeviceData>,
660 pub changed: Vec<DeviceData>,
661 pub deleted: Vec<DeviceData>,
662}
663
664fn collect_device_updates(
670 verification_machine: VerificationMachine,
671 own_identity: Option<OwnUserIdentityData>,
672 identities: IdentityChanges,
673 devices: DeviceChanges,
674) -> DeviceUpdates {
675 let mut new: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
676 let mut changed: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
677
678 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
679
680 let map_device = |device: DeviceData| {
681 let device_owner_identity = new_identities
682 .get(device.user_id())
683 .or_else(|| changed_identities.get(device.user_id()))
684 .or_else(|| unchanged_identities.get(device.user_id()))
685 .cloned();
686
687 Device {
688 inner: device,
689 verification_machine: verification_machine.to_owned(),
690 own_identity: own_identity.to_owned(),
691 device_owner_identity,
692 }
693 };
694
695 for device in devices.new {
696 let device = map_device(device);
697
698 new.entry(device.user_id().to_owned())
699 .or_default()
700 .insert(device.device_id().to_owned(), device);
701 }
702
703 for device in devices.changed {
704 let device = map_device(device);
705
706 changed
707 .entry(device.user_id().to_owned())
708 .or_default()
709 .insert(device.device_id().to_owned(), device.to_owned());
710 }
711
712 DeviceUpdates { new, changed }
713}
714
715#[derive(Clone, Debug, Default)]
718pub struct DeviceUpdates {
719 pub new: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
725 pub changed: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
727}
728
729#[derive(Clone, Debug, Default)]
732pub struct IdentityUpdates {
733 pub new: BTreeMap<OwnedUserId, UserIdentity>,
739 pub changed: BTreeMap<OwnedUserId, UserIdentity>,
741 pub unchanged: BTreeMap<OwnedUserId, UserIdentity>,
743}
744
745#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
755#[serde(transparent)]
756pub struct BackupDecryptionKey {
757 pub(crate) inner: Box<[u8; BackupDecryptionKey::KEY_SIZE]>,
758}
759
760impl BackupDecryptionKey {
761 pub const KEY_SIZE: usize = 32;
763
764 pub fn new() -> Result<Self, rand::Error> {
766 let mut rng = rand::thread_rng();
767
768 let mut key = Box::new([0u8; Self::KEY_SIZE]);
769 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
770
771 Ok(Self { inner: key })
772 }
773
774 pub fn to_base64(&self) -> String {
776 base64_encode(self.inner.as_slice())
777 }
778}
779
780#[cfg(not(tarpaulin_include))]
781impl Debug for BackupDecryptionKey {
782 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
783 f.debug_tuple("BackupDecryptionKey").field(&"...").finish()
784 }
785}
786
787#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
792#[serde(transparent)]
793pub struct DehydratedDeviceKey {
794 pub(crate) inner: Box<[u8; DehydratedDeviceKey::KEY_SIZE]>,
795}
796
797impl DehydratedDeviceKey {
798 pub const KEY_SIZE: usize = 32;
800
801 pub fn new() -> Result<Self, rand::Error> {
803 let mut rng = rand::thread_rng();
804
805 let mut key = Box::new([0u8; Self::KEY_SIZE]);
806 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
807
808 Ok(Self { inner: key })
809 }
810
811 pub fn from_slice(slice: &[u8]) -> Result<Self, DehydrationError> {
815 if slice.len() == 32 {
816 let mut key = Box::new([0u8; 32]);
817 key.copy_from_slice(slice);
818 Ok(DehydratedDeviceKey { inner: key })
819 } else {
820 Err(DehydrationError::PickleKeyLength(slice.len()))
821 }
822 }
823
824 pub fn from_bytes(raw_key: &[u8; 32]) -> Self {
826 let mut inner = Box::new([0u8; Self::KEY_SIZE]);
827 inner.copy_from_slice(raw_key);
828
829 Self { inner }
830 }
831
832 pub fn to_base64(&self) -> String {
834 base64_encode(self.inner.as_slice())
835 }
836}
837
838impl From<&[u8; 32]> for DehydratedDeviceKey {
839 fn from(value: &[u8; 32]) -> Self {
840 DehydratedDeviceKey { inner: Box::new(*value) }
841 }
842}
843
844impl From<DehydratedDeviceKey> for Vec<u8> {
845 fn from(key: DehydratedDeviceKey) -> Self {
846 key.inner.to_vec()
847 }
848}
849
850#[cfg(not(tarpaulin_include))]
851impl Debug for DehydratedDeviceKey {
852 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
853 f.debug_tuple("DehydratedDeviceKey").field(&"...").finish()
854 }
855}
856
857impl DeviceChanges {
858 pub fn extend(&mut self, other: DeviceChanges) {
860 self.new.extend(other.new);
861 self.changed.extend(other.changed);
862 self.deleted.extend(other.deleted);
863 }
864
865 fn is_empty(&self) -> bool {
866 self.new.is_empty() && self.changed.is_empty() && self.deleted.is_empty()
867 }
868}
869
870#[derive(Debug, Clone, Default)]
872pub struct RoomKeyCounts {
873 pub total: usize,
875 pub backed_up: usize,
877}
878
879#[derive(Default, Clone, Debug)]
881pub struct BackupKeys {
882 pub decryption_key: Option<BackupDecryptionKey>,
884 pub backup_version: Option<String>,
886}
887
888#[derive(Default, Zeroize, ZeroizeOnDrop)]
891pub struct CrossSigningKeyExport {
892 pub master_key: Option<String>,
894 pub self_signing_key: Option<String>,
896 pub user_signing_key: Option<String>,
898}
899
900#[cfg(not(tarpaulin_include))]
901impl Debug for CrossSigningKeyExport {
902 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
903 f.debug_struct("CrossSigningKeyExport")
904 .field("master_key", &self.master_key.is_some())
905 .field("self_signing_key", &self.self_signing_key.is_some())
906 .field("user_signing_key", &self.user_signing_key.is_some())
907 .finish_non_exhaustive()
908 }
909}
910
911#[derive(Debug, Error)]
914pub enum SecretImportError {
915 #[error(transparent)]
917 Key(#[from] vodozemac::KeyError),
918 #[error(
921 "The public key of the imported private key doesn't match to the \
922 public key that was uploaded to the server"
923 )]
924 MismatchedPublicKeys,
925 #[error(transparent)]
927 Store(#[from] CryptoStoreError),
928}
929
930#[derive(Debug, Error)]
935pub enum SecretsBundleExportError {
936 #[error(transparent)]
938 Store(#[from] CryptoStoreError),
939 #[error("The store is missing one or multiple cross-signing keys")]
941 MissingCrossSigningKey(KeyUsage),
942 #[error("The store doesn't contain any cross-signing keys")]
944 MissingCrossSigningKeys,
945 #[error("The store contains a backup key, but no backup version")]
948 MissingBackupVersion,
949}
950
951#[derive(Clone, Copy, Debug, PartialEq, Eq)]
954pub(crate) enum UserKeyQueryResult {
955 WasPending,
956 WasNotPending,
957
958 TimeoutExpired,
960}
961
962#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
964pub struct RoomSettings {
965 pub algorithm: EventEncryptionAlgorithm,
967
968 pub only_allow_trusted_devices: bool,
971
972 pub session_rotation_period: Option<Duration>,
975
976 pub session_rotation_period_messages: Option<usize>,
979}
980
981impl Default for RoomSettings {
982 fn default() -> Self {
983 Self {
984 algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
985 only_allow_trusted_devices: false,
986 session_rotation_period: None,
987 session_rotation_period_messages: None,
988 }
989 }
990}
991
992#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
994pub struct RoomKeyInfo {
995 pub algorithm: EventEncryptionAlgorithm,
1000
1001 pub room_id: OwnedRoomId,
1003
1004 pub sender_key: Curve25519PublicKey,
1006
1007 pub session_id: String,
1009}
1010
1011impl From<&InboundGroupSession> for RoomKeyInfo {
1012 fn from(group_session: &InboundGroupSession) -> Self {
1013 RoomKeyInfo {
1014 algorithm: group_session.algorithm().clone(),
1015 room_id: group_session.room_id().to_owned(),
1016 sender_key: group_session.sender_key(),
1017 session_id: group_session.session_id().to_owned(),
1018 }
1019 }
1020}
1021
1022#[derive(Clone, Debug, Deserialize, Serialize)]
1024pub struct RoomKeyWithheldInfo {
1025 pub room_id: OwnedRoomId,
1027
1028 pub session_id: String,
1030
1031 pub withheld_event: RoomKeyWithheldEvent,
1034}
1035
1036impl Store {
1037 pub(crate) fn new(
1039 account: StaticAccountData,
1040 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
1041 store: Arc<CryptoStoreWrapper>,
1042 verification_machine: VerificationMachine,
1043 ) -> Self {
1044 Self {
1045 inner: Arc::new(StoreInner {
1046 static_account: account,
1047 identity,
1048 store: store.clone(),
1049 verification_machine,
1050 cache: Arc::new(RwLock::new(StoreCache {
1051 store,
1052 tracked_users: Default::default(),
1053 loaded_tracked_users: Default::default(),
1054 account: Default::default(),
1055 })),
1056 }),
1057 }
1058 }
1059
1060 pub(crate) fn user_id(&self) -> &UserId {
1062 &self.inner.static_account.user_id
1063 }
1064
1065 pub(crate) fn device_id(&self) -> &DeviceId {
1067 self.inner.verification_machine.own_device_id()
1068 }
1069
1070 pub(crate) fn static_account(&self) -> &StaticAccountData {
1072 &self.inner.static_account
1073 }
1074
1075 pub(crate) async fn cache(&self) -> Result<StoreCacheGuard> {
1076 Ok(StoreCacheGuard { cache: self.inner.cache.clone().read_owned().await })
1081 }
1082
1083 pub(crate) async fn transaction(&self) -> StoreTransaction {
1084 StoreTransaction::new(self.clone()).await
1085 }
1086
1087 pub(crate) async fn with_transaction<
1090 T,
1091 Fut: futures_core::Future<Output = Result<(StoreTransaction, T), crate::OlmError>>,
1092 F: FnOnce(StoreTransaction) -> Fut,
1093 >(
1094 &self,
1095 func: F,
1096 ) -> Result<T, crate::OlmError> {
1097 let tr = self.transaction().await;
1098 let (tr, res) = func(tr).await?;
1099 tr.commit().await?;
1100 Ok(res)
1101 }
1102
1103 #[cfg(test)]
1104 pub(crate) async fn reset_cross_signing_identity(&self) {
1106 self.inner.identity.lock().await.reset();
1107 }
1108
1109 pub(crate) fn private_identity(&self) -> Arc<Mutex<PrivateCrossSigningIdentity>> {
1111 self.inner.identity.clone()
1112 }
1113
1114 pub(crate) async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
1116 let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
1117
1118 self.save_changes(changes).await
1119 }
1120
1121 pub(crate) async fn get_sessions(
1122 &self,
1123 sender_key: &str,
1124 ) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
1125 self.inner.store.get_sessions(sender_key).await
1126 }
1127
1128 pub(crate) async fn save_changes(&self, changes: Changes) -> Result<()> {
1129 self.inner.store.save_changes(changes).await
1130 }
1131
1132 pub(crate) async fn compare_group_session(
1139 &self,
1140 session: &InboundGroupSession,
1141 ) -> Result<SessionOrdering> {
1142 let old_session = self
1143 .inner
1144 .store
1145 .get_inbound_group_session(session.room_id(), session.session_id())
1146 .await?;
1147
1148 Ok(if let Some(old_session) = old_session {
1149 session.compare(&old_session).await
1150 } else {
1151 SessionOrdering::Better
1152 })
1153 }
1154
1155 #[cfg(test)]
1156 pub(crate) async fn save_device_data(&self, devices: &[DeviceData]) -> Result<()> {
1158 let changes = Changes {
1159 devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
1160 ..Default::default()
1161 };
1162
1163 self.save_changes(changes).await
1164 }
1165
1166 pub(crate) async fn save_inbound_group_sessions(
1168 &self,
1169 sessions: &[InboundGroupSession],
1170 ) -> Result<()> {
1171 let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
1172
1173 self.save_changes(changes).await
1174 }
1175
1176 pub(crate) async fn device_display_name(&self) -> Result<Option<String>, CryptoStoreError> {
1178 Ok(self
1179 .inner
1180 .store
1181 .get_device(self.user_id(), self.device_id())
1182 .await?
1183 .and_then(|d| d.display_name().map(|d| d.to_owned())))
1184 }
1185
1186 pub(crate) async fn get_device_data(
1191 &self,
1192 user_id: &UserId,
1193 device_id: &DeviceId,
1194 ) -> Result<Option<DeviceData>> {
1195 self.inner.store.get_device(user_id, device_id).await
1196 }
1197
1198 pub(crate) async fn get_device_data_for_user_filtered(
1206 &self,
1207 user_id: &UserId,
1208 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1209 self.inner.store.get_user_devices(user_id).await.map(|mut d| {
1210 if user_id == self.user_id() {
1211 d.remove(self.device_id());
1212 }
1213 d
1214 })
1215 }
1216
1217 pub(crate) async fn get_device_data_for_user(
1226 &self,
1227 user_id: &UserId,
1228 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1229 self.inner.store.get_user_devices(user_id).await
1230 }
1231
1232 pub(crate) async fn get_device_from_curve_key(
1238 &self,
1239 user_id: &UserId,
1240 curve_key: Curve25519PublicKey,
1241 ) -> Result<Option<Device>> {
1242 self.get_user_devices(user_id)
1243 .await
1244 .map(|d| d.devices().find(|d| d.curve25519_key() == Some(curve_key)))
1245 }
1246
1247 pub(crate) async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
1257 let devices = self.get_device_data_for_user(user_id).await?;
1258
1259 let own_identity = self
1260 .inner
1261 .store
1262 .get_user_identity(self.user_id())
1263 .await?
1264 .and_then(|i| i.own().cloned());
1265 let device_owner_identity = self.inner.store.get_user_identity(user_id).await?;
1266
1267 Ok(UserDevices {
1268 inner: devices,
1269 verification_machine: self.inner.verification_machine.clone(),
1270 own_identity,
1271 device_owner_identity,
1272 })
1273 }
1274
1275 pub(crate) async fn get_device(
1285 &self,
1286 user_id: &UserId,
1287 device_id: &DeviceId,
1288 ) -> Result<Option<Device>> {
1289 if let Some(device_data) = self.inner.store.get_device(user_id, device_id).await? {
1290 Ok(Some(self.wrap_device_data(device_data).await?))
1291 } else {
1292 Ok(None)
1293 }
1294 }
1295
1296 pub(crate) async fn wrap_device_data(&self, device_data: DeviceData) -> Result<Device> {
1301 let own_identity = self
1302 .inner
1303 .store
1304 .get_user_identity(self.user_id())
1305 .await?
1306 .and_then(|i| i.own().cloned());
1307
1308 let device_owner_identity =
1309 self.inner.store.get_user_identity(device_data.user_id()).await?;
1310
1311 Ok(Device {
1312 inner: device_data,
1313 verification_machine: self.inner.verification_machine.clone(),
1314 own_identity,
1315 device_owner_identity,
1316 })
1317 }
1318
1319 pub(crate) async fn get_identity(&self, user_id: &UserId) -> Result<Option<UserIdentity>> {
1321 let own_identity = self
1322 .inner
1323 .store
1324 .get_user_identity(self.user_id())
1325 .await?
1326 .and_then(as_variant!(UserIdentityData::Own));
1327
1328 Ok(self.inner.store.get_user_identity(user_id).await?.map(|i| {
1329 UserIdentity::new(
1330 self.clone(),
1331 i,
1332 self.inner.verification_machine.to_owned(),
1333 own_identity,
1334 )
1335 }))
1336 }
1337
1338 pub async fn export_secret(
1347 &self,
1348 secret_name: &SecretName,
1349 ) -> Result<Option<String>, CryptoStoreError> {
1350 Ok(match secret_name {
1351 SecretName::CrossSigningMasterKey
1352 | SecretName::CrossSigningUserSigningKey
1353 | SecretName::CrossSigningSelfSigningKey => {
1354 self.inner.identity.lock().await.export_secret(secret_name).await
1355 }
1356 SecretName::RecoveryKey => {
1357 if let Some(key) = self.load_backup_keys().await?.decryption_key {
1358 let exported = key.to_base64();
1359 Some(exported)
1360 } else {
1361 None
1362 }
1363 }
1364 name => {
1365 warn!(secret = ?name, "Unknown secret was requested");
1366 None
1367 }
1368 })
1369 }
1370
1371 pub async fn export_cross_signing_keys(
1379 &self,
1380 ) -> Result<Option<CrossSigningKeyExport>, CryptoStoreError> {
1381 let master_key = self.export_secret(&SecretName::CrossSigningMasterKey).await?;
1382 let self_signing_key = self.export_secret(&SecretName::CrossSigningSelfSigningKey).await?;
1383 let user_signing_key = self.export_secret(&SecretName::CrossSigningUserSigningKey).await?;
1384
1385 Ok(if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() {
1386 None
1387 } else {
1388 Some(CrossSigningKeyExport { master_key, self_signing_key, user_signing_key })
1389 })
1390 }
1391
1392 pub async fn import_cross_signing_keys(
1397 &self,
1398 export: CrossSigningKeyExport,
1399 ) -> Result<CrossSigningStatus, SecretImportError> {
1400 if let Some(public_identity) =
1401 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1402 {
1403 let identity = self.inner.identity.lock().await;
1404
1405 identity
1406 .import_secrets(
1407 public_identity.to_owned(),
1408 export.master_key.as_deref(),
1409 export.self_signing_key.as_deref(),
1410 export.user_signing_key.as_deref(),
1411 )
1412 .await?;
1413
1414 let status = identity.status().await;
1415
1416 let diff = identity.get_public_identity_diff(&public_identity.inner).await;
1417
1418 let mut changes =
1419 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1420
1421 if diff.none_differ() {
1422 public_identity.mark_as_verified();
1423 changes.identities.changed.push(UserIdentityData::Own(public_identity.inner));
1424 }
1425
1426 info!(?status, "Successfully imported the private cross-signing keys");
1427
1428 self.save_changes(changes).await?;
1429 } else {
1430 warn!("No public identity found while importing cross-signing keys, a /keys/query needs to be done");
1431 }
1432
1433 Ok(self.inner.identity.lock().await.status().await)
1434 }
1435
1436 pub async fn export_secrets_bundle(&self) -> Result<SecretsBundle, SecretsBundleExportError> {
1448 let Some(cross_signing) = self.export_cross_signing_keys().await? else {
1449 return Err(SecretsBundleExportError::MissingCrossSigningKeys);
1450 };
1451
1452 let Some(master_key) = cross_signing.master_key.clone() else {
1453 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::Master));
1454 };
1455
1456 let Some(user_signing_key) = cross_signing.user_signing_key.clone() else {
1457 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::UserSigning));
1458 };
1459
1460 let Some(self_signing_key) = cross_signing.self_signing_key.clone() else {
1461 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::SelfSigning));
1462 };
1463
1464 let backup_keys = self.load_backup_keys().await?;
1465
1466 let backup = if let Some(key) = backup_keys.decryption_key {
1467 if let Some(backup_version) = backup_keys.backup_version {
1468 Some(BackupSecrets::MegolmBackupV1Curve25519AesSha2(
1469 MegolmBackupV1Curve25519AesSha2Secrets { key, backup_version },
1470 ))
1471 } else {
1472 return Err(SecretsBundleExportError::MissingBackupVersion);
1473 }
1474 } else {
1475 None
1476 };
1477
1478 Ok(SecretsBundle {
1479 cross_signing: CrossSigningSecrets { master_key, user_signing_key, self_signing_key },
1480 backup,
1481 })
1482 }
1483
1484 pub async fn import_secrets_bundle(
1497 &self,
1498 bundle: &SecretsBundle,
1499 ) -> Result<(), SecretImportError> {
1500 let mut changes = Changes::default();
1501
1502 if let Some(backup_bundle) = &bundle.backup {
1503 match backup_bundle {
1504 BackupSecrets::MegolmBackupV1Curve25519AesSha2(bundle) => {
1505 changes.backup_decryption_key = Some(bundle.key.clone());
1506 changes.backup_version = Some(bundle.backup_version.clone());
1507 }
1508 }
1509 }
1510
1511 let identity = self.inner.identity.lock().await;
1512
1513 identity
1514 .import_secrets_unchecked(
1515 Some(&bundle.cross_signing.master_key),
1516 Some(&bundle.cross_signing.self_signing_key),
1517 Some(&bundle.cross_signing.user_signing_key),
1518 )
1519 .await?;
1520
1521 let public_identity = identity.to_public_identity().await.expect(
1522 "We should be able to create a new public identity since we just imported \
1523 all the private cross-signing keys",
1524 );
1525
1526 changes.private_identity = Some(identity.clone());
1527 changes.identities.new.push(UserIdentityData::Own(public_identity));
1528
1529 Ok(self.save_changes(changes).await?)
1530 }
1531
1532 pub async fn import_secret(&self, secret: &GossippedSecret) -> Result<(), SecretImportError> {
1534 match &secret.secret_name {
1535 SecretName::CrossSigningMasterKey
1536 | SecretName::CrossSigningUserSigningKey
1537 | SecretName::CrossSigningSelfSigningKey => {
1538 if let Some(public_identity) =
1539 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1540 {
1541 let identity = self.inner.identity.lock().await;
1542
1543 identity
1544 .import_secret(
1545 public_identity,
1546 &secret.secret_name,
1547 &secret.event.content.secret,
1548 )
1549 .await?;
1550 info!(
1551 secret_name = ?secret.secret_name,
1552 "Successfully imported a private cross signing key"
1553 );
1554
1555 let changes =
1556 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1557
1558 self.save_changes(changes).await?;
1559 }
1560 }
1561 SecretName::RecoveryKey => {
1562 }
1568 name => {
1569 warn!(secret = ?name, "Tried to import an unknown secret");
1570 }
1571 }
1572
1573 Ok(())
1574 }
1575
1576 pub async fn get_only_allow_trusted_devices(&self) -> Result<bool> {
1579 let value = self.get_value("only_allow_trusted_devices").await?.unwrap_or_default();
1580 Ok(value)
1581 }
1582
1583 pub async fn set_only_allow_trusted_devices(
1586 &self,
1587 block_untrusted_devices: bool,
1588 ) -> Result<()> {
1589 self.set_value("only_allow_trusted_devices", &block_untrusted_devices).await
1590 }
1591
1592 pub async fn get_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
1594 let Some(value) = self.get_custom_value(key).await? else {
1595 return Ok(None);
1596 };
1597 let deserialized = self.deserialize_value(&value)?;
1598 Ok(Some(deserialized))
1599 }
1600
1601 pub async fn set_value(&self, key: &str, value: &impl Serialize) -> Result<()> {
1603 let serialized = self.serialize_value(value)?;
1604 self.set_custom_value(key, serialized).await?;
1605 Ok(())
1606 }
1607
1608 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
1609 let serialized =
1610 rmp_serde::to_vec_named(value).map_err(|x| CryptoStoreError::Backend(x.into()))?;
1611 Ok(serialized)
1612 }
1613
1614 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
1615 let deserialized =
1616 rmp_serde::from_slice(value).map_err(|e| CryptoStoreError::Backend(e.into()))?;
1617 Ok(deserialized)
1618 }
1619
1620 pub fn room_keys_received_stream(
1632 &self,
1633 ) -> impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>> {
1634 self.inner.store.room_keys_received_stream()
1635 }
1636
1637 pub fn room_keys_withheld_received_stream(
1646 &self,
1647 ) -> impl Stream<Item = Vec<RoomKeyWithheldInfo>> {
1648 self.inner.store.room_keys_withheld_received_stream()
1649 }
1650
1651 pub fn user_identities_stream(&self) -> impl Stream<Item = IdentityUpdates> {
1682 let verification_machine = self.inner.verification_machine.to_owned();
1683
1684 let this = self.clone();
1685 self.inner.store.identities_stream().map(move |(own_identity, identities, _)| {
1686 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
1687
1688 let map_identity = |(user_id, identity)| {
1689 (
1690 user_id,
1691 UserIdentity::new(
1692 this.clone(),
1693 identity,
1694 verification_machine.to_owned(),
1695 own_identity.to_owned(),
1696 ),
1697 )
1698 };
1699
1700 let new = new_identities.into_iter().map(map_identity).collect();
1701 let changed = changed_identities.into_iter().map(map_identity).collect();
1702 let unchanged = unchanged_identities.into_iter().map(map_identity).collect();
1703
1704 IdentityUpdates { new, changed, unchanged }
1705 })
1706 }
1707
1708 pub fn devices_stream(&self) -> impl Stream<Item = DeviceUpdates> {
1740 let verification_machine = self.inner.verification_machine.to_owned();
1741
1742 self.inner.store.identities_stream().map(move |(own_identity, identities, devices)| {
1743 collect_device_updates(
1744 verification_machine.to_owned(),
1745 own_identity,
1746 identities,
1747 devices,
1748 )
1749 })
1750 }
1751
1752 pub fn identities_stream_raw(&self) -> impl Stream<Item = (IdentityChanges, DeviceChanges)> {
1762 self.inner.store.identities_stream().map(|(_, identities, devices)| (identities, devices))
1763 }
1764
1765 pub fn create_store_lock(
1768 &self,
1769 lock_key: String,
1770 lock_value: String,
1771 ) -> CrossProcessStoreLock<LockableCryptoStore> {
1772 self.inner.store.create_store_lock(lock_key, lock_value)
1773 }
1774
1775 pub fn secrets_stream(&self) -> impl Stream<Item = GossippedSecret> {
1815 self.inner.store.secrets_stream()
1816 }
1817
1818 pub async fn import_room_keys(
1831 &self,
1832 exported_keys: Vec<ExportedRoomKey>,
1833 from_backup_version: Option<&str>,
1834 progress_listener: impl Fn(usize, usize),
1835 ) -> Result<RoomKeyImportResult> {
1836 let exported_keys: Vec<&ExportedRoomKey> = exported_keys.iter().collect();
1837 self.import_sessions_impl(exported_keys, from_backup_version, progress_listener).await
1838 }
1839
1840 pub async fn import_exported_room_keys(
1867 &self,
1868 exported_keys: Vec<ExportedRoomKey>,
1869 progress_listener: impl Fn(usize, usize),
1870 ) -> Result<RoomKeyImportResult> {
1871 self.import_room_keys(exported_keys, None, progress_listener).await
1872 }
1873
1874 async fn import_sessions_impl<T>(
1875 &self,
1876 room_keys: Vec<T>,
1877 from_backup_version: Option<&str>,
1878 progress_listener: impl Fn(usize, usize),
1879 ) -> Result<RoomKeyImportResult>
1880 where
1881 T: TryInto<InboundGroupSession> + RoomKeyExport + Copy,
1882 T::Error: Debug,
1883 {
1884 let mut sessions = Vec::new();
1885
1886 async fn new_session_better(
1887 session: &InboundGroupSession,
1888 old_session: Option<InboundGroupSession>,
1889 ) -> bool {
1890 if let Some(old_session) = &old_session {
1891 session.compare(old_session).await == SessionOrdering::Better
1892 } else {
1893 true
1894 }
1895 }
1896
1897 let total_count = room_keys.len();
1898 let mut keys = BTreeMap::new();
1899
1900 for (i, key) in room_keys.into_iter().enumerate() {
1901 match key.try_into() {
1902 Ok(session) => {
1903 let old_session = self
1904 .inner
1905 .store
1906 .get_inbound_group_session(session.room_id(), session.session_id())
1907 .await?;
1908
1909 if new_session_better(&session, old_session).await {
1912 if from_backup_version.is_some() {
1913 session.mark_as_backed_up();
1914 }
1915
1916 keys.entry(session.room_id().to_owned())
1917 .or_insert_with(BTreeMap::new)
1918 .entry(session.sender_key().to_base64())
1919 .or_insert_with(BTreeSet::new)
1920 .insert(session.session_id().to_owned());
1921
1922 sessions.push(session);
1923 }
1924 }
1925 Err(e) => {
1926 warn!(
1927 sender_key = key.sender_key().to_base64(),
1928 room_id = ?key.room_id(),
1929 session_id = key.session_id(),
1930 error = ?e,
1931 "Couldn't import a room key from a file export."
1932 );
1933 }
1934 }
1935
1936 progress_listener(i, total_count);
1937 }
1938
1939 let imported_count = sessions.len();
1940
1941 self.inner.store.save_inbound_group_sessions(sessions, from_backup_version).await?;
1942
1943 info!(total_count, imported_count, room_keys = ?keys, "Successfully imported room keys");
1944
1945 Ok(RoomKeyImportResult::new(imported_count, total_count, keys))
1946 }
1947
1948 pub(crate) fn crypto_store(&self) -> Arc<CryptoStoreWrapper> {
1949 self.inner.store.clone()
1950 }
1951
1952 pub async fn export_room_keys(
1975 &self,
1976 predicate: impl FnMut(&InboundGroupSession) -> bool,
1977 ) -> Result<Vec<ExportedRoomKey>> {
1978 let mut exported = Vec::new();
1979
1980 let mut sessions = self.get_inbound_group_sessions().await?;
1981 sessions.retain(predicate);
1982
1983 for session in sessions {
1984 let export = session.export().await;
1985 exported.push(export);
1986 }
1987
1988 Ok(exported)
1989 }
1990
1991 pub async fn export_room_keys_stream(
2024 &self,
2025 predicate: impl FnMut(&InboundGroupSession) -> bool,
2026 ) -> Result<impl Stream<Item = ExportedRoomKey>> {
2027 let sessions = self.get_inbound_group_sessions().await?;
2029 Ok(futures_util::stream::iter(sessions.into_iter().filter(predicate))
2030 .then(|session| async move { session.export().await }))
2031 }
2032
2033 pub async fn build_room_key_bundle(
2038 &self,
2039 room_id: &RoomId,
2040 ) -> std::result::Result<RoomKeyBundle, CryptoStoreError> {
2041 let mut sessions = self.get_inbound_group_sessions().await?;
2044 sessions.retain(|session| session.room_id == room_id);
2045
2046 let mut bundle = RoomKeyBundle::default();
2047 for session in sessions {
2048 if session.shared_history() {
2049 bundle.room_keys.push(session.export().await.into());
2050 } else {
2051 bundle.withheld.push(RoomKeyWithheldContent::new(
2052 session.algorithm().to_owned(),
2053 WithheldCode::Unauthorised,
2054 session.room_id().to_owned(),
2055 session.session_id().to_owned(),
2056 session.sender_key().to_owned(),
2057 self.device_id().to_owned(),
2058 ));
2059 }
2060 }
2061
2062 Ok(bundle)
2063 }
2064}
2065
2066impl Deref for Store {
2067 type Target = DynCryptoStore;
2068
2069 fn deref(&self) -> &Self::Target {
2070 self.inner.store.deref().deref()
2071 }
2072}
2073
2074#[derive(Clone, Debug)]
2076pub struct LockableCryptoStore(Arc<dyn CryptoStore<Error = CryptoStoreError>>);
2077
2078#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
2079#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
2080impl matrix_sdk_common::store_locks::BackingStore for LockableCryptoStore {
2081 type LockError = CryptoStoreError;
2082
2083 async fn try_lock(
2084 &self,
2085 lease_duration_ms: u32,
2086 key: &str,
2087 holder: &str,
2088 ) -> std::result::Result<bool, Self::LockError> {
2089 self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
2090 }
2091}
2092
2093#[cfg(test)]
2094mod tests {
2095 use std::pin::pin;
2096
2097 use futures_util::StreamExt;
2098 use insta::{_macro_support::Content, assert_json_snapshot, internals::ContentPath};
2099 use matrix_sdk_test::async_test;
2100 use ruma::{device_id, room_id, user_id, RoomId};
2101 use vodozemac::megolm::SessionKey;
2102
2103 use crate::{
2104 machine::test_helpers::get_machine_pair,
2105 olm::{InboundGroupSession, SenderData},
2106 store::DehydratedDeviceKey,
2107 types::EventEncryptionAlgorithm,
2108 OlmMachine,
2109 };
2110
2111 #[async_test]
2112 async fn test_import_room_keys_notifies_stream() {
2113 use futures_util::FutureExt;
2114
2115 let (alice, bob, _) =
2116 get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2117
2118 let room1_id = room_id!("!room1:localhost");
2119 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2120 let exported_sessions = alice.store().export_room_keys(|_| true).await.unwrap();
2121
2122 let mut room_keys_received_stream = Box::pin(bob.store().room_keys_received_stream());
2123 bob.store().import_room_keys(exported_sessions, None, |_, _| {}).await.unwrap();
2124
2125 let room_keys = room_keys_received_stream
2126 .next()
2127 .now_or_never()
2128 .flatten()
2129 .expect("We should have received an update of room key infos")
2130 .unwrap();
2131 assert_eq!(room_keys.len(), 1);
2132 assert_eq!(room_keys[0].room_id, "!room1:localhost");
2133 }
2134
2135 #[async_test]
2136 async fn test_export_room_keys_provides_selected_keys() {
2137 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2139 let room1_id = room_id!("!room1:localhost");
2140 let room2_id = room_id!("!room2:localhost");
2141 let room3_id = room_id!("!room3:localhost");
2142 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2143 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2144 alice.create_outbound_group_session_with_defaults_test_helper(room3_id).await.unwrap();
2145
2146 let keys = alice
2148 .store()
2149 .export_room_keys(|s| s.room_id() == room2_id || s.room_id() == room3_id)
2150 .await
2151 .unwrap();
2152
2153 assert_eq!(keys.len(), 2);
2155 assert_eq!(keys[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2156 assert_eq!(keys[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2157 assert_eq!(keys[0].room_id, "!room2:localhost");
2158 assert_eq!(keys[1].room_id, "!room3:localhost");
2159 assert_eq!(keys[0].session_key.to_base64().len(), 220);
2160 assert_eq!(keys[1].session_key.to_base64().len(), 220);
2161 }
2162
2163 #[async_test]
2164 async fn test_export_room_keys_stream_can_provide_all_keys() {
2165 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2167 let room1_id = room_id!("!room1:localhost");
2168 let room2_id = room_id!("!room2:localhost");
2169 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2170 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2171
2172 let mut keys = pin!(alice.store().export_room_keys_stream(|_| true).await.unwrap());
2174
2175 let mut collected = vec![];
2177 while let Some(key) = keys.next().await {
2178 collected.push(key);
2179 }
2180
2181 assert_eq!(collected.len(), 2);
2183 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2184 assert_eq!(collected[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2185 assert_eq!(collected[0].room_id, "!room1:localhost");
2186 assert_eq!(collected[1].room_id, "!room2:localhost");
2187 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2188 assert_eq!(collected[1].session_key.to_base64().len(), 220);
2189 }
2190
2191 #[async_test]
2192 async fn test_export_room_keys_stream_can_provide_a_subset_of_keys() {
2193 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2195 let room1_id = room_id!("!room1:localhost");
2196 let room2_id = room_id!("!room2:localhost");
2197 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2198 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2199
2200 let mut keys =
2202 pin!(alice.store().export_room_keys_stream(|s| s.room_id() == room1_id).await.unwrap());
2203
2204 let mut collected = vec![];
2206 while let Some(key) = keys.next().await {
2207 collected.push(key);
2208 }
2209
2210 assert_eq!(collected.len(), 1);
2212 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2213 assert_eq!(collected[0].room_id, "!room1:localhost");
2214 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2215 }
2216
2217 #[async_test]
2218 async fn test_export_secrets_bundle() {
2219 let user_id = user_id!("@alice:example.com");
2220 let (first, second, _) = get_machine_pair(user_id, user_id, false).await;
2221
2222 let _ = first
2223 .bootstrap_cross_signing(false)
2224 .await
2225 .expect("We should be able to bootstrap cross-signing");
2226
2227 let bundle = first.store().export_secrets_bundle().await.expect(
2228 "We should be able to export the secrets bundle, now that we \
2229 have the cross-signing keys",
2230 );
2231
2232 assert!(bundle.backup.is_none(), "The bundle should not contain a backup key");
2233
2234 second
2235 .store()
2236 .import_secrets_bundle(&bundle)
2237 .await
2238 .expect("We should be able to import the secrets bundle");
2239
2240 let status = second.cross_signing_status().await;
2241 let identity = second.get_identity(user_id, None).await.unwrap().unwrap().own().unwrap();
2242
2243 assert!(identity.is_verified(), "The public identity should be marked as verified.");
2244
2245 assert!(status.is_complete(), "We should have imported all the cross-signing keys");
2246 }
2247
2248 #[async_test]
2249 async fn test_create_dehydrated_device_key() {
2250 let pickle_key = DehydratedDeviceKey::new()
2251 .expect("Should be able to create a random dehydrated device key");
2252
2253 let to_vec = pickle_key.inner.to_vec();
2254 let pickle_key_from_slice = DehydratedDeviceKey::from_slice(to_vec.as_slice())
2255 .expect("Should be able to create a dehydrated device key from slice");
2256
2257 assert_eq!(pickle_key_from_slice.to_base64(), pickle_key.to_base64());
2258 }
2259
2260 #[async_test]
2261 async fn test_create_dehydrated_errors() {
2262 let too_small = [0u8; 22];
2263 let pickle_key = DehydratedDeviceKey::from_slice(&too_small);
2264
2265 assert!(pickle_key.is_err());
2266
2267 let too_big = [0u8; 40];
2268 let pickle_key = DehydratedDeviceKey::from_slice(&too_big);
2269
2270 assert!(pickle_key.is_err());
2271 }
2272
2273 #[async_test]
2274 async fn test_build_room_key_bundle() {
2275 let alice = OlmMachine::new(user_id!("@a:s.co"), device_id!("ALICE")).await;
2278 let bob = OlmMachine::new(user_id!("@b:s.co"), device_id!("BOB")).await;
2279
2280 let room1_id = room_id!("!room1:localhost");
2281 let room2_id = room_id!("!room2:localhost");
2282
2283 let session_key1 = "AgAAAAC2XHVzsMBKs4QCRElJ92CJKyGtknCSC8HY7cQ7UYwndMKLQAejXLh5UA0l6s736mgctcUMNvELScUWrObdflrHo+vth/gWreXOaCnaSxmyjjKErQwyIYTkUfqbHy40RJfEesLwnN23on9XAkch/iy8R2+Jz7B8zfG01f2Ow2SxPQFnAndcO1ZSD2GmXgedy6n4B20MWI1jGP2wiexOWbFSya8DO/VxC9m5+/mF+WwYqdpKn9g4Y05Yw4uz7cdjTc3rXm7xK+8E7hI//5QD1nHPvuKYbjjM9u2JSL+Bzp61Cw";
2288 let session_key2 = "AgAAAAC1BXreFTUQQSBGekTEuYxhdytRKyv4JgDGcG+VOBYdPNGgs807SdibCGJky4lJ3I+7ZDGHoUzZPZP/4ogGu4kxni0PWdtWuN7+5zsuamgoFF/BkaGeUUGv6kgIkx8pyPpM5SASTUEP9bN2loDSpUPYwfiIqz74DgC4WQ4435sTBctYvKz8n+TDJwdLXpyT6zKljuqADAioud+s/iqx9LYn9HpbBfezZcvbg67GtE113pLrvde3IcPI5s6dNHK2onGO2B2eoaobcen18bbEDnlUGPeIivArLya7Da6us14jBQ";
2289 let session_key3 = "AgAAAAAM9KFsliaUUhGSXgwOzM5UemjkNH4n8NHgvC/y8hhw13zTF+ooGD4uIYEXYX630oNvQm/EvgZo+dkoc0re+vsqsx4sQeNODdSjcBsWOa0oDF+irQn9oYoLUDPI1IBtY1rX+FV99Zm/xnG7uFOX7aTVlko2GSdejy1w9mfobmfxu5aUc04A9zaKJP1pOthZvRAlhpymGYHgsDtWPrrjyc/yypMflE4kIUEEEtu1kT6mrAmcl615XYRAHYK9G2+fZsGvokwzbkl4nulGwcZMpQEoM0nD2o3GWgX81HW3nGfKBg";
2290 let session_key4 = "AgAAAAA4Kkesxq2h4v9PLD6Sm3Smxspz1PXTqytQPCMQMkkrHNmzV2bHlJ+6/Al9cu8vh1Oj69AK0WUAeJOJuaiskEeg/PI3P03+UYLeC379RzgqwSHdBgdQ41G2vD6zpgmE/8vYToe+qpCZACtPOswZxyqxHH+T/Iq0nv13JmlFGIeA6fEPfr5Y28B49viG74Fs9rxV9EH5PfjbuPM/p+Sz5obShuaBPKQBX1jT913nEXPoIJ06exNZGr0285nw/LgVvNlmWmbqNnbzO2cNZjQWA+xZYz5FSfyCxwqEBbEdUCuRCQ";
2291
2292 let sessions = [
2293 create_inbound_group_session_with_visibility(
2294 &alice,
2295 room1_id,
2296 &SessionKey::from_base64(session_key1).unwrap(),
2297 true,
2298 ),
2299 create_inbound_group_session_with_visibility(
2300 &alice,
2301 room1_id,
2302 &SessionKey::from_base64(session_key2).unwrap(),
2303 true,
2304 ),
2305 create_inbound_group_session_with_visibility(
2306 &alice,
2307 room1_id,
2308 &SessionKey::from_base64(session_key3).unwrap(),
2309 false,
2310 ),
2311 create_inbound_group_session_with_visibility(
2312 &alice,
2313 room2_id,
2314 &SessionKey::from_base64(session_key4).unwrap(),
2315 true,
2316 ),
2317 ];
2318 bob.store().save_inbound_group_sessions(&sessions).await.unwrap();
2319
2320 let mut bundle = bob.store().build_room_key_bundle(room1_id).await.unwrap();
2322
2323 bundle.room_keys.sort_by_key(|session| session.session_id.clone());
2327
2328 let alice_curve_key = alice.identity_keys().curve25519.to_base64();
2330 let map_alice_curve_key = move |value: Content, _path: ContentPath<'_>| {
2331 assert_eq!(value.as_str().unwrap(), alice_curve_key);
2332 "[alice curve key]"
2333 };
2334 let alice_ed25519_key = alice.identity_keys().ed25519.to_base64();
2335 let map_alice_ed25519_key = move |value: Content, _path: ContentPath<'_>| {
2336 assert_eq!(value.as_str().unwrap(), alice_ed25519_key);
2337 "[alice ed25519 key]"
2338 };
2339
2340 insta::with_settings!({ sort_maps => true }, {
2341 assert_json_snapshot!(bundle, {
2342 ".room_keys[].sender_key" => insta::dynamic_redaction(map_alice_curve_key.clone()),
2343 ".withheld[].sender_key" => insta::dynamic_redaction(map_alice_curve_key),
2344 ".room_keys[].sender_claimed_keys.ed25519" => insta::dynamic_redaction(map_alice_ed25519_key),
2345 });
2346 });
2347 }
2348
2349 fn create_inbound_group_session_with_visibility(
2354 olm_machine: &OlmMachine,
2355 room_id: &RoomId,
2356 session_key: &SessionKey,
2357 shared_history: bool,
2358 ) -> InboundGroupSession {
2359 let identity_keys = &olm_machine.store().static_account().identity_keys;
2360 InboundGroupSession::new(
2361 identity_keys.curve25519,
2362 identity_keys.ed25519,
2363 room_id,
2364 session_key,
2365 SenderData::unknown(),
2366 EventEncryptionAlgorithm::MegolmV1AesSha2,
2367 None,
2368 shared_history,
2369 )
2370 .unwrap()
2371 }
2372}