use js_sys::Array;
use std::convert::TryFrom;
use wasm_bindgen::{JsCast, JsValue};
use worker_sys::types::{SqlStorage as SqlStorageSys, SqlStorageCursor as SqlStorageCursorSys};
use serde::de::DeserializeOwned;
use serde_wasm_bindgen as swb;
use crate::Error;
use crate::Result;
#[derive(Debug, Clone, PartialEq)]
pub enum SqlStorageValue {
Null,
Boolean(bool),
Integer(i64),
Float(f64),
String(String),
Blob(Vec<u8>),
}
impl From<bool> for SqlStorageValue {
fn from(value: bool) -> Self {
SqlStorageValue::Boolean(value)
}
}
impl From<i32> for SqlStorageValue {
fn from(value: i32) -> Self {
SqlStorageValue::Integer(value as i64)
}
}
impl From<i64> for SqlStorageValue {
fn from(value: i64) -> Self {
SqlStorageValue::Integer(value)
}
}
impl SqlStorageValue {
pub fn try_from_i64(value: i64) -> Result<Self> {
if value >= js_sys::Number::MIN_SAFE_INTEGER as i64
&& value <= js_sys::Number::MAX_SAFE_INTEGER as i64
{
Ok(SqlStorageValue::Integer(value))
} else {
Err(crate::Error::from(
"Value outside JavaScript safe integer range",
))
}
}
}
impl From<f64> for SqlStorageValue {
fn from(value: f64) -> Self {
SqlStorageValue::Float(value)
}
}
impl From<String> for SqlStorageValue {
fn from(value: String) -> Self {
SqlStorageValue::String(value)
}
}
impl From<&str> for SqlStorageValue {
fn from(value: &str) -> Self {
SqlStorageValue::String(value.to_string())
}
}
impl From<Vec<u8>> for SqlStorageValue {
fn from(value: Vec<u8>) -> Self {
SqlStorageValue::Blob(value)
}
}
impl<T> From<Option<T>> for SqlStorageValue
where
T: Into<SqlStorageValue>,
{
fn from(value: Option<T>) -> Self {
match value {
Some(v) => v.into(),
None => SqlStorageValue::Null,
}
}
}
impl From<SqlStorageValue> for JsValue {
fn from(val: SqlStorageValue) -> Self {
match val {
SqlStorageValue::Null => JsValue::NULL,
SqlStorageValue::Boolean(b) => JsValue::from(b),
SqlStorageValue::Integer(i) => {
let js_value = JsValue::from(i as f64);
if !js_sys::Number::is_safe_integer(&js_value) {
crate::console_debug!(
"Warning: Converting {} to JsValue as Integer, \
but it is outside the JavaScript safe-integer range",
i
);
}
js_value
}
SqlStorageValue::Float(f) => JsValue::from(f),
SqlStorageValue::String(s) => JsValue::from(s),
SqlStorageValue::Blob(bytes) => {
let array = js_sys::Uint8Array::new_with_length(bytes.len() as u32);
array.copy_from(&bytes);
array.into()
}
}
}
}
impl TryFrom<JsValue> for SqlStorageValue {
type Error = crate::Error;
fn try_from(js_val: JsValue) -> Result<Self> {
if js_val.is_null() || js_val.is_undefined() {
Ok(SqlStorageValue::Null)
} else if let Some(bool_val) = js_val.as_bool() {
Ok(SqlStorageValue::Boolean(bool_val))
} else if let Some(str_val) = js_val.as_string() {
Ok(SqlStorageValue::String(str_val))
} else if let Some(num_val) = js_val.as_f64() {
if js_sys::Number::is_safe_integer(&js_val) {
Ok(SqlStorageValue::Integer(num_val as i64))
} else {
Ok(SqlStorageValue::Float(num_val))
}
} else {
js_val
.dyn_into::<js_sys::Uint8Array>()
.map(|uint8_array| {
let mut bytes = vec![0u8; uint8_array.length() as usize];
uint8_array.copy_to(&mut bytes);
SqlStorageValue::Blob(bytes)
})
.or_else(|js_val| {
js_val
.dyn_into::<js_sys::ArrayBuffer>()
.map(|array_buffer| {
let uint8_array = js_sys::Uint8Array::new(&array_buffer);
let mut bytes = vec![0u8; uint8_array.length() as usize];
uint8_array.copy_to(&mut bytes);
SqlStorageValue::Blob(bytes)
})
})
.map_err(|_| Error::from("Unsupported JavaScript value type"))
}
}
}
#[derive(Clone, Debug)]
pub struct SqlStorage {
inner: SqlStorageSys,
}
unsafe impl Send for SqlStorage {}
unsafe impl Sync for SqlStorage {}
impl SqlStorage {
pub(crate) fn new(inner: SqlStorageSys) -> Self {
Self { inner }
}
pub fn database_size(&self) -> usize {
self.inner.database_size() as usize
}
pub fn exec(
&self,
query: &str,
bindings: impl Into<Option<Vec<SqlStorageValue>>>,
) -> Result<SqlCursor> {
let array = Array::new();
if let Some(bindings) = bindings.into() {
for v in bindings {
array.push(&v.into());
}
}
let cursor = self.inner.exec(query, array).map_err(Error::from)?;
Ok(SqlCursor { inner: cursor })
}
pub fn exec_raw(
&self,
query: &str,
bindings: impl Into<Option<Vec<JsValue>>>,
) -> Result<SqlCursor> {
let array = Array::new();
if let Some(bindings) = bindings.into() {
for v in bindings {
array.push(&v);
}
}
let cursor = self.inner.exec(query, array).map_err(Error::from)?;
Ok(SqlCursor { inner: cursor })
}
}
impl AsRef<JsValue> for SqlStorage {
fn as_ref(&self) -> &JsValue {
&self.inner
}
}
#[derive(Clone, Debug)]
pub struct SqlCursor {
inner: SqlStorageCursorSys,
}
unsafe impl Send for SqlCursor {}
unsafe impl Sync for SqlCursor {}
#[derive(Debug)]
pub struct SqlCursorIterator<T> {
cursor: SqlCursor,
_phantom: std::marker::PhantomData<T>,
}
impl<T> Iterator for SqlCursorIterator<T>
where
T: DeserializeOwned,
{
type Item = Result<T>;
fn next(&mut self) -> Option<Self::Item> {
let result = self.cursor.inner.next();
let done = js_sys::Reflect::get(&result, &JsValue::from("done"))
.ok()
.and_then(|v| v.as_bool())
.unwrap_or(true);
if done {
None
} else {
let value = js_sys::Reflect::get(&result, &JsValue::from("value"))
.map_err(Error::from)
.and_then(|js_val| swb::from_value(js_val).map_err(Error::from));
Some(value)
}
}
}
#[derive(Debug)]
pub struct SqlCursorRawIterator {
inner: js_sys::Iterator,
}
impl Iterator for SqlCursorRawIterator {
type Item = Result<Vec<SqlStorageValue>>;
fn next(&mut self) -> Option<Self::Item> {
match self.inner.next() {
Ok(iterator_next) => {
if iterator_next.done() {
None
} else {
let js_val = iterator_next.value();
let array_result = js_array_to_sql_storage_values(js_val);
Some(array_result)
}
}
Err(e) => Some(Err(Error::from(e))),
}
}
}
fn js_array_to_sql_storage_values(js_val: JsValue) -> Result<Vec<SqlStorageValue>> {
let array = js_sys::Array::from(&js_val);
let mut values = Vec::with_capacity(array.length() as usize);
for i in 0..array.length() {
let item = array.get(i);
let sql_value = SqlStorageValue::try_from(item)?;
values.push(sql_value);
}
Ok(values)
}
impl SqlCursor {
pub fn to_array<T>(&self) -> Result<Vec<T>>
where
T: DeserializeOwned,
{
let arr = self.inner.to_array();
let mut out = Vec::with_capacity(arr.length() as usize);
for val in arr.iter() {
out.push(swb::from_value(val)?);
}
Ok(out)
}
pub fn one<T>(&self) -> Result<T>
where
T: DeserializeOwned,
{
let val = self.inner.one();
Ok(swb::from_value(val)?)
}
pub fn column_names(&self) -> Vec<String> {
self.inner
.column_names()
.iter()
.map(|v| v.as_string().unwrap_or_default())
.collect()
}
pub fn rows_read(&self) -> usize {
self.inner.rows_read() as usize
}
pub fn rows_written(&self) -> usize {
self.inner.rows_written() as usize
}
pub fn next<T>(&self) -> SqlCursorIterator<T>
where
T: DeserializeOwned,
{
SqlCursorIterator {
cursor: self.clone(),
_phantom: std::marker::PhantomData,
}
}
pub fn raw(&self) -> SqlCursorRawIterator {
SqlCursorRawIterator {
inner: self.inner.raw(),
}
}
}
impl Iterator for SqlCursor {
type Item = Result<JsValue>;
fn next(&mut self) -> Option<Self::Item> {
let result = self.inner.next();
let done = js_sys::Reflect::get(&result, &JsValue::from("done"))
.ok()
.and_then(|v| v.as_bool())
.unwrap_or(true);
if done {
None
} else {
let value = js_sys::Reflect::get(&result, &JsValue::from("value")).map_err(Error::from);
Some(value)
}
}
}