[go: up one dir, main page]

worker/
socket.rs

1use std::{
2    convert::TryFrom,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::Result;
8use crate::{r2::js_object, Error};
9use futures_util::FutureExt;
10use js_sys::{
11    Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject,
12    Reflect, Uint8Array,
13};
14use std::convert::TryInto;
15use std::io::Error as IoError;
16use std::io::Result as IoResult;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use wasm_bindgen::{JsCast, JsValue};
19use wasm_bindgen_futures::JsFuture;
20use web_sys::{
21    ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter,
22};
23
24#[derive(Debug)]
25pub struct SocketInfo {
26    pub remote_address: Option<String>,
27    pub local_address: Option<String>,
28}
29
30impl TryFrom<JsValue> for SocketInfo {
31    type Error = Error;
32    fn try_from(value: JsValue) -> Result<Self> {
33        let remote_address_value =
34            js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?;
35        let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?;
36        Ok(Self {
37            remote_address: remote_address_value.as_string(),
38            local_address: local_address_value.as_string(),
39        })
40    }
41}
42
43#[derive(Default)]
44enum Reading {
45    #[default]
46    None,
47    Pending(JsFuture, ReadableStreamDefaultReader),
48    Ready(Vec<u8>),
49}
50
51#[derive(Default)]
52enum Writing {
53    Pending(JsFuture, WritableStreamDefaultWriter, usize),
54    #[default]
55    None,
56}
57
58#[derive(Default)]
59enum Closing {
60    Pending(JsFuture),
61    #[default]
62    None,
63}
64
65/// Represents an outbound TCP connection from your Worker.
66pub struct Socket {
67    inner: worker_sys::Socket,
68    writable: WritableStream,
69    readable: ReadableStream,
70    write: Option<Writing>,
71    read: Option<Reading>,
72    close: Option<Closing>,
73}
74
75// This can only be done because workers are single threaded.
76unsafe impl Send for Socket {}
77unsafe impl Sync for Socket {}
78
79impl Socket {
80    fn new(inner: worker_sys::Socket) -> Self {
81        let writable = inner.writable().unwrap();
82        let readable = inner.readable().unwrap();
83        Socket {
84            inner,
85            writable,
86            readable,
87            read: None,
88            write: None,
89            close: None,
90        }
91    }
92
93    /// Closes the TCP socket. Both the readable and writable streams are forcibly closed.
94    pub async fn close(&mut self) -> Result<()> {
95        JsFuture::from(self.inner.close()?).await?;
96        Ok(())
97    }
98
99    /// This Future is resolved when the socket is closed
100    /// and is rejected if the socket encounters an error.
101    pub async fn closed(&self) -> Result<()> {
102        JsFuture::from(self.inner.closed()?).await?;
103        Ok(())
104    }
105
106    pub async fn opened(&self) -> Result<SocketInfo> {
107        let value = JsFuture::from(self.inner.opened()?).await?;
108        value.try_into()
109    }
110
111    /// Upgrades an insecure socket to a secure one that uses TLS,
112    /// returning a new Socket. Note that in order to call this method,
113    /// you must set [`secure_transport`](SocketOptions::secure_transport)
114    /// to [`StartTls`](SecureTransport::StartTls) when initially
115    /// calling [`connect`](connect) to create the socket.
116    pub fn start_tls(self) -> Socket {
117        let inner = self.inner.start_tls().unwrap();
118        Socket::new(inner)
119    }
120
121    pub fn builder() -> ConnectionBuilder {
122        ConnectionBuilder::default()
123    }
124
125    fn handle_write_future(
126        cx: &mut Context<'_>,
127        mut fut: JsFuture,
128        writer: WritableStreamDefaultWriter,
129        len: usize,
130    ) -> (Writing, Poll<IoResult<usize>>) {
131        match fut.poll_unpin(cx) {
132            Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending),
133            Poll::Ready(res) => {
134                writer.release_lock();
135                match res {
136                    Ok(_) => (Writing::None, Poll::Ready(Ok(len))),
137                    Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
138                }
139            }
140        }
141    }
142}
143
144fn js_value_to_std_io_error(value: JsValue) -> IoError {
145    let s = if value.is_string() {
146        value.as_string().unwrap()
147    } else if let Some(value) = value.dyn_ref::<JsError>() {
148        value.to_string().into()
149    } else {
150        format!("Error interpreting JsError: {:?}", value)
151    };
152    IoError::other(s)
153}
154impl AsyncRead for Socket {
155    fn poll_read(
156        mut self: Pin<&mut Self>,
157        cx: &mut Context<'_>,
158        buf: &mut ReadBuf<'_>,
159    ) -> Poll<IoResult<()>> {
160        fn handle_future(
161            cx: &mut Context<'_>,
162            buf: &mut ReadBuf<'_>,
163            mut fut: JsFuture,
164            reader: ReadableStreamDefaultReader,
165        ) -> (Reading, Poll<IoResult<()>>) {
166            match fut.poll_unpin(cx) {
167                Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending),
168                Poll::Ready(res) => match res {
169                    Ok(value) => {
170                        reader.release_lock();
171                        let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) {
172                            Ok(value) => value.into(),
173                            Err(error) => {
174                                let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {:?}", error);
175                                return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
176                            }
177                        };
178                        if done.is_truthy() {
179                            (Reading::None, Poll::Ready(Ok(())))
180                        } else {
181                            let arr: Uint8Array = match Reflect::get(
182                                &value,
183                                &JsValue::from("value"),
184                            ) {
185                                Ok(value) => value.into(),
186                                Err(error) => {
187                                    let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {:?}", error);
188                                    return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
189                                }
190                            };
191                            let data = arr.to_vec();
192                            handle_data(buf, data)
193                        }
194                    }
195                    Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
196                },
197            }
198        }
199
200        let (new_reading, poll) = match self.read.take().unwrap_or_default() {
201            Reading::None => {
202                let reader: ReadableStreamDefaultReader =
203                    match self.readable.get_reader().dyn_into() {
204                        Ok(reader) => reader,
205                        Err(error) => {
206                            let msg = format!(
207                                "Unable to cast JsObject to ReadableStreamDefaultReader: {:?}",
208                                error
209                            );
210                            return Poll::Ready(Err(IoError::other(msg)));
211                        }
212                    };
213
214                handle_future(cx, buf, JsFuture::from(reader.read()), reader)
215            }
216            Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader),
217            Reading::Ready(data) => handle_data(buf, data),
218        };
219        self.read = Some(new_reading);
220        poll
221    }
222}
223
224impl AsyncWrite for Socket {
225    fn poll_write(
226        mut self: Pin<&mut Self>,
227        cx: &mut Context<'_>,
228        buf: &[u8],
229    ) -> Poll<IoResult<usize>> {
230        let (new_writing, poll) = match self.write.take().unwrap_or_default() {
231            Writing::None => {
232                let obj = JsValue::from(Uint8Array::from(buf));
233                let writer: WritableStreamDefaultWriter = match self.writable.get_writer() {
234                    Ok(writer) => writer,
235                    Err(error) => {
236                        let msg = format!("Could not retrieve Writer: {:?}", error);
237                        return Poll::Ready(Err(IoError::other(msg)));
238                    }
239                };
240                Self::handle_write_future(
241                    cx,
242                    JsFuture::from(writer.write_with_chunk(&obj)),
243                    writer,
244                    buf.len(),
245                )
246            }
247            Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len),
248        };
249        self.write = Some(new_writing);
250        poll
251    }
252
253    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
254        // Poll existing write future if it exists.
255        let (new_writing, poll) = match self.write.take().unwrap_or_default() {
256            Writing::Pending(fut, writer, len) => {
257                let (writing, poll) = Self::handle_write_future(cx, fut, writer, len);
258                // Map poll output to ()
259                (writing, poll.map(|res| res.map(|_| ())))
260            }
261            writing => (writing, Poll::Ready(Ok(()))),
262        };
263        self.write = Some(new_writing);
264        poll
265    }
266
267    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
268        fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll<IoResult<()>>) {
269            match fut.poll_unpin(cx) {
270                Poll::Pending => (Closing::Pending(fut), Poll::Pending),
271                Poll::Ready(res) => match res {
272                    Ok(_) => (Closing::None, Poll::Ready(Ok(()))),
273                    Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
274                },
275            }
276        }
277        let (new_closing, poll) = match self.close.take().unwrap_or_default() {
278            Closing::None => handle_future(cx, JsFuture::from(self.writable.close())),
279            Closing::Pending(fut) => handle_future(cx, fut),
280        };
281        self.close = Some(new_closing);
282        poll
283    }
284}
285
286/// Secure transport options for outbound TCP connections.
287pub enum SecureTransport {
288    /// Do not use TLS.
289    Off,
290    /// Use TLS.
291    On,
292    /// Do not use TLS initially, but allow the socket to be upgraded to
293    /// use TLS by calling [`Socket.start_tls`](Socket::start_tls).
294    StartTls,
295}
296
297/// Used to configure outbound TCP connections.
298pub struct SocketOptions {
299    /// Specifies whether or not to use TLS when creating the TCP socket.
300    pub secure_transport: SecureTransport,
301    /// Defines whether the writable side of the TCP socket will automatically
302    /// close on end-of-file (EOF). When set to false, the writable side of the
303    /// TCP socket will automatically close on EOF. When set to true, the
304    /// writable side of the TCP socket will remain open on EOF.
305    pub allow_half_open: bool,
306}
307
308impl Default for SocketOptions {
309    fn default() -> Self {
310        SocketOptions {
311            secure_transport: SecureTransport::Off,
312            allow_half_open: false,
313        }
314    }
315}
316
317/// The host and port that you wish to connect to.
318pub struct SocketAddress {
319    /// The hostname to connect to. Example: `cloudflare.com`.
320    pub hostname: String,
321    /// The port number to connect to. Example: `5432`.
322    pub port: u16,
323}
324
325#[derive(Default)]
326pub struct ConnectionBuilder {
327    options: SocketOptions,
328}
329
330impl ConnectionBuilder {
331    /// Create a new `ConnectionBuilder` with default settings.
332    pub fn new() -> Self {
333        ConnectionBuilder {
334            options: SocketOptions::default(),
335        }
336    }
337
338    /// Set whether the writable side of the TCP socket will automatically
339    /// close on end-of-file (EOF).
340    pub fn allow_half_open(mut self, allow_half_open: bool) -> Self {
341        self.options.allow_half_open = allow_half_open;
342        self
343    }
344
345    // Specify whether or not to use TLS when creating the TCP socket.
346    pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self {
347        self.options.secure_transport = secure_transport;
348        self
349    }
350
351    /// Open the connection to `hostname` on port `port`, returning a [`Socket`](Socket).
352    pub fn connect(self, hostname: impl Into<String>, port: u16) -> Result<Socket> {
353        let address: JsValue = js_object!(
354            "hostname" => JsObject::from(JsString::from(hostname.into())),
355            "port" => JsNumber::from(port)
356        )
357        .into();
358
359        let options: JsValue = js_object!(
360            "allowHalfOpen" => JsBoolean::from(self.options.allow_half_open),
361            "secureTransport" => JsString::from(match self.options.secure_transport {
362                SecureTransport::On => "on",
363                SecureTransport::Off => "off",
364                SecureTransport::StartTls => "starttls",
365            })
366        )
367        .into();
368
369        let inner = worker_sys::connect(address, options)?;
370        Ok(Socket::new(inner))
371    }
372}
373
374// Writes as much as possible to buf, and stores the rest in internal buffer
375fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoResult<()>>) {
376    let idx = buf.remaining().min(data.len());
377    let store = data.split_off(idx);
378    buf.put_slice(&data);
379    if store.is_empty() {
380        (Reading::None, Poll::Ready(Ok(())))
381    } else {
382        (Reading::Ready(store), Poll::Ready(Ok(())))
383    }
384}
385
386#[cfg(feature = "tokio-postgres")]
387/// Implements [`TlsConnect`](tokio_postgres::TlsConnect) for
388/// [`Socket`](crate::Socket) to enable `tokio_postgres` connections
389/// to databases using TLS.
390pub mod postgres_tls {
391    use super::Socket;
392    use futures_util::future::{ready, Ready};
393    use std::error::Error;
394    use std::fmt::{self, Display, Formatter};
395    use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};
396
397    /// Supply this to `connect_raw` in place of `NoTls` to specify TLS
398    /// when using Workers.
399    ///
400    /// ```rust
401    /// let config = tokio_postgres::config::Config::new();
402    /// let socket = Socket::builder()
403    ///     .secure_transport(SecureTransport::StartTls)
404    ///     .connect("database_url", 5432)?;
405    /// let _ = config.connect_raw(socket, PassthroughTls).await?;
406    /// ```
407    pub struct PassthroughTls;
408
409    #[derive(Debug)]
410    /// Error type for PassthroughTls.
411    /// Should never be returned.
412    pub struct PassthroughTlsError;
413
414    impl Error for PassthroughTlsError {}
415
416    impl Display for PassthroughTlsError {
417        fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
418            fmt.write_str("PassthroughTlsError")
419        }
420    }
421
422    impl TlsConnect<Socket> for PassthroughTls {
423        type Stream = Socket;
424        type Error = PassthroughTlsError;
425        type Future = Ready<Result<Socket, PassthroughTlsError>>;
426
427        fn connect(self, s: Self::Stream) -> Self::Future {
428            let tls = s.start_tls();
429            ready(Ok(tls))
430        }
431    }
432
433    impl TlsStream for Socket {
434        fn channel_binding(&self) -> ChannelBinding {
435            ChannelBinding::none()
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    #[test]
444    fn test_handle_data() {
445        let mut arr = vec![0u8; 32];
446        let mut buf = ReadBuf::new(&mut arr);
447        let data = vec![1u8; 32];
448        let (reading, _) = handle_data(&mut buf, data);
449
450        assert!(matches!(reading, Reading::None));
451        assert_eq!(buf.remaining(), 0);
452        assert_eq!(buf.filled().len(), 32);
453    }
454
455    #[test]
456    fn test_handle_large_data() {
457        let mut arr = vec![0u8; 32];
458        let mut buf = ReadBuf::new(&mut arr);
459        let data = vec![1u8; 64];
460        let (reading, _) = handle_data(&mut buf, data);
461
462        assert!(matches!(reading, Reading::Ready(store) if store.len() == 32));
463        assert_eq!(buf.remaining(), 0);
464        assert_eq!(buf.filled().len(), 32);
465    }
466
467    #[test]
468    fn test_handle_small_data() {
469        let mut arr = vec![0u8; 32];
470        let mut buf = ReadBuf::new(&mut arr);
471        let data = vec![1u8; 16];
472        let (reading, _) = handle_data(&mut buf, data);
473
474        assert!(matches!(reading, Reading::None));
475        assert_eq!(buf.remaining(), 16);
476        assert_eq!(buf.filled().len(), 16);
477    }
478}