1use std::{
2 convert::TryFrom,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::Result;
8use crate::{r2::js_object, Error};
9use futures_util::FutureExt;
10use js_sys::{
11 Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject,
12 Reflect, Uint8Array,
13};
14use std::convert::TryInto;
15use std::io::Error as IoError;
16use std::io::Result as IoResult;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use wasm_bindgen::{JsCast, JsValue};
19use wasm_bindgen_futures::JsFuture;
20use web_sys::{
21 ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter,
22};
23
24#[derive(Debug)]
25pub struct SocketInfo {
26 pub remote_address: Option<String>,
27 pub local_address: Option<String>,
28}
29
30impl TryFrom<JsValue> for SocketInfo {
31 type Error = Error;
32 fn try_from(value: JsValue) -> Result<Self> {
33 let remote_address_value =
34 js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?;
35 let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?;
36 Ok(Self {
37 remote_address: remote_address_value.as_string(),
38 local_address: local_address_value.as_string(),
39 })
40 }
41}
42
43#[derive(Default)]
44enum Reading {
45 #[default]
46 None,
47 Pending(JsFuture, ReadableStreamDefaultReader),
48 Ready(Vec<u8>),
49}
50
51#[derive(Default)]
52enum Writing {
53 Pending(JsFuture, WritableStreamDefaultWriter, usize),
54 #[default]
55 None,
56}
57
58#[derive(Default)]
59enum Closing {
60 Pending(JsFuture),
61 #[default]
62 None,
63}
64
65pub struct Socket {
67 inner: worker_sys::Socket,
68 writable: WritableStream,
69 readable: ReadableStream,
70 write: Option<Writing>,
71 read: Option<Reading>,
72 close: Option<Closing>,
73}
74
75unsafe impl Send for Socket {}
77unsafe impl Sync for Socket {}
78
79impl Socket {
80 fn new(inner: worker_sys::Socket) -> Self {
81 let writable = inner.writable().unwrap();
82 let readable = inner.readable().unwrap();
83 Socket {
84 inner,
85 writable,
86 readable,
87 read: None,
88 write: None,
89 close: None,
90 }
91 }
92
93 pub async fn close(&mut self) -> Result<()> {
95 JsFuture::from(self.inner.close()?).await?;
96 Ok(())
97 }
98
99 pub async fn closed(&self) -> Result<()> {
102 JsFuture::from(self.inner.closed()?).await?;
103 Ok(())
104 }
105
106 pub async fn opened(&self) -> Result<SocketInfo> {
107 let value = JsFuture::from(self.inner.opened()?).await?;
108 value.try_into()
109 }
110
111 pub fn start_tls(self) -> Socket {
117 let inner = self.inner.start_tls().unwrap();
118 Socket::new(inner)
119 }
120
121 pub fn builder() -> ConnectionBuilder {
122 ConnectionBuilder::default()
123 }
124
125 fn handle_write_future(
126 cx: &mut Context<'_>,
127 mut fut: JsFuture,
128 writer: WritableStreamDefaultWriter,
129 len: usize,
130 ) -> (Writing, Poll<IoResult<usize>>) {
131 match fut.poll_unpin(cx) {
132 Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending),
133 Poll::Ready(res) => {
134 writer.release_lock();
135 match res {
136 Ok(_) => (Writing::None, Poll::Ready(Ok(len))),
137 Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
138 }
139 }
140 }
141 }
142}
143
144fn js_value_to_std_io_error(value: JsValue) -> IoError {
145 let s = if value.is_string() {
146 value.as_string().unwrap()
147 } else if let Some(value) = value.dyn_ref::<JsError>() {
148 value.to_string().into()
149 } else {
150 format!("Error interpreting JsError: {:?}", value)
151 };
152 IoError::other(s)
153}
154impl AsyncRead for Socket {
155 fn poll_read(
156 mut self: Pin<&mut Self>,
157 cx: &mut Context<'_>,
158 buf: &mut ReadBuf<'_>,
159 ) -> Poll<IoResult<()>> {
160 fn handle_future(
161 cx: &mut Context<'_>,
162 buf: &mut ReadBuf<'_>,
163 mut fut: JsFuture,
164 reader: ReadableStreamDefaultReader,
165 ) -> (Reading, Poll<IoResult<()>>) {
166 match fut.poll_unpin(cx) {
167 Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending),
168 Poll::Ready(res) => match res {
169 Ok(value) => {
170 reader.release_lock();
171 let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) {
172 Ok(value) => value.into(),
173 Err(error) => {
174 let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {:?}", error);
175 return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
176 }
177 };
178 if done.is_truthy() {
179 (Reading::None, Poll::Ready(Ok(())))
180 } else {
181 let arr: Uint8Array = match Reflect::get(
182 &value,
183 &JsValue::from("value"),
184 ) {
185 Ok(value) => value.into(),
186 Err(error) => {
187 let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {:?}", error);
188 return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
189 }
190 };
191 let data = arr.to_vec();
192 handle_data(buf, data)
193 }
194 }
195 Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
196 },
197 }
198 }
199
200 let (new_reading, poll) = match self.read.take().unwrap_or_default() {
201 Reading::None => {
202 let reader: ReadableStreamDefaultReader =
203 match self.readable.get_reader().dyn_into() {
204 Ok(reader) => reader,
205 Err(error) => {
206 let msg = format!(
207 "Unable to cast JsObject to ReadableStreamDefaultReader: {:?}",
208 error
209 );
210 return Poll::Ready(Err(IoError::other(msg)));
211 }
212 };
213
214 handle_future(cx, buf, JsFuture::from(reader.read()), reader)
215 }
216 Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader),
217 Reading::Ready(data) => handle_data(buf, data),
218 };
219 self.read = Some(new_reading);
220 poll
221 }
222}
223
224impl AsyncWrite for Socket {
225 fn poll_write(
226 mut self: Pin<&mut Self>,
227 cx: &mut Context<'_>,
228 buf: &[u8],
229 ) -> Poll<IoResult<usize>> {
230 let (new_writing, poll) = match self.write.take().unwrap_or_default() {
231 Writing::None => {
232 let obj = JsValue::from(Uint8Array::from(buf));
233 let writer: WritableStreamDefaultWriter = match self.writable.get_writer() {
234 Ok(writer) => writer,
235 Err(error) => {
236 let msg = format!("Could not retrieve Writer: {:?}", error);
237 return Poll::Ready(Err(IoError::other(msg)));
238 }
239 };
240 Self::handle_write_future(
241 cx,
242 JsFuture::from(writer.write_with_chunk(&obj)),
243 writer,
244 buf.len(),
245 )
246 }
247 Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len),
248 };
249 self.write = Some(new_writing);
250 poll
251 }
252
253 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
254 let (new_writing, poll) = match self.write.take().unwrap_or_default() {
256 Writing::Pending(fut, writer, len) => {
257 let (writing, poll) = Self::handle_write_future(cx, fut, writer, len);
258 (writing, poll.map(|res| res.map(|_| ())))
260 }
261 writing => (writing, Poll::Ready(Ok(()))),
262 };
263 self.write = Some(new_writing);
264 poll
265 }
266
267 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
268 fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll<IoResult<()>>) {
269 match fut.poll_unpin(cx) {
270 Poll::Pending => (Closing::Pending(fut), Poll::Pending),
271 Poll::Ready(res) => match res {
272 Ok(_) => (Closing::None, Poll::Ready(Ok(()))),
273 Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
274 },
275 }
276 }
277 let (new_closing, poll) = match self.close.take().unwrap_or_default() {
278 Closing::None => handle_future(cx, JsFuture::from(self.writable.close())),
279 Closing::Pending(fut) => handle_future(cx, fut),
280 };
281 self.close = Some(new_closing);
282 poll
283 }
284}
285
286pub enum SecureTransport {
288 Off,
290 On,
292 StartTls,
295}
296
297pub struct SocketOptions {
299 pub secure_transport: SecureTransport,
301 pub allow_half_open: bool,
306}
307
308impl Default for SocketOptions {
309 fn default() -> Self {
310 SocketOptions {
311 secure_transport: SecureTransport::Off,
312 allow_half_open: false,
313 }
314 }
315}
316
317pub struct SocketAddress {
319 pub hostname: String,
321 pub port: u16,
323}
324
325#[derive(Default)]
326pub struct ConnectionBuilder {
327 options: SocketOptions,
328}
329
330impl ConnectionBuilder {
331 pub fn new() -> Self {
333 ConnectionBuilder {
334 options: SocketOptions::default(),
335 }
336 }
337
338 pub fn allow_half_open(mut self, allow_half_open: bool) -> Self {
341 self.options.allow_half_open = allow_half_open;
342 self
343 }
344
345 pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self {
347 self.options.secure_transport = secure_transport;
348 self
349 }
350
351 pub fn connect(self, hostname: impl Into<String>, port: u16) -> Result<Socket> {
353 let address: JsValue = js_object!(
354 "hostname" => JsObject::from(JsString::from(hostname.into())),
355 "port" => JsNumber::from(port)
356 )
357 .into();
358
359 let options: JsValue = js_object!(
360 "allowHalfOpen" => JsBoolean::from(self.options.allow_half_open),
361 "secureTransport" => JsString::from(match self.options.secure_transport {
362 SecureTransport::On => "on",
363 SecureTransport::Off => "off",
364 SecureTransport::StartTls => "starttls",
365 })
366 )
367 .into();
368
369 let inner = worker_sys::connect(address, options)?;
370 Ok(Socket::new(inner))
371 }
372}
373
374fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoResult<()>>) {
376 let idx = buf.remaining().min(data.len());
377 let store = data.split_off(idx);
378 buf.put_slice(&data);
379 if store.is_empty() {
380 (Reading::None, Poll::Ready(Ok(())))
381 } else {
382 (Reading::Ready(store), Poll::Ready(Ok(())))
383 }
384}
385
386#[cfg(feature = "tokio-postgres")]
387pub mod postgres_tls {
391 use super::Socket;
392 use futures_util::future::{ready, Ready};
393 use std::error::Error;
394 use std::fmt::{self, Display, Formatter};
395 use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};
396
397 pub struct PassthroughTls;
408
409 #[derive(Debug)]
410 pub struct PassthroughTlsError;
413
414 impl Error for PassthroughTlsError {}
415
416 impl Display for PassthroughTlsError {
417 fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
418 fmt.write_str("PassthroughTlsError")
419 }
420 }
421
422 impl TlsConnect<Socket> for PassthroughTls {
423 type Stream = Socket;
424 type Error = PassthroughTlsError;
425 type Future = Ready<Result<Socket, PassthroughTlsError>>;
426
427 fn connect(self, s: Self::Stream) -> Self::Future {
428 let tls = s.start_tls();
429 ready(Ok(tls))
430 }
431 }
432
433 impl TlsStream for Socket {
434 fn channel_binding(&self) -> ChannelBinding {
435 ChannelBinding::none()
436 }
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 #[test]
444 fn test_handle_data() {
445 let mut arr = vec![0u8; 32];
446 let mut buf = ReadBuf::new(&mut arr);
447 let data = vec![1u8; 32];
448 let (reading, _) = handle_data(&mut buf, data);
449
450 assert!(matches!(reading, Reading::None));
451 assert_eq!(buf.remaining(), 0);
452 assert_eq!(buf.filled().len(), 32);
453 }
454
455 #[test]
456 fn test_handle_large_data() {
457 let mut arr = vec![0u8; 32];
458 let mut buf = ReadBuf::new(&mut arr);
459 let data = vec![1u8; 64];
460 let (reading, _) = handle_data(&mut buf, data);
461
462 assert!(matches!(reading, Reading::Ready(store) if store.len() == 32));
463 assert_eq!(buf.remaining(), 0);
464 assert_eq!(buf.filled().len(), 32);
465 }
466
467 #[test]
468 fn test_handle_small_data() {
469 let mut arr = vec![0u8; 32];
470 let mut buf = ReadBuf::new(&mut arr);
471 let data = vec![1u8; 16];
472 let (reading, _) = handle_data(&mut buf, data);
473
474 assert!(matches!(reading, Reading::None));
475 assert_eq!(buf.remaining(), 16);
476 assert_eq!(buf.filled().len(), 16);
477 }
478}