1use std::{matches, sync::Arc, time::Duration};
16
17use matrix_sdk_common::locks::Mutex;
18use ruma::{
19 events::{
20 key::verification::{
21 accept::{
22 AcceptMethod, KeyVerificationAcceptEventContent, SasV1Content as AcceptV1Content,
23 SasV1ContentInit as AcceptV1ContentInit, ToDeviceKeyVerificationAcceptEventContent,
24 },
25 cancel::CancelCode,
26 done::{KeyVerificationDoneEventContent, ToDeviceKeyVerificationDoneEventContent},
27 key::{KeyVerificationKeyEventContent, ToDeviceKeyVerificationKeyEventContent},
28 start::{
29 KeyVerificationStartEventContent, SasV1Content, SasV1ContentInit, StartMethod,
30 ToDeviceKeyVerificationStartEventContent,
31 },
32 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode,
33 ShortAuthenticationString,
34 },
35 relation::Reference,
36 AnyMessageLikeEventContent, AnyToDeviceEventContent,
37 },
38 serde::Base64,
39 time::Instant,
40 DeviceId, OwnedTransactionId, TransactionId, UserId,
41};
42use serde::{Deserialize, Serialize};
43use tracing::info;
44use vodozemac::{
45 sas::{EstablishedSas, Mac, Sas},
46 Curve25519PublicKey,
47};
48
49use super::{
50 helpers::{
51 calculate_commitment, get_decimal, get_emoji, get_emoji_index, get_mac_content,
52 receive_mac_event, SasIds,
53 },
54 OutgoingContent,
55};
56use crate::{
57 identities::{DeviceData, UserIdentityData},
58 olm::StaticAccountData,
59 verification::{
60 cache::RequestInfo,
61 event_enums::{
62 AcceptContent, DoneContent, KeyContent, MacContent, OwnedAcceptContent,
63 OwnedStartContent, StartContent,
64 },
65 Cancelled, Emoji, FlowId,
66 },
67 OwnUserIdentityData,
68};
69
70const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
71 &[KeyAgreementProtocol::Curve25519HkdfSha256];
72const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
73const STRINGS: &[ShortAuthenticationString] =
74 &[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
75
76fn the_protocol_definitions(
77 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
78) -> SasV1Content {
79 SasV1ContentInit {
80 short_authentication_string: short_auth_strings.unwrap_or_else(|| STRINGS.to_owned()),
81 key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(),
82 message_authentication_codes: vec![
83 #[allow(deprecated)]
84 MessageAuthenticationCode::HkdfHmacSha256,
85 MessageAuthenticationCode::HkdfHmacSha256V2,
86 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
88 ],
89 hashes: HASHES.to_vec(),
90 }
91 .into()
92}
93
94const MAX_AGE: Duration = Duration::from_secs(60 * 5);
96
97const MAX_EVENT_TIMEOUT: Duration = Duration::from_secs(60);
99
100#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
104pub enum SupportedMacMethod {
105 #[serde(rename = "hkdf-hmac-sha256")]
106 HkdfHmacSha256,
107 #[serde(rename = "hkdf-hmac-sha256.v2")]
108 HkdfHmacSha256V2,
109 #[serde(rename = "org.matrix.msc3783.hkdf-hmac-sha256")]
110 Msc3783HkdfHmacSha256V2,
111}
112
113impl AsRef<str> for SupportedMacMethod {
114 fn as_ref(&self) -> &str {
115 match self {
116 SupportedMacMethod::HkdfHmacSha256 => "hkdf-hmac-sha256",
117 SupportedMacMethod::HkdfHmacSha256V2 => "hkdf-hmac-sha256.v2",
118 SupportedMacMethod::Msc3783HkdfHmacSha256V2 => "org.matrix.msc3783.hkdf-hmac-sha256",
119 }
120 }
121}
122
123impl From<SupportedMacMethod> for MessageAuthenticationCode {
124 fn from(m: SupportedMacMethod) -> Self {
125 MessageAuthenticationCode::from(m.as_ref())
126 }
127}
128
129impl TryFrom<&MessageAuthenticationCode> for SupportedMacMethod {
130 type Error = ();
131
132 fn try_from(value: &MessageAuthenticationCode) -> Result<Self, Self::Error> {
133 match value.as_str() {
134 "hkdf-hmac-sha256" => Ok(Self::HkdfHmacSha256),
135 "org.matrix.msc3783.hkdf-hmac-sha256" => Ok(Self::Msc3783HkdfHmacSha256V2),
136 "hkdf-hmac-sha256.v2" => Ok(Self::HkdfHmacSha256V2),
137 _ => Err(()),
138 }
139 }
140}
141
142impl SupportedMacMethod {
143 pub fn verify_mac(
149 &self,
150 sas: &EstablishedSas,
151 input: &str,
152 info: &str,
153 mac: &Base64,
154 ) -> Result<(), CancelCode> {
155 match self {
156 SupportedMacMethod::HkdfHmacSha256 => {
157 let calculated_mac = sas.calculate_mac_invalid_base64(input, info);
158 let calculated_mac = Base64::parse(calculated_mac)
159 .expect("We can always decode a Mac from vodozemac");
160
161 if calculated_mac != *mac {
162 Err(CancelCode::KeyMismatch)
163 } else {
164 Ok(())
165 }
166 }
167 SupportedMacMethod::HkdfHmacSha256V2 | SupportedMacMethod::Msc3783HkdfHmacSha256V2 => {
168 let mac = Mac::from_slice(mac.as_bytes());
169 sas.verify_mac(input, info, &mac).map_err(|_| CancelCode::MismatchedSas)
170 }
171 }
172 }
173
174 pub fn calculate_mac(&self, sas: &EstablishedSas, input: &str, info: &str) -> Base64 {
180 match self {
181 SupportedMacMethod::HkdfHmacSha256 => {
182 Base64::parse(sas.calculate_mac_invalid_base64(input, info))
183 .expect("We can always decode our newly generated Mac")
184 }
185 SupportedMacMethod::HkdfHmacSha256V2 | SupportedMacMethod::Msc3783HkdfHmacSha256V2 => {
186 let mac = sas.calculate_mac(input, info);
187 Base64::new(mac.as_bytes().to_vec())
188 }
189 }
190 }
191}
192
193#[derive(Clone, Debug, PartialEq, Eq)]
196pub struct AcceptedProtocols {
197 pub key_agreement_protocol: KeyAgreementProtocol,
199 pub hash: HashAlgorithm,
201 pub message_auth_code: SupportedMacMethod,
203 pub short_auth_string: Vec<ShortAuthenticationString>,
206}
207
208impl TryFrom<AcceptV1Content> for AcceptedProtocols {
209 type Error = CancelCode;
210
211 fn try_from(content: AcceptV1Content) -> Result<Self, Self::Error> {
212 if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
213 || !HASHES.contains(&content.hash)
214 || (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
215 && !content
216 .short_authentication_string
217 .contains(&ShortAuthenticationString::Decimal))
218 {
219 Err(CancelCode::UnknownMethod)
220 } else {
221 let message_auth_code = (&content.message_authentication_code)
222 .try_into()
223 .map_err(|_| CancelCode::UnknownMethod)?;
224
225 Ok(Self {
226 hash: content.hash,
227 key_agreement_protocol: content.key_agreement_protocol,
228 message_auth_code,
229 short_auth_string: content.short_authentication_string,
230 })
231 }
232 }
233}
234
235impl TryFrom<&SasV1Content> for AcceptedProtocols {
236 type Error = CancelCode;
237
238 fn try_from(method_content: &SasV1Content) -> Result<Self, Self::Error> {
239 if !method_content
240 .key_agreement_protocols
241 .contains(&KeyAgreementProtocol::Curve25519HkdfSha256)
242 || !method_content.hashes.contains(&HashAlgorithm::Sha256)
243 || (!method_content
244 .short_authentication_string
245 .contains(&ShortAuthenticationString::Decimal)
246 && !method_content
247 .short_authentication_string
248 .contains(&ShortAuthenticationString::Emoji))
249 {
250 Err(CancelCode::UnknownMethod)
251 } else {
252 let mac_methods: Vec<SupportedMacMethod> = method_content
253 .message_authentication_codes
254 .iter()
255 .filter_map(|m| SupportedMacMethod::try_from(m).ok())
256 .collect();
257
258 let message_auth_code =
259 if mac_methods.contains(&SupportedMacMethod::HkdfHmacSha256V2) {
260 Some(SupportedMacMethod::HkdfHmacSha256V2)
261 } else if mac_methods.contains(&SupportedMacMethod::Msc3783HkdfHmacSha256V2) {
262 Some(SupportedMacMethod::Msc3783HkdfHmacSha256V2)
263 } else {
264 mac_methods.first().copied()
265 }
266 .ok_or(CancelCode::UnknownMethod)?;
267
268 let mut short_auth_string = vec![];
269
270 if method_content
271 .short_authentication_string
272 .contains(&ShortAuthenticationString::Decimal)
273 {
274 short_auth_string.push(ShortAuthenticationString::Decimal)
275 }
276
277 if method_content
278 .short_authentication_string
279 .contains(&ShortAuthenticationString::Emoji)
280 {
281 short_auth_string.push(ShortAuthenticationString::Emoji);
282 }
283
284 Ok(Self {
285 hash: HashAlgorithm::Sha256,
286 key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
287 message_auth_code,
288 short_auth_string,
289 })
290 }
291 }
292}
293
294#[cfg(not(tarpaulin_include))]
295impl Default for AcceptedProtocols {
296 fn default() -> Self {
297 AcceptedProtocols {
298 hash: HashAlgorithm::Sha256,
299 key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
300 message_auth_code: SupportedMacMethod::HkdfHmacSha256V2,
301 short_auth_string: vec![
302 ShortAuthenticationString::Decimal,
303 ShortAuthenticationString::Emoji,
304 ],
305 }
306 }
307}
308
309#[derive(Clone)]
314pub struct SasState<S: Clone> {
315 inner: Arc<Mutex<Option<Sas>>>,
317
318 our_public_key: Curve25519PublicKey,
320
321 ids: Box<SasIds>,
324
325 creation_time: Arc<Instant>,
329
330 last_event_time: Arc<Instant>,
332
333 pub verification_flow_id: Arc<FlowId>,
338
339 pub state: Arc<S>,
341
342 pub started_from_request: bool,
344}
345
346impl<S: Clone> SasState<S> {
347 fn handle_key_content(
348 &self,
349 sender: &UserId,
350 content: &KeyContent<'_>,
351 ) -> Result<EstablishedSas, CancelCode> {
352 self.check_event(sender, content.flow_id())?;
353
354 let their_public_key = Curve25519PublicKey::from_slice(content.public_key().as_bytes())
355 .map_err(|_| CancelCode::from("Invalid public key"))?;
356
357 if let Some(sas) = self.inner.lock().take() {
358 sas.diffie_hellman(their_public_key).map_err(|_| "Invalid public key".into())
359 } else {
360 Err(CancelCode::UnexpectedMessage)
361 }
362 }
363}
364
365#[cfg(not(tarpaulin_include))]
366impl<S: Clone + std::fmt::Debug> std::fmt::Debug for SasState<S> {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 f.debug_struct("SasState")
369 .field("ids", &self.ids)
370 .field("flow_id", &self.verification_flow_id)
371 .field("state", &self.state)
372 .finish()
373 }
374}
375
376#[derive(Clone, Debug)]
378pub struct Created {
379 pub protocol_definitions: SasV1Content,
380}
381
382#[derive(Clone, Debug)]
384pub struct Started {
385 commitment: Base64,
386 pub protocol_definitions: SasV1Content,
387 pub accepted_protocols: AcceptedProtocols,
388}
389
390#[derive(Clone, Debug)]
393pub struct Accepted {
394 pub accepted_protocols: AcceptedProtocols,
395 start_content: Arc<OwnedStartContent>,
396 pub request_id: OwnedTransactionId,
397 commitment: Base64,
398}
399
400#[derive(Clone, Debug)]
403pub struct WeAccepted {
404 we_started: bool,
405 pub accepted_protocols: AcceptedProtocols,
406 commitment: Base64,
407}
408
409#[derive(Clone, Debug)]
414pub struct KeyReceived {
415 sas: Arc<Mutex<EstablishedSas>>,
416 we_started: bool,
417 pub request_id: OwnedTransactionId,
418 pub accepted_protocols: AcceptedProtocols,
419}
420
421#[derive(Clone, Debug)]
422pub struct KeySent {
423 we_started: bool,
424 start_content: Arc<OwnedStartContent>,
425 commitment: Base64,
426 pub accepted_protocols: AcceptedProtocols,
427}
428
429#[derive(Clone, Debug)]
430pub struct KeysExchanged {
431 sas: Arc<Mutex<EstablishedSas>>,
432 we_started: bool,
433 pub accepted_protocols: AcceptedProtocols,
434}
435
436#[derive(Clone, Debug)]
440pub struct Confirmed {
441 sas: Arc<Mutex<EstablishedSas>>,
442 pub accepted_protocols: AcceptedProtocols,
443}
444
445#[derive(Clone, Debug)]
449pub struct MacReceived {
450 sas: Arc<Mutex<EstablishedSas>>,
451 we_started: bool,
452 verified_devices: Arc<[DeviceData]>,
453 verified_master_keys: Arc<[UserIdentityData]>,
454 pub accepted_protocols: AcceptedProtocols,
455}
456
457#[derive(Clone, Debug)]
461pub struct WaitingForDone {
462 sas: Arc<Mutex<EstablishedSas>>,
463 verified_devices: Arc<[DeviceData]>,
464 verified_master_keys: Arc<[UserIdentityData]>,
465 pub accepted_protocols: AcceptedProtocols,
466}
467
468#[derive(Clone, Debug)]
474pub struct Done {
475 sas: Arc<Mutex<EstablishedSas>>,
476 verified_devices: Arc<[DeviceData]>,
477 verified_master_keys: Arc<[UserIdentityData]>,
478 pub accepted_protocols: AcceptedProtocols,
479}
480
481impl<S: Clone> SasState<S> {
482 #[cfg(test)]
484 pub fn user_id(&self) -> &UserId {
485 &self.ids.account.user_id
486 }
487
488 pub fn device_id(&self) -> &DeviceId {
490 &self.ids.account.device_id
491 }
492
493 #[cfg(test)]
494 pub fn other_device(&self) -> DeviceData {
495 self.ids.other_device.clone()
496 }
497
498 pub fn cancel(self, cancelled_by_us: bool, cancel_code: CancelCode) -> SasState<Cancelled> {
499 SasState {
500 inner: self.inner,
501 our_public_key: self.our_public_key,
502 ids: self.ids,
503 creation_time: self.creation_time,
504 last_event_time: self.last_event_time,
505 verification_flow_id: self.verification_flow_id,
506 state: Arc::new(Cancelled::new(cancelled_by_us, cancel_code)),
507 started_from_request: self.started_from_request,
508 }
509 }
510
511 pub fn timed_out(&self) -> bool {
513 self.creation_time.elapsed() > MAX_AGE || self.last_event_time.elapsed() > MAX_EVENT_TIMEOUT
514 }
515
516 #[allow(dead_code)]
518 pub fn is_dm_verification(&self) -> bool {
519 matches!(&*self.verification_flow_id, FlowId::InRoom(_, _))
520 }
521
522 #[cfg(test)]
523 #[allow(dead_code)]
524 pub fn set_creation_time(&mut self, time: Instant) {
525 self.creation_time = Arc::new(time);
526 }
527
528 fn check_event(&self, sender: &UserId, flow_id: &str) -> Result<(), CancelCode> {
529 if *flow_id != *self.verification_flow_id.as_str() {
530 Err(CancelCode::UnknownTransaction)
531 } else if sender != self.ids.other_device.user_id() {
532 Err(CancelCode::UserMismatch)
533 } else if self.timed_out() {
534 Err(CancelCode::Timeout)
535 } else {
536 Ok(())
537 }
538 }
539}
540
541impl SasState<Created> {
542 pub fn new(
552 account: StaticAccountData,
553 other_device: DeviceData,
554 own_identity: Option<OwnUserIdentityData>,
555 other_identity: Option<UserIdentityData>,
556 flow_id: FlowId,
557 started_from_request: bool,
558 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
559 ) -> SasState<Created> {
560 Self::new_helper(
561 flow_id,
562 account,
563 other_device,
564 own_identity,
565 other_identity,
566 started_from_request,
567 short_auth_strings,
568 )
569 }
570
571 fn new_helper(
572 flow_id: FlowId,
573 account: StaticAccountData,
574 other_device: DeviceData,
575 own_identity: Option<OwnUserIdentityData>,
576 other_identity: Option<UserIdentityData>,
577 started_from_request: bool,
578 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
579 ) -> SasState<Created> {
580 let sas = Sas::new();
581 let our_public_key = sas.public_key();
582
583 let protocol_definitions = the_protocol_definitions(short_auth_strings);
584
585 SasState {
586 inner: Arc::new(Mutex::new(Some(sas))),
587 our_public_key,
588 ids: Box::new(SasIds { account, other_device, other_identity, own_identity }),
589 verification_flow_id: flow_id.into(),
590
591 creation_time: Arc::new(Instant::now()),
592 last_event_time: Arc::new(Instant::now()),
593 started_from_request,
594
595 state: Arc::new(Created { protocol_definitions }),
596 }
597 }
598
599 pub fn as_content(&self) -> OwnedStartContent {
600 match self.verification_flow_id.as_ref() {
601 FlowId::ToDevice(s) => {
602 OwnedStartContent::ToDevice(ToDeviceKeyVerificationStartEventContent::new(
603 self.device_id().into(),
604 s.clone(),
605 StartMethod::SasV1(self.state.protocol_definitions.clone()),
606 ))
607 }
608 FlowId::InRoom(r, e) => OwnedStartContent::Room(
609 r.clone(),
610 KeyVerificationStartEventContent::new(
611 self.device_id().into(),
612 StartMethod::SasV1(self.state.protocol_definitions.clone()),
613 Reference::new(e.clone()),
614 ),
615 ),
616 }
617 }
618
619 pub fn into_accepted(
627 self,
628 sender: &UserId,
629 content: &AcceptContent<'_>,
630 ) -> Result<SasState<Accepted>, SasState<Cancelled>> {
631 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
632
633 let AcceptMethod::SasV1(content) = content.method() else {
634 return Err(self.cancel(true, CancelCode::UnknownMethod));
635 };
636
637 let accepted_protocols = AcceptedProtocols::try_from(content.clone())
638 .map_err(|c| self.clone().cancel(true, c))?;
639
640 let start_content = self.as_content().into();
641
642 Ok(SasState {
643 inner: self.inner,
644 our_public_key: self.our_public_key,
645 ids: self.ids,
646 verification_flow_id: self.verification_flow_id,
647 creation_time: self.creation_time,
648 last_event_time: Instant::now().into(),
649 started_from_request: self.started_from_request,
650 state: Arc::new(Accepted {
651 start_content,
652 commitment: content.commitment.clone(),
653 request_id: TransactionId::new(),
654 accepted_protocols,
655 }),
656 })
657 }
658}
659
660impl SasState<Started> {
661 pub fn from_start_event(
675 account: StaticAccountData,
676 other_device: DeviceData,
677 own_identity: Option<OwnUserIdentityData>,
678 other_identity: Option<UserIdentityData>,
679 flow_id: FlowId,
680 content: &StartContent<'_>,
681 started_from_request: bool,
682 ) -> Result<SasState<Started>, SasState<Cancelled>> {
683 let flow_id = Arc::new(flow_id);
684
685 let sas = Sas::new();
686 let our_public_key = sas.public_key();
687
688 let canceled = || SasState {
689 inner: Arc::new(Mutex::new(None)),
690 our_public_key,
691
692 creation_time: Arc::new(Instant::now()),
693 last_event_time: Arc::new(Instant::now()),
694 started_from_request,
695
696 ids: Box::new(SasIds {
697 account: account.clone(),
698 other_device: other_device.clone(),
699 own_identity: own_identity.clone(),
700 other_identity: other_identity.clone(),
701 }),
702
703 verification_flow_id: flow_id.clone(),
704 state: Arc::new(Cancelled::new(true, CancelCode::UnknownMethod)),
705 };
706
707 let state = match content.method() {
708 StartMethod::SasV1(method_content) => {
709 let commitment = calculate_commitment(our_public_key, content);
710
711 info!(
712 public_key = our_public_key.to_base64(),
713 ?commitment,
714 ?content,
715 "Calculated SAS commitment",
716 );
717
718 let Ok(accepted_protocols) = AcceptedProtocols::try_from(method_content) else {
719 return Err(canceled());
720 };
721
722 Started {
723 protocol_definitions: method_content.to_owned(),
724 accepted_protocols,
725 commitment,
726 }
727 }
728 _ => return Err(canceled()),
729 };
730
731 Ok(SasState {
732 inner: Arc::new(Mutex::new(Some(sas))),
733 our_public_key,
734
735 ids: Box::new(SasIds { account, other_device, other_identity, own_identity }),
736
737 creation_time: Arc::new(Instant::now()),
738 last_event_time: Arc::new(Instant::now()),
739 started_from_request,
740
741 verification_flow_id: flow_id,
742
743 state: Arc::new(state),
744 })
745 }
746
747 #[cfg(test)]
748 fn into_we_accepted_with_mac_method(
749 self,
750 methods: Vec<ShortAuthenticationString>,
751 mac_method: Option<SupportedMacMethod>,
752 ) -> SasState<WeAccepted> {
753 let mut accepted_protocols = self.state.accepted_protocols.to_owned();
754
755 if let Some(mac_method) = mac_method {
756 accepted_protocols.message_auth_code = mac_method;
757 }
758
759 self.into_we_accepted_helper(accepted_protocols, methods)
760 }
761
762 fn into_we_accepted_helper(
763 self,
764 mut accepted_protocols: AcceptedProtocols,
765 methods: Vec<ShortAuthenticationString>,
766 ) -> SasState<WeAccepted> {
767 accepted_protocols.short_auth_string = methods;
768
769 if !accepted_protocols.short_auth_string.contains(&ShortAuthenticationString::Decimal) {
771 accepted_protocols.short_auth_string.push(ShortAuthenticationString::Decimal);
772 }
773
774 SasState {
775 inner: self.inner,
776 our_public_key: self.our_public_key,
777 ids: self.ids,
778 verification_flow_id: self.verification_flow_id,
779 creation_time: self.creation_time,
780 last_event_time: self.last_event_time,
781 started_from_request: self.started_from_request,
782 state: Arc::new(WeAccepted {
783 we_started: false,
784 accepted_protocols,
785 commitment: self.state.commitment.clone(),
786 }),
787 }
788 }
789
790 pub fn into_we_accepted(self, methods: Vec<ShortAuthenticationString>) -> SasState<WeAccepted> {
791 let accepted_protocols = self.state.accepted_protocols.to_owned();
792 self.into_we_accepted_helper(accepted_protocols, methods)
793 }
794
795 fn as_content(&self) -> OwnedStartContent {
796 match self.verification_flow_id.as_ref() {
797 FlowId::ToDevice(s) => {
798 OwnedStartContent::ToDevice(ToDeviceKeyVerificationStartEventContent::new(
799 self.device_id().into(),
800 s.clone(),
801 StartMethod::SasV1(self.state.protocol_definitions.to_owned()),
802 ))
803 }
804 FlowId::InRoom(r, e) => OwnedStartContent::Room(
805 r.clone(),
806 KeyVerificationStartEventContent::new(
807 self.device_id().into(),
808 StartMethod::SasV1(self.state.protocol_definitions.to_owned()),
809 Reference::new(e.clone()),
810 ),
811 ),
812 }
813 }
814
815 pub fn into_accepted(
828 self,
829 sender: &UserId,
830 content: &AcceptContent<'_>,
831 ) -> Result<SasState<Accepted>, SasState<Cancelled>> {
832 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
833
834 let AcceptMethod::SasV1(content) = content.method() else {
835 return Err(self.cancel(true, CancelCode::UnknownMethod));
836 };
837
838 let accepted_protocols = AcceptedProtocols::try_from(content.clone())
839 .map_err(|c| self.clone().cancel(true, c))?;
840
841 let start_content = self.as_content().into();
842
843 Ok(SasState {
844 inner: self.inner,
845 our_public_key: self.our_public_key,
846 ids: self.ids,
847 verification_flow_id: self.verification_flow_id,
848 creation_time: self.creation_time,
849 last_event_time: Instant::now().into(),
850 started_from_request: self.started_from_request,
851 state: Arc::new(Accepted {
852 start_content,
853 commitment: content.commitment.clone(),
854 request_id: TransactionId::new(),
855 accepted_protocols,
856 }),
857 })
858 }
859}
860
861impl SasState<WeAccepted> {
862 pub fn as_content(&self) -> OwnedAcceptContent {
870 let method = AcceptMethod::SasV1(
871 AcceptV1ContentInit {
872 commitment: self.state.commitment.clone(),
873 hash: self.state.accepted_protocols.hash.clone(),
874 key_agreement_protocol: self
875 .state
876 .accepted_protocols
877 .key_agreement_protocol
878 .clone(),
879 message_authentication_code: self.state.accepted_protocols.message_auth_code.into(),
880 short_authentication_string: self
881 .state
882 .accepted_protocols
883 .short_auth_string
884 .clone(),
885 }
886 .into(),
887 );
888
889 match self.verification_flow_id.as_ref() {
890 FlowId::ToDevice(s) => {
891 ToDeviceKeyVerificationAcceptEventContent::new(s.clone(), method).into()
892 }
893 FlowId::InRoom(r, e) => (
894 r.clone(),
895 KeyVerificationAcceptEventContent::new(method, Reference::new(e.clone())),
896 )
897 .into(),
898 }
899 }
900
901 pub fn into_key_received(
910 self,
911 sender: &UserId,
912 content: &KeyContent<'_>,
913 ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
914 let established =
915 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
916
917 Ok(SasState {
918 inner: self.inner,
919 our_public_key: self.our_public_key,
920 ids: self.ids,
921 verification_flow_id: self.verification_flow_id,
922 creation_time: self.creation_time,
923 last_event_time: Instant::now().into(),
924 started_from_request: self.started_from_request,
925 state: Arc::new(KeyReceived {
926 sas: Mutex::new(established).into(),
927 we_started: self.state.we_started,
928 request_id: TransactionId::new(),
929 accepted_protocols: self.state.accepted_protocols.clone(),
930 }),
931 })
932 }
933}
934
935impl SasState<Accepted> {
936 pub fn into_key_received(
945 self,
946 sender: &UserId,
947 content: &KeyContent<'_>,
948 ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
949 let established =
950 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
951
952 let their_public_key = established.their_public_key();
953
954 let commitment =
955 calculate_commitment(their_public_key, &self.state.start_content.as_start_content());
956
957 if self.state.commitment == commitment {
958 Ok(SasState {
959 inner: self.inner,
960 our_public_key: self.our_public_key,
961 ids: self.ids,
962 verification_flow_id: self.verification_flow_id,
963 creation_time: self.creation_time,
964 last_event_time: Instant::now().into(),
965 started_from_request: self.started_from_request,
966 state: Arc::new(KeyReceived {
967 sas: Mutex::new(established).into(),
968 we_started: true,
969 request_id: self.state.request_id.to_owned(),
970 accepted_protocols: self.state.accepted_protocols.clone(),
971 }),
972 })
973 } else {
974 Err(self.cancel(true, CancelCode::KeyMismatch))
975 }
976 }
977
978 pub fn into_key_sent(self, request_id: &TransactionId) -> Option<SasState<KeySent>> {
979 (self.state.request_id == request_id).then(|| SasState {
980 inner: self.inner,
981 our_public_key: self.our_public_key,
982 ids: self.ids,
983 verification_flow_id: self.verification_flow_id,
984 creation_time: self.creation_time,
985 last_event_time: Instant::now().into(),
986 started_from_request: self.started_from_request,
987 state: Arc::new(KeySent {
988 we_started: true,
989 start_content: self.state.start_content.clone(),
990 commitment: self.state.commitment.clone(),
991 accepted_protocols: self.state.accepted_protocols.clone(),
992 }),
993 })
994 }
995
996 pub fn as_content(&self) -> (OutgoingContent, RequestInfo) {
1000 let content = match &*self.verification_flow_id {
1001 FlowId::ToDevice(s) => AnyToDeviceEventContent::KeyVerificationKey(
1002 ToDeviceKeyVerificationKeyEventContent::new(
1003 s.clone(),
1004 Base64::new(self.our_public_key.to_vec()),
1005 ),
1006 )
1007 .into(),
1008 FlowId::InRoom(r, e) => (
1009 r.clone(),
1010 AnyMessageLikeEventContent::KeyVerificationKey(
1011 KeyVerificationKeyEventContent::new(
1012 Base64::new(self.our_public_key.to_vec()),
1013 Reference::new(e.clone()),
1014 ),
1015 ),
1016 )
1017 .into(),
1018 };
1019
1020 (
1021 content,
1022 RequestInfo {
1023 flow_id: (*self.verification_flow_id).to_owned(),
1024 request_id: self.state.request_id.to_owned(),
1025 },
1026 )
1027 }
1028}
1029
1030impl SasState<KeySent> {
1031 pub fn into_keys_exchanged(
1032 self,
1033 sender: &UserId,
1034 content: &KeyContent<'_>,
1035 ) -> Result<SasState<KeysExchanged>, SasState<Cancelled>> {
1036 let established =
1037 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
1038
1039 let their_public_key = established.their_public_key();
1040 let commitment =
1041 calculate_commitment(their_public_key, &self.state.start_content.as_start_content());
1042
1043 if self.state.commitment == commitment {
1044 Ok(SasState {
1045 inner: self.inner,
1046 our_public_key: self.our_public_key,
1047 ids: self.ids,
1048 verification_flow_id: self.verification_flow_id,
1049 creation_time: self.creation_time,
1050 last_event_time: Instant::now().into(),
1051 started_from_request: self.started_from_request,
1052 state: Arc::new(KeysExchanged {
1053 sas: Mutex::new(established).into(),
1054 we_started: self.state.we_started,
1055 accepted_protocols: self.state.accepted_protocols.clone(),
1056 }),
1057 })
1058 } else {
1059 Err(self.cancel(true, CancelCode::KeyMismatch))
1060 }
1061 }
1062}
1063
1064impl SasState<KeyReceived> {
1065 pub fn as_content(&self) -> (OutgoingContent, RequestInfo) {
1070 let content = match &*self.verification_flow_id {
1071 FlowId::ToDevice(s) => AnyToDeviceEventContent::KeyVerificationKey(
1072 ToDeviceKeyVerificationKeyEventContent::new(
1073 s.clone(),
1074 Base64::new(self.our_public_key.to_vec()),
1075 ),
1076 )
1077 .into(),
1078 FlowId::InRoom(r, e) => (
1079 r.clone(),
1080 AnyMessageLikeEventContent::KeyVerificationKey(
1081 KeyVerificationKeyEventContent::new(
1082 Base64::new(self.our_public_key.to_vec()),
1083 Reference::new(e.clone()),
1084 ),
1085 ),
1086 )
1087 .into(),
1088 };
1089
1090 (
1091 content,
1092 RequestInfo {
1093 flow_id: (*self.verification_flow_id).to_owned(),
1094 request_id: self.state.request_id.to_owned(),
1095 },
1096 )
1097 }
1098
1099 pub fn into_keys_exchanged(
1100 self,
1101 request_id: &TransactionId,
1102 ) -> Option<SasState<KeysExchanged>> {
1103 (self.state.request_id == request_id).then(|| SasState {
1104 inner: self.inner,
1105 our_public_key: self.our_public_key,
1106 ids: self.ids,
1107 verification_flow_id: self.verification_flow_id,
1108 creation_time: self.creation_time,
1109 last_event_time: Instant::now().into(),
1110 started_from_request: self.started_from_request,
1111 state: KeysExchanged {
1112 sas: self.state.sas.clone(),
1113 we_started: self.state.we_started,
1114 accepted_protocols: self.state.accepted_protocols.clone(),
1115 }
1116 .into(),
1117 })
1118 }
1119}
1120
1121impl SasState<KeysExchanged> {
1122 pub fn get_emoji(&self) -> [Emoji; 7] {
1127 get_emoji(
1128 &self.state.sas.lock(),
1129 &self.ids,
1130 self.verification_flow_id.as_str(),
1131 self.state.we_started,
1132 )
1133 }
1134
1135 pub fn get_emoji_index(&self) -> [u8; 7] {
1140 get_emoji_index(
1141 &self.state.sas.lock(),
1142 &self.ids,
1143 self.verification_flow_id.as_str(),
1144 self.state.we_started,
1145 )
1146 }
1147
1148 pub fn get_decimal(&self) -> (u16, u16, u16) {
1153 get_decimal(
1154 &self.state.sas.lock(),
1155 &self.ids,
1156 self.verification_flow_id.as_str(),
1157 self.state.we_started,
1158 )
1159 }
1160
1161 pub fn into_mac_received(
1169 self,
1170 sender: &UserId,
1171 content: &MacContent<'_>,
1172 ) -> Result<SasState<MacReceived>, SasState<Cancelled>> {
1173 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1174
1175 let (devices, master_keys) = receive_mac_event(
1176 &self.state.sas.lock(),
1177 &self.ids,
1178 self.verification_flow_id.as_str(),
1179 sender,
1180 self.state.accepted_protocols.message_auth_code,
1181 content,
1182 )
1183 .map_err(|c| self.clone().cancel(true, c))?;
1184
1185 Ok(SasState {
1186 inner: self.inner,
1187 our_public_key: self.our_public_key,
1188 verification_flow_id: self.verification_flow_id,
1189 creation_time: self.creation_time,
1190 last_event_time: Instant::now().into(),
1191 ids: self.ids,
1192 started_from_request: self.started_from_request,
1193 state: Arc::new(MacReceived {
1194 sas: self.state.sas.clone(),
1195 we_started: self.state.we_started,
1196 verified_devices: devices.into(),
1197 verified_master_keys: master_keys.into(),
1198 accepted_protocols: self.state.accepted_protocols.clone(),
1199 }),
1200 })
1201 }
1202
1203 pub fn confirm(self) -> SasState<Confirmed> {
1208 SasState {
1209 inner: self.inner,
1210 our_public_key: self.our_public_key,
1211 started_from_request: self.started_from_request,
1212 verification_flow_id: self.verification_flow_id,
1213 creation_time: self.creation_time,
1214 last_event_time: self.last_event_time,
1215 ids: self.ids,
1216 state: Arc::new(Confirmed {
1217 sas: self.state.sas.clone(),
1218 accepted_protocols: self.state.accepted_protocols.clone(),
1219 }),
1220 }
1221 }
1222}
1223
1224impl SasState<Confirmed> {
1225 pub fn into_done(
1233 self,
1234 sender: &UserId,
1235 content: &MacContent<'_>,
1236 ) -> Result<SasState<Done>, SasState<Cancelled>> {
1237 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1238
1239 let (devices, master_keys) = receive_mac_event(
1240 &self.state.sas.lock(),
1241 &self.ids,
1242 self.verification_flow_id.as_str(),
1243 sender,
1244 self.state.accepted_protocols.message_auth_code,
1245 content,
1246 )
1247 .map_err(|c| self.clone().cancel(true, c))?;
1248
1249 Ok(SasState {
1250 inner: self.inner,
1251 our_public_key: self.our_public_key,
1252 creation_time: self.creation_time,
1253 last_event_time: Instant::now().into(),
1254 verification_flow_id: self.verification_flow_id,
1255 started_from_request: self.started_from_request,
1256 ids: self.ids,
1257
1258 state: Arc::new(Done {
1259 sas: self.state.sas.clone(),
1260 verified_devices: devices.into(),
1261 verified_master_keys: master_keys.into(),
1262 accepted_protocols: self.state.accepted_protocols.clone(),
1263 }),
1264 })
1265 }
1266
1267 pub fn into_waiting_for_done(
1277 self,
1278 sender: &UserId,
1279 content: &MacContent<'_>,
1280 ) -> Result<SasState<WaitingForDone>, SasState<Cancelled>> {
1281 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1282
1283 let (devices, master_keys) = receive_mac_event(
1284 &self.state.sas.lock(),
1285 &self.ids,
1286 self.verification_flow_id.as_str(),
1287 sender,
1288 self.state.accepted_protocols.message_auth_code,
1289 content,
1290 )
1291 .map_err(|c| self.clone().cancel(true, c))?;
1292
1293 Ok(SasState {
1294 inner: self.inner,
1295 our_public_key: self.our_public_key,
1296 creation_time: self.creation_time,
1297 last_event_time: Instant::now().into(),
1298 verification_flow_id: self.verification_flow_id,
1299 started_from_request: self.started_from_request,
1300 ids: self.ids,
1301
1302 state: Arc::new(WaitingForDone {
1303 sas: self.state.sas.clone(),
1304 verified_devices: devices.into(),
1305 verified_master_keys: master_keys.into(),
1306 accepted_protocols: self.state.accepted_protocols.clone(),
1307 }),
1308 })
1309 }
1310
1311 pub fn as_content(&self) -> OutgoingContent {
1315 get_mac_content(
1316 &self.state.sas.lock(),
1317 &self.ids,
1318 &self.verification_flow_id,
1319 self.state.accepted_protocols.message_auth_code,
1320 )
1321 }
1322}
1323
1324impl SasState<MacReceived> {
1325 pub fn confirm(self) -> SasState<Done> {
1330 SasState {
1331 inner: self.inner,
1332 our_public_key: self.our_public_key,
1333 verification_flow_id: self.verification_flow_id,
1334 creation_time: self.creation_time,
1335 started_from_request: self.started_from_request,
1336 last_event_time: self.last_event_time,
1337 ids: self.ids,
1338 state: Arc::new(Done {
1339 sas: self.state.sas.clone(),
1340 verified_devices: self.state.verified_devices.clone(),
1341 verified_master_keys: self.state.verified_master_keys.clone(),
1342 accepted_protocols: self.state.accepted_protocols.clone(),
1343 }),
1344 }
1345 }
1346
1347 pub fn confirm_and_wait_for_done(self) -> SasState<WaitingForDone> {
1354 SasState {
1355 inner: self.inner,
1356 our_public_key: self.our_public_key,
1357 verification_flow_id: self.verification_flow_id,
1358 creation_time: self.creation_time,
1359 started_from_request: self.started_from_request,
1360 last_event_time: self.last_event_time,
1361 ids: self.ids,
1362 state: Arc::new(WaitingForDone {
1363 sas: self.state.sas.clone(),
1364 verified_devices: self.state.verified_devices.clone(),
1365 verified_master_keys: self.state.verified_master_keys.clone(),
1366 accepted_protocols: self.state.accepted_protocols.clone(),
1367 }),
1368 }
1369 }
1370
1371 pub fn get_emoji(&self) -> [Emoji; 7] {
1376 get_emoji(
1377 &self.state.sas.lock(),
1378 &self.ids,
1379 self.verification_flow_id.as_str(),
1380 self.state.we_started,
1381 )
1382 }
1383
1384 pub fn get_emoji_index(&self) -> [u8; 7] {
1389 get_emoji_index(
1390 &self.state.sas.lock(),
1391 &self.ids,
1392 self.verification_flow_id.as_str(),
1393 self.state.we_started,
1394 )
1395 }
1396
1397 pub fn get_decimal(&self) -> (u16, u16, u16) {
1402 get_decimal(
1403 &self.state.sas.lock(),
1404 &self.ids,
1405 self.verification_flow_id.as_str(),
1406 self.state.we_started,
1407 )
1408 }
1409}
1410
1411impl SasState<WaitingForDone> {
1412 pub fn as_content(&self) -> OutgoingContent {
1417 get_mac_content(
1418 &self.state.sas.lock(),
1419 &self.ids,
1420 &self.verification_flow_id,
1421 self.state.accepted_protocols.message_auth_code,
1422 )
1423 }
1424
1425 pub fn done_content(&self) -> OutgoingContent {
1426 match self.verification_flow_id.as_ref() {
1427 FlowId::ToDevice(t) => AnyToDeviceEventContent::KeyVerificationDone(
1428 ToDeviceKeyVerificationDoneEventContent::new(t.to_owned()),
1429 )
1430 .into(),
1431 FlowId::InRoom(r, e) => (
1432 r.clone(),
1433 AnyMessageLikeEventContent::KeyVerificationDone(
1434 KeyVerificationDoneEventContent::new(Reference::new(e.clone())),
1435 ),
1436 )
1437 .into(),
1438 }
1439 }
1440
1441 pub fn into_done(
1449 self,
1450 sender: &UserId,
1451 content: &DoneContent<'_>,
1452 ) -> Result<SasState<Done>, SasState<Cancelled>> {
1453 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1454
1455 Ok(SasState {
1456 inner: self.inner,
1457 our_public_key: self.our_public_key,
1458 creation_time: self.creation_time,
1459 last_event_time: Instant::now().into(),
1460 verification_flow_id: self.verification_flow_id,
1461 started_from_request: self.started_from_request,
1462 ids: self.ids,
1463
1464 state: Arc::new(Done {
1465 sas: self.state.sas.clone(),
1466 verified_devices: self.state.verified_devices.clone(),
1467 verified_master_keys: self.state.verified_master_keys.clone(),
1468 accepted_protocols: self.state.accepted_protocols.clone(),
1469 }),
1470 })
1471 }
1472}
1473
1474impl SasState<Done> {
1475 pub fn as_content(&self) -> OutgoingContent {
1480 get_mac_content(
1481 &self.state.sas.lock(),
1482 &self.ids,
1483 &self.verification_flow_id,
1484 self.state.accepted_protocols.message_auth_code,
1485 )
1486 }
1487
1488 pub fn verified_devices(&self) -> Arc<[DeviceData]> {
1490 self.state.verified_devices.clone()
1491 }
1492
1493 pub fn verified_identities(&self) -> Arc<[UserIdentityData]> {
1495 self.state.verified_master_keys.clone()
1496 }
1497}
1498
1499impl SasState<Cancelled> {
1500 pub fn as_content(&self) -> OutgoingContent {
1501 self.state.as_content(&self.verification_flow_id)
1502 }
1503}
1504
1505#[cfg(test)]
1506mod tests {
1507 use matrix_sdk_test::async_test;
1508 use ruma::{
1509 device_id,
1510 events::key::verification::{
1511 accept::{AcceptMethod, ToDeviceKeyVerificationAcceptEventContent},
1512 start::{
1513 SasV1Content, SasV1ContentInit, StartMethod,
1514 ToDeviceKeyVerificationStartEventContent,
1515 },
1516 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode,
1517 ShortAuthenticationString,
1518 },
1519 serde::Base64,
1520 user_id, DeviceId, TransactionId, UserId,
1521 };
1522 use serde_json::json;
1523
1524 use super::{Accepted, Created, SasState, Started, SupportedMacMethod, WeAccepted};
1525 use crate::{
1526 verification::{
1527 event_enums::{AcceptContent, KeyContent, MacContent, StartContent},
1528 FlowId,
1529 },
1530 AcceptedProtocols, Account, DeviceData,
1531 };
1532
1533 fn alice_id() -> &'static UserId {
1534 user_id!("@alice:example.org")
1535 }
1536
1537 fn alice_device_id() -> &'static DeviceId {
1538 device_id!("JLAFKJWSCS")
1539 }
1540
1541 fn bob_id() -> &'static UserId {
1542 user_id!("@bob:example.org")
1543 }
1544
1545 fn bob_device_id() -> &'static DeviceId {
1546 device_id!("BOBDEVICE")
1547 }
1548
1549 fn get_sas_pair(
1550 mac_method: Option<SupportedMacMethod>,
1551 ) -> (SasState<Created>, SasState<WeAccepted>) {
1552 let alice = Account::with_device_id(alice_id(), alice_device_id());
1553 let alice_device = DeviceData::from_account(&alice);
1554
1555 let bob = Account::with_device_id(bob_id(), bob_device_id());
1556 let bob_device = DeviceData::from_account(&bob);
1557
1558 let flow_id = TransactionId::new().into();
1559 let alice_sas = SasState::<Created>::new(
1560 alice.static_data().clone(),
1561 bob_device,
1562 None,
1563 None,
1564 flow_id,
1565 false,
1566 None,
1567 );
1568
1569 let start_content = alice_sas.as_content();
1570 let flow_id = start_content.flow_id();
1571
1572 let bob_sas = SasState::<Started>::from_start_event(
1573 bob.static_data().clone(),
1574 alice_device,
1575 None,
1576 None,
1577 flow_id,
1578 &start_content.as_start_content(),
1579 false,
1580 );
1581 let bob_sas = bob_sas
1582 .unwrap()
1583 .into_we_accepted_with_mac_method(vec![ShortAuthenticationString::Emoji], mac_method);
1584
1585 (alice_sas, bob_sas)
1586 }
1587
1588 #[test]
1589 fn start_content_accepting() {
1590 let mut start_content: SasV1Content = SasV1ContentInit {
1591 key_agreement_protocols: vec![
1592 KeyAgreementProtocol::Curve25519HkdfSha256,
1593 KeyAgreementProtocol::Curve25519,
1594 ],
1595 hashes: vec![HashAlgorithm::Sha256],
1596 message_authentication_codes: vec![
1597 #[allow(deprecated)]
1598 MessageAuthenticationCode::HkdfHmacSha256,
1599 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
1600 MessageAuthenticationCode::HkdfHmacSha256V2,
1601 ],
1602 short_authentication_string: vec![
1603 ShortAuthenticationString::Emoji,
1604 ShortAuthenticationString::Decimal,
1605 ],
1606 }
1607 .into();
1608
1609 let accepted_protocols = AcceptedProtocols::try_from(&start_content).unwrap();
1610
1611 assert_eq!(accepted_protocols.message_auth_code, SupportedMacMethod::HkdfHmacSha256V2);
1612 assert_eq!(
1613 accepted_protocols.key_agreement_protocol,
1614 KeyAgreementProtocol::Curve25519HkdfSha256
1615 );
1616
1617 start_content.message_authentication_codes = vec![
1618 #[allow(deprecated)]
1619 MessageAuthenticationCode::HkdfHmacSha256,
1620 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
1621 ];
1622 let accepted_protocols = AcceptedProtocols::try_from(&start_content).unwrap();
1623 assert_eq!(
1624 accepted_protocols.message_auth_code,
1625 SupportedMacMethod::Msc3783HkdfHmacSha256V2
1626 );
1627
1628 start_content.key_agreement_protocols = vec![KeyAgreementProtocol::Curve25519];
1629 AcceptedProtocols::try_from(&start_content)
1630 .expect_err("We don't support the old Curve25519 key agreement protocol");
1631 }
1632
1633 #[test]
1634 fn test_create_sas() {
1635 let (_, _) = get_sas_pair(None);
1636 }
1637
1638 #[test]
1639 fn test_sas_accept() {
1640 let (alice, bob) = get_sas_pair(None);
1641 let content = bob.as_content();
1642 let content = AcceptContent::from(&content);
1643
1644 alice.into_accepted(bob.user_id(), &content).unwrap();
1645 }
1646
1647 #[test]
1648 fn test_sas_key_share() {
1649 let (alice, bob) = get_sas_pair(None);
1650
1651 let content = bob.as_content();
1652 let content = AcceptContent::from(&content);
1653
1654 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1655 let content = alice.as_content();
1656 let transaction_id = content.1.request_id;
1657 let content = KeyContent::try_from(&content.0).unwrap();
1658 let alice = alice.into_key_sent(&transaction_id).unwrap();
1659
1660 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1661
1662 let content = bob.as_content();
1663 let transaction_id = content.1.request_id;
1664 let content = KeyContent::try_from(&content.0).unwrap();
1665
1666 let bob = bob.into_keys_exchanged(&transaction_id).unwrap();
1667
1668 let alice = alice.into_keys_exchanged(bob.user_id(), &content).unwrap();
1669
1670 assert_eq!(alice.get_decimal(), bob.get_decimal());
1671 assert_eq!(alice.get_emoji(), bob.get_emoji());
1672 }
1673
1674 fn full_flow_helper(mac_method: SupportedMacMethod) {
1675 let (alice, bob) = get_sas_pair(Some(mac_method));
1676
1677 let content = bob.as_content();
1678 let content = AcceptContent::from(&content);
1679
1680 assert_eq!(
1681 bob.state.accepted_protocols.message_auth_code, mac_method,
1682 "Bob should be using the specified MAC method."
1683 );
1684
1685 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1686
1687 assert_eq!(
1688 alice.state.accepted_protocols.message_auth_code, mac_method,
1689 "Alice should use the our specified MAC method.",
1690 );
1691
1692 let content = alice.as_content();
1693 let request_id = content.1.request_id;
1694 let content = KeyContent::try_from(&content.0).unwrap();
1695
1696 let alice = alice.into_key_sent(&request_id).unwrap();
1697 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1698
1699 let (content, request_info) = bob.as_content();
1700 let request_id = request_info.request_id;
1701 let content = KeyContent::try_from(&content).unwrap();
1702 let bob = bob.into_keys_exchanged(&request_id).unwrap();
1703
1704 let alice = alice.into_keys_exchanged(bob.user_id(), &content).unwrap();
1705
1706 assert_eq!(alice.get_decimal(), bob.get_decimal());
1707 assert_eq!(alice.get_emoji(), bob.get_emoji());
1708
1709 let bob_decimals = bob.get_decimal();
1710
1711 let bob = bob.confirm();
1712
1713 let content = bob.as_content();
1714 let content = MacContent::try_from(&content).unwrap();
1715
1716 let alice = alice.into_mac_received(bob.user_id(), &content).unwrap();
1717 assert!(!alice.get_emoji().is_empty());
1718 assert_eq!(alice.get_decimal(), bob_decimals);
1719 let alice = alice.confirm();
1720
1721 let content = alice.as_content();
1722 let content = MacContent::try_from(&content).unwrap();
1723 let bob = bob.into_done(alice.user_id(), &content).unwrap();
1724
1725 assert!(bob.verified_devices().contains(&bob.other_device()));
1726 assert!(alice.verified_devices().contains(&alice.other_device()));
1727 }
1728
1729 #[test]
1730 fn test_full_flow() {
1731 full_flow_helper(SupportedMacMethod::HkdfHmacSha256);
1732 }
1733
1734 #[test]
1735 fn test_full_flow_hkdf_hmac_sha_v2() {
1736 full_flow_helper(SupportedMacMethod::HkdfHmacSha256V2);
1737 }
1738
1739 #[test]
1740 fn test_full_flow_hkdf_msc3783() {
1741 full_flow_helper(SupportedMacMethod::Msc3783HkdfHmacSha256V2);
1742 }
1743
1744 #[test]
1745 fn test_sas_invalid_commitment() {
1746 let (alice, bob) = get_sas_pair(None);
1747
1748 let mut content = bob.as_content();
1749 let mut method = content.method_mut();
1750
1751 match &mut method {
1752 AcceptMethod::SasV1(c) => {
1753 c.commitment = Base64::empty();
1754 }
1755 _ => panic!("Unknown accept event content"),
1756 }
1757
1758 let content = AcceptContent::from(&content);
1759
1760 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1761
1762 let content = alice.as_content();
1763 let content = KeyContent::try_from(&content.0).unwrap();
1764 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1765 let content = bob.as_content();
1766 let content = KeyContent::try_from(&content.0).unwrap();
1767
1768 alice
1769 .into_key_received(bob.user_id(), &content)
1770 .expect_err("Didn't cancel on invalid commitment");
1771 }
1772
1773 #[test]
1774 fn test_sas_invalid_sender() {
1775 let (alice, bob) = get_sas_pair(None);
1776
1777 let content = bob.as_content();
1778 let content = AcceptContent::from(&content);
1779 let sender = user_id!("@malory:example.org");
1780 alice.into_accepted(sender, &content).expect_err("Didn't cancel on a invalid sender");
1781 }
1782
1783 #[test]
1784 fn test_sas_unknown_sas_method() {
1785 let (alice, bob) = get_sas_pair(None);
1786
1787 let mut content = bob.as_content();
1788 let mut method = content.method_mut();
1789
1790 match &mut method {
1791 AcceptMethod::SasV1(ref mut c) => {
1792 c.short_authentication_string = vec![];
1793 }
1794 _ => panic!("Unknown accept event content"),
1795 }
1796
1797 let content = AcceptContent::from(&content);
1798
1799 alice
1800 .into_accepted(bob.user_id(), &content)
1801 .expect_err("Didn't cancel on an invalid SAS method");
1802 }
1803
1804 #[test]
1805 fn test_sas_unknown_method() {
1806 let (alice, bob) = get_sas_pair(None);
1807
1808 let content = json!({
1809 "method": "m.sas.custom",
1810 "method_data": "something",
1811 "transaction_id": "some_id",
1812 });
1813
1814 let content: ToDeviceKeyVerificationAcceptEventContent =
1815 serde_json::from_value(content).unwrap();
1816 let content = AcceptContent::from(&content);
1817
1818 alice
1819 .into_accepted(bob.user_id(), &content)
1820 .expect_err("Didn't cancel on an unknown SAS method");
1821 }
1822
1823 #[async_test]
1824 async fn test_sas_from_start_unknown_method() {
1825 let alice = Account::with_device_id(alice_id(), alice_device_id());
1826 let alice_device = DeviceData::from_account(&alice);
1827
1828 let bob = Account::with_device_id(bob_id(), bob_device_id());
1829 let bob_device = DeviceData::from_account(&bob);
1830
1831 let flow_id = TransactionId::new().into();
1832 let alice_sas = SasState::<Created>::new(
1833 alice.static_data().clone(),
1834 bob_device,
1835 None,
1836 None,
1837 flow_id,
1838 false,
1839 None,
1840 );
1841
1842 let mut start_content = alice_sas.as_content();
1843 let method = start_content.method_mut();
1844
1845 match method {
1846 StartMethod::SasV1(ref mut c) => {
1847 c.message_authentication_codes = vec![];
1848 }
1849 _ => panic!("Unknown SAS start method"),
1850 }
1851
1852 let flow_id = start_content.flow_id();
1853 let content = StartContent::from(&start_content);
1854
1855 SasState::<Started>::from_start_event(
1856 bob.static_data().clone(),
1857 alice_device.clone(),
1858 None,
1859 None,
1860 flow_id,
1861 &content,
1862 false,
1863 )
1864 .expect_err("Didn't cancel on invalid MAC method");
1865
1866 let content = json!({
1867 "method": "m.sas.custom",
1868 "from_device": "DEVICEID",
1869 "method_data": "something",
1870 "transaction_id": "some_id",
1871 });
1872
1873 let content: ToDeviceKeyVerificationStartEventContent =
1874 serde_json::from_value(content).unwrap();
1875 let content = StartContent::from(&content);
1876 let flow_id = content.flow_id().to_owned();
1877
1878 SasState::<Started>::from_start_event(
1879 bob.static_data().clone(),
1880 alice_device,
1881 None,
1882 None,
1883 FlowId::ToDevice(flow_id.into()),
1884 &content,
1885 false,
1886 )
1887 .expect_err("Didn't cancel on unknown sas method");
1888 }
1889}