1use std::{fmt, sync::Arc};
16
17use ruma::{serde::Raw, SecondsSinceUnixEpoch};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use tokio::sync::Mutex;
21use tracing::{debug, Span};
22use vodozemac::{
23 olm::{DecryptionError, OlmMessage, Session as InnerSession, SessionConfig, SessionPickle},
24 Curve25519PublicKey,
25};
26
27#[cfg(feature = "experimental-algorithms")]
28use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content;
29use crate::{
30 error::{EventError, OlmResult, SessionUnpickleError},
31 types::{
32 events::{
33 olm_v1::{DecryptedOlmV1Event, OlmV1Keys},
34 room::encrypted::{OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent},
35 EventType,
36 },
37 DeviceKeys, EventEncryptionAlgorithm,
38 },
39 DeviceData,
40};
41
42#[derive(Clone)]
45pub struct Session {
46 pub inner: Arc<Mutex<InnerSession>>,
48 pub session_id: Arc<str>,
50 pub sender_key: Curve25519PublicKey,
52 pub our_device_keys: DeviceKeys,
54 pub created_using_fallback_key: bool,
56 pub creation_time: SecondsSinceUnixEpoch,
58 pub last_use_time: SecondsSinceUnixEpoch,
60}
61
62#[cfg(not(tarpaulin_include))]
63impl fmt::Debug for Session {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("Session")
66 .field("session_id", &self.session_id())
67 .field("sender_key", &self.sender_key)
68 .finish()
69 }
70}
71
72impl Session {
73 pub async fn decrypt(&mut self, message: &OlmMessage) -> Result<String, DecryptionError> {
82 let mut inner = self.inner.lock().await;
83 Span::current().record("session_id", inner.session_id());
84
85 let plaintext = inner.decrypt(message)?;
86 debug!(session=?inner, "Decrypted an Olm message");
87
88 let plaintext = String::from_utf8_lossy(&plaintext).to_string();
89
90 self.last_use_time = SecondsSinceUnixEpoch::now();
91
92 Ok(plaintext)
93 }
94
95 pub fn sender_key(&self) -> Curve25519PublicKey {
97 self.sender_key
98 }
99
100 pub async fn session_config(&self) -> SessionConfig {
102 self.inner.lock().await.session_config()
103 }
104
105 #[allow(clippy::unused_async)] pub async fn algorithm(&self) -> EventEncryptionAlgorithm {
108 #[cfg(feature = "experimental-algorithms")]
109 if self.session_config().await.version() == 2 {
110 EventEncryptionAlgorithm::OlmV2Curve25519AesSha2
111 } else {
112 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
113 }
114
115 #[cfg(not(feature = "experimental-algorithms"))]
116 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
117 }
118
119 pub(crate) async fn encrypt_helper(&mut self, plaintext: &str) -> OlmMessage {
127 let mut session = self.inner.lock().await;
128 let message = session.encrypt(plaintext);
129 self.last_use_time = SecondsSinceUnixEpoch::now();
130 debug!(?session, "Successfully encrypted an event");
131 message
132 }
133
134 pub async fn encrypt(
147 &mut self,
148 recipient_device: &DeviceData,
149 event_type: &str,
150 content: impl Serialize,
151 message_id: Option<String>,
152 ) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
153 #[derive(Debug)]
154 struct Content<'a> {
155 event_type: &'a str,
156 content: Raw<Value>,
157 }
158
159 impl EventType for Content<'_> {
160 const EVENT_TYPE: &'static str = "";
170
171 fn event_type(&self) -> &str {
172 self.event_type
173 }
174 }
175
176 impl Serialize for Content<'_> {
177 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178 where
179 S: serde::Serializer,
180 {
181 self.content.serialize(serializer)
182 }
183 }
184
185 let plaintext = {
186 let content = serde_json::to_value(content)?;
187 let content = Content { event_type, content: Raw::new(&content)? };
188
189 let recipient_signing_key =
190 recipient_device.ed25519_key().ok_or(EventError::MissingSigningKey)?;
191
192 let content = DecryptedOlmV1Event {
193 sender: self.our_device_keys.user_id.clone(),
194 recipient: recipient_device.user_id().into(),
195 keys: OlmV1Keys {
196 ed25519: self
197 .our_device_keys
198 .ed25519_key()
199 .expect("Our own device should have an Ed25519 public key"),
200 },
201 recipient_keys: OlmV1Keys { ed25519: recipient_signing_key },
202 sender_device_keys: Some(self.our_device_keys.clone()),
203 content,
204 };
205
206 serde_json::to_string(&content)?
207 };
208
209 let ciphertext = self.encrypt_helper(&plaintext).await;
210
211 let content = self.build_encrypted_event(ciphertext, message_id).await?;
212 let content = Raw::new(&content)?;
213 Ok(content)
214 }
215
216 pub(crate) async fn build_encrypted_event(
225 &self,
226 ciphertext: OlmMessage,
227 message_id: Option<String>,
228 ) -> OlmResult<ToDeviceEncryptedEventContent> {
229 let content = match self.algorithm().await {
230 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => OlmV1Curve25519AesSha2Content {
231 ciphertext,
232 recipient_key: self.sender_key,
233 sender_key: self
234 .our_device_keys
235 .curve25519_key()
236 .expect("Device doesn't have curve25519 key"),
237 message_id,
238 }
239 .into(),
240 #[cfg(feature = "experimental-algorithms")]
241 EventEncryptionAlgorithm::OlmV2Curve25519AesSha2 => OlmV2Curve25519AesSha2Content {
242 ciphertext,
243 sender_key: self
244 .our_device_keys
245 .curve25519_key()
246 .expect("Device doesn't have curve25519 key"),
247 message_id,
248 }
249 .into(),
250 _ => unreachable!(),
251 };
252
253 Ok(content)
254 }
255
256 pub fn session_id(&self) -> &str {
258 &self.session_id
259 }
260
261 pub async fn pickle(&self) -> PickledSession {
268 let pickle = self.inner.lock().await.pickle();
269
270 PickledSession {
271 pickle,
272 sender_key: self.sender_key,
273 created_using_fallback_key: self.created_using_fallback_key,
274 creation_time: self.creation_time,
275 last_use_time: self.last_use_time,
276 }
277 }
278
279 pub fn from_pickle(
290 our_device_keys: DeviceKeys,
291 pickle: PickledSession,
292 ) -> Result<Self, SessionUnpickleError> {
293 if our_device_keys.curve25519_key().is_none() {
294 return Err(SessionUnpickleError::MissingIdentityKey);
295 }
296 if our_device_keys.ed25519_key().is_none() {
297 return Err(SessionUnpickleError::MissingSigningKey);
298 }
299
300 let session: vodozemac::olm::Session = pickle.pickle.into();
301 let session_id = session.session_id();
302
303 Ok(Session {
304 inner: Arc::new(Mutex::new(session)),
305 session_id: session_id.into(),
306 created_using_fallback_key: pickle.created_using_fallback_key,
307 sender_key: pickle.sender_key,
308 our_device_keys,
309 creation_time: pickle.creation_time,
310 last_use_time: pickle.last_use_time,
311 })
312 }
313}
314
315impl PartialEq for Session {
316 fn eq(&self, other: &Self) -> bool {
317 self.session_id() == other.session_id()
318 }
319}
320
321#[derive(Serialize, Deserialize)]
326#[allow(missing_debug_implementations)]
327pub struct PickledSession {
328 pub pickle: SessionPickle,
330 pub sender_key: Curve25519PublicKey,
332 #[serde(default)]
334 pub created_using_fallback_key: bool,
335 pub creation_time: SecondsSinceUnixEpoch,
337 pub last_use_time: SecondsSinceUnixEpoch,
339}
340
341#[cfg(test)]
342mod tests {
343 use assert_matches2::assert_let;
344 use matrix_sdk_test::async_test;
345 use ruma::{device_id, user_id};
346 use serde_json::{self, Value};
347 use vodozemac::olm::{OlmMessage, SessionConfig};
348
349 use crate::{
350 identities::DeviceData,
351 olm::Account,
352 types::events::{
353 dummy::DummyEventContent, olm_v1::DecryptedOlmV1Event,
354 room::encrypted::ToDeviceEncryptedEventContent,
355 },
356 };
357
358 #[async_test]
359 async fn test_encryption_and_decryption() {
360 use ruma::events::dummy::ToDeviceDummyEventContent;
361
362 let alice =
364 Account::with_device_id(user_id!("@alice:localhost"), device_id!("ALICEDEVICE"));
365 let mut bob = Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"));
366
367 bob.generate_one_time_keys(1);
369 let one_time_key = *bob.one_time_keys().values().next().unwrap();
370 let sender_key = bob.identity_keys().curve25519;
371 let mut alice_session = alice.create_outbound_session_helper(
372 SessionConfig::default(),
373 sender_key,
374 one_time_key,
375 false,
376 alice.device_keys(),
377 );
378
379 let alice_device = DeviceData::from_account(&alice);
380
381 let message = alice_session
383 .encrypt(&alice_device, "m.dummy", ToDeviceDummyEventContent::new(), None)
384 .await
385 .unwrap()
386 .deserialize()
387 .unwrap();
388
389 #[cfg(feature = "experimental-algorithms")]
390 assert_let!(ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(content) = message);
391 #[cfg(not(feature = "experimental-algorithms"))]
392 assert_let!(ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(content) = message);
393
394 let prekey = if let OlmMessage::PreKey(m) = content.ciphertext {
395 m
396 } else {
397 panic!("Wrong Olm message type");
398 };
399
400 let bob_session_result = bob
402 .create_inbound_session(
403 alice_device.curve25519_key().unwrap(),
404 bob.device_keys(),
405 &prekey,
406 )
407 .unwrap();
408
409 let plaintext: Value = serde_json::from_str(&bob_session_result.plaintext).unwrap();
412 assert_eq!(plaintext["sender_device_keys"]["user_id"].as_str(), Some("@alice:localhost"));
413
414 let event: DecryptedOlmV1Event<DummyEventContent> =
417 serde_json::from_str(&bob_session_result.plaintext).unwrap();
418 assert_eq!(event.sender_device_keys.unwrap(), alice.device_keys());
419 }
420}