use std::collections::HashMap;
use std::rc::Rc;
use futures::{future::LocalBoxFuture, Future};
use matchit::{Match, Node};
use worker_kv::KvStore;
use crate::{
durable::ObjectNamespace,
env::{Env, Secret, Var},
http::Method,
request::Request,
response::Response,
Result,
};
type HandlerFn<D> = fn(Request, RouteContext<D>) -> Result<Response>;
type AsyncHandlerFn<'a, D> =
Rc<dyn 'a + Fn(Request, RouteContext<D>) -> LocalBoxFuture<'a, Result<Response>>>;
pub struct RouteParams(HashMap<String, String>);
impl RouteParams {
fn get(&self, key: &str) -> Option<&String> {
self.0.get(key)
}
}
enum Handler<'a, D> {
Async(AsyncHandlerFn<'a, D>),
Sync(HandlerFn<D>),
}
impl<D> Clone for Handler<'_, D> {
fn clone(&self) -> Self {
match self {
Self::Async(rc) => Self::Async(rc.clone()),
Self::Sync(func) => Self::Sync(*func),
}
}
}
pub struct Router<'a, D> {
handlers: HashMap<Method, Node<Handler<'a, D>>>,
or_else_any_method: Node<Handler<'a, D>>,
data: D,
}
pub struct RouteContext<D> {
data: D,
env: Env,
params: RouteParams,
}
impl<D> RouteContext<D> {
pub fn data(&self) -> &D {
&self.data
}
pub fn get_env(self) -> Env {
self.env
}
pub fn secret(&self, binding: &str) -> Result<Secret> {
self.env.secret(binding)
}
pub fn var(&self, binding: &str) -> Result<Var> {
self.env.var(binding)
}
pub fn kv(&self, binding: &str) -> Result<KvStore> {
KvStore::from_this(&self.env, binding).map_err(From::from)
}
pub fn durable_object(&self, binding: &str) -> Result<ObjectNamespace> {
self.env.durable_object(binding)
}
pub fn param(&self, key: &str) -> Option<&String> {
self.params.get(key)
}
}
impl<'a> Router<'a, ()> {
pub fn new() -> Self {
Self::with_data(())
}
}
impl<'a, D: 'a> Router<'a, D> {
pub fn with_data(data: D) -> Self {
Self {
handlers: HashMap::new(),
or_else_any_method: Node::new(),
data,
}
}
pub fn head(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), vec![Method::Head]);
self
}
pub fn get(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), vec![Method::Get]);
self
}
pub fn post(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), vec![Method::Post]);
self
}
pub fn put(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), vec![Method::Put]);
self
}
pub fn patch(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), vec![Method::Patch]);
self
}
pub fn delete(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), vec![Method::Delete]);
self
}
pub fn options(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), vec![Method::Options]);
self
}
pub fn on(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.add_handler(pattern, Handler::Sync(func), Method::all());
self
}
pub fn or_else_any_method(mut self, pattern: &str, func: HandlerFn<D>) -> Self {
self.or_else_any_method
.insert(pattern, Handler::Sync(func))
.unwrap_or_else(|e| panic!("failed to register route for {} pattern: {}", pattern, e));
self
}
pub fn head_async<T>(mut self, pattern: &str, func: fn(Request, RouteContext<D>) -> T) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, info| Box::pin(func(req, info)))),
vec![Method::Head],
);
self
}
pub fn get_async<T>(mut self, pattern: &str, func: fn(Request, RouteContext<D>) -> T) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, info| Box::pin(func(req, info)))),
vec![Method::Get],
);
self
}
pub fn post_async<T>(mut self, pattern: &str, func: fn(Request, RouteContext<D>) -> T) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, info| Box::pin(func(req, info)))),
vec![Method::Post],
);
self
}
pub fn put_async<T>(mut self, pattern: &str, func: fn(Request, RouteContext<D>) -> T) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, info| Box::pin(func(req, info)))),
vec![Method::Put],
);
self
}
pub fn patch_async<T>(mut self, pattern: &str, func: fn(Request, RouteContext<D>) -> T) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, info| Box::pin(func(req, info)))),
vec![Method::Patch],
);
self
}
pub fn delete_async<T>(mut self, pattern: &str, func: fn(Request, RouteContext<D>) -> T) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, info| Box::pin(func(req, info)))),
vec![Method::Delete],
);
self
}
pub fn options_async<T>(
mut self,
pattern: &str,
func: fn(Request, RouteContext<D>) -> T,
) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, info| Box::pin(func(req, info)))),
vec![Method::Options],
);
self
}
pub fn on_async<T>(mut self, pattern: &str, func: fn(Request, RouteContext<D>) -> T) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.add_handler(
pattern,
Handler::Async(Rc::new(move |req, route| Box::pin(func(req, route)))),
Method::all(),
);
self
}
pub fn or_else_any_method_async<T>(
mut self,
pattern: &str,
func: fn(Request, RouteContext<D>) -> T,
) -> Self
where
T: Future<Output = Result<Response>> + 'a,
{
self.or_else_any_method
.insert(
pattern,
Handler::Async(Rc::new(move |req, route| Box::pin(func(req, route)))),
)
.unwrap_or_else(|e| panic!("failed to register route for {} pattern: {}", pattern, e));
self
}
fn add_handler(&mut self, pattern: &str, func: Handler<'a, D>, methods: Vec<Method>) {
for method in methods {
self.handlers
.entry(method.clone())
.or_insert_with(Node::new)
.insert(pattern, func.clone())
.unwrap_or_else(|e| {
panic!(
"failed to register {:?} route for {} pattern: {}",
method, pattern, e
)
});
}
}
pub async fn run(self, req: Request, env: Env) -> Result<Response> {
let (handlers, data, or_else_any_method_handler) = self.split();
if let Some(handlers) = handlers.get(&req.method()) {
if let Ok(Match { value, params }) = handlers.at(&req.path()) {
let route_info = RouteContext {
data,
env,
params: params.into(),
};
return match value {
Handler::Sync(func) => (func)(req, route_info),
Handler::Async(func) => (func)(req, route_info).await,
};
}
}
for method in Method::all() {
if method == Method::Head || method == Method::Options || method == Method::Trace {
continue;
}
if let Some(handlers) = handlers.get(&method) {
if let Ok(Match { .. }) = handlers.at(&req.path()) {
return Response::error("Method Not Allowed", 405);
}
}
}
if let Ok(Match { value, params }) = or_else_any_method_handler.at(&req.path()) {
let route_info = RouteContext {
data,
env,
params: params.into(),
};
return match value {
Handler::Sync(func) => (func)(req, route_info),
Handler::Async(func) => (func)(req, route_info).await,
};
}
Response::error("Not Found", 404)
}
}
type NodeWithHandlers<'a, D> = Node<Handler<'a, D>>;
impl<'a, D: 'a> Router<'a, D> {
fn split(
self,
) -> (
HashMap<Method, NodeWithHandlers<'a, D>>,
D,
NodeWithHandlers<'a, D>,
) {
(self.handlers, self.data, self.or_else_any_method)
}
}
impl From<matchit::Params<'_, '_>> for RouteParams {
fn from(p: matchit::Params) -> Self {
let mut route_params = RouteParams(HashMap::new());
for (ident, value) in p.iter() {
route_params.0.insert(ident.into(), value.into());
}
route_params
}
}