matrix_sdk_common/
executor.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Abstraction over an executor so we can spawn tasks under WASM the same way
16//! we do usually.
17
18#[cfg(target_arch = "wasm32")]
19use std::{
20    future::Future,
21    pin::Pin,
22    task::{Context, Poll},
23};
24
25#[cfg(target_arch = "wasm32")]
26pub use futures_util::future::Aborted as JoinError;
27#[cfg(target_arch = "wasm32")]
28use futures_util::{
29    future::{AbortHandle, Abortable, RemoteHandle},
30    FutureExt,
31};
32#[cfg(not(target_arch = "wasm32"))]
33pub use tokio::task::{spawn, JoinError, JoinHandle};
34
35#[cfg(target_arch = "wasm32")]
36pub fn spawn<F, T>(future: F) -> JoinHandle<T>
37where
38    F: Future<Output = T> + 'static,
39{
40    let (future, remote_handle) = future.remote_handle();
41    let (abort_handle, abort_registration) = AbortHandle::new_pair();
42    let future = Abortable::new(future, abort_registration);
43
44    wasm_bindgen_futures::spawn_local(async {
45        // Poll the future, and ignore the result (either it's `Ok(())`, or it's
46        // `Err(Aborted)`).
47        let _ = future.await;
48    });
49
50    JoinHandle { remote_handle: Some(remote_handle), abort_handle }
51}
52
53#[cfg(target_arch = "wasm32")]
54#[derive(Debug)]
55pub struct JoinHandle<T> {
56    remote_handle: Option<RemoteHandle<T>>,
57    abort_handle: AbortHandle,
58}
59
60#[cfg(target_arch = "wasm32")]
61impl<T> JoinHandle<T> {
62    pub fn abort(&self) {
63        self.abort_handle.abort();
64    }
65
66    pub fn is_finished(&self) -> bool {
67        self.abort_handle.is_aborted()
68    }
69}
70
71#[cfg(target_arch = "wasm32")]
72impl<T> Drop for JoinHandle<T> {
73    fn drop(&mut self) {
74        // don't abort the spawned future
75        if let Some(h) = self.remote_handle.take() {
76            h.forget();
77        }
78    }
79}
80
81#[cfg(target_arch = "wasm32")]
82impl<T: 'static> Future for JoinHandle<T> {
83    type Output = Result<T, JoinError>;
84
85    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86        if self.abort_handle.is_aborted() {
87            // The future has been aborted. It is not possible to poll it again.
88            Poll::Ready(Err(JoinError))
89        } else if let Some(handle) = self.remote_handle.as_mut() {
90            Pin::new(handle).poll(cx).map(Ok)
91        } else {
92            Poll::Ready(Err(JoinError))
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use assert_matches::assert_matches;
100    use matrix_sdk_test_macros::async_test;
101
102    use super::spawn;
103
104    #[async_test]
105    async fn test_spawn() {
106        let future = async { 42 };
107        let join_handle = spawn(future);
108
109        assert_matches!(join_handle.await, Ok(42));
110    }
111
112    #[async_test]
113    async fn test_abort() {
114        let future = async { 42 };
115        let join_handle = spawn(future);
116
117        join_handle.abort();
118
119        assert!(join_handle.await.is_err());
120    }
121}