use std::thread;
use std::thread::JoinHandle;
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Mutex};
pub struct Worker<T> {
sender: Vec<Sender<T>>,
stopper: Vec<Sender<()>>,
handler: Vec<JoinHandle<()>>,
round_robin: usize,
}
impl<T> Worker<T>
where T: Send + 'static
{
pub fn single<F>(callback: F) -> Self
where F: Fn(T) + Send + Sync + 'static
{
let (tx, stx, handle) = spawn_thread(Arc::new(Mutex::new(callback)));
Worker {
sender: vec![tx],
stopper: vec![stx],
handler: vec![handle],
round_robin: 0,
}
}
pub fn multi<F>(count: usize, callback: F) -> Self
where F: Fn(T) + Send + Sync + 'static
{
let cb = Arc::new(Mutex::new(callback));
let mut sender = Vec::new();
let mut stopper = Vec::new();
let mut handler = Vec::new();
for _ in 0..count {
let (tx, stx, handle) = spawn_thread(cb.clone());
sender.push(tx);
stopper.push(stx);
handler.push(handle);
}
Worker {
sender: sender,
stopper: stopper,
handler: handler,
round_robin: 0,
}
}
pub fn add(&mut self, item: T) {
self.sender[self.round_robin].send(item).unwrap();
self.round_robin = if (self.round_robin + 1) >= self.sender.len() {
0
} else {
self.round_robin + 1
}
}
pub fn stop(self) {
for stopper in self.stopper {
stopper.send(()).unwrap();
}
for handler in self.handler {
handler.join().unwrap();
}
}
}
fn spawn_thread<T, F>(callback: Arc<Mutex<F>>) -> (Sender<T>, Sender<()>, JoinHandle<()>)
where T: Send + 'static,
F: Fn(T) + Send + Sync + 'static
{
let (tx, rx) = channel();
let (stx, srx) = channel();
(tx,
stx,
thread::spawn(move || {
loop {
match srx.try_recv() {
Ok(_) => break,
Err(_) => (),
};
match rx.try_recv() {
Ok(res) => callback.lock().unwrap()(res),
Err(_) => (),
}
}
}))
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time;
use std::sync::{Arc, Mutex};
#[test]
fn test_single_worker() {
let data = Arc::new(Mutex::new(Vec::new()));
let data_clone = data.clone();
let mut w = Worker::<&'static str>::single(move |r| {
data_clone.lock().unwrap().push(r.clone());
});
w.add("1");
w.add("2");
thread::sleep(time::Duration::from_millis(100));
assert_eq!(2, data.lock().unwrap().len());
w.add("3");
thread::sleep(time::Duration::from_millis(100));
assert_eq!(3, data.lock().unwrap().len());
w.stop();
}
#[test]
fn test_multiple_worker() {
let data = Arc::new(Mutex::new(Vec::new()));
let data_clone = data.clone();
let mut w = Worker::<&'static str>::multi(4, move |r| {
data_clone.lock().unwrap().push(r.clone());
});
w.add("1");
w.add("2");
w.add("3");
w.add("4");
thread::sleep(time::Duration::from_millis(100));
assert_eq!(4, data.lock().unwrap().len());
w.add("5");
thread::sleep(time::Duration::from_millis(100));
assert_eq!(5, data.lock().unwrap().len());
w.stop();
}
}