[go: up one dir, main page]

Documentation
use std::thread;
use std::thread::JoinHandle;
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Mutex};

///
/// The worker is a thread based working queue.
/// After you initialized the worker with the callback you want you can add items to it.
/// As soon as the item is ready to process the callback will be called to handle the logic.
///
/// # Example
///
/// ```
/// let mut w = Worker::<&'static str>::single(move |r| {
///     println!("{:?}", r);
/// });
///
/// w.add("1");
/// w.add("2");
///
/// w.stop();
/// ```
///
pub struct Worker<T> {
    sender: Vec<Sender<T>>,
    stopper: Vec<Sender<()>>,
    handler: Vec<JoinHandle<()>>,
    round_robin: usize,
}

impl<T> Worker<T>
    where T: Send + 'static
{
    /// Creates a single instance worker
    pub fn single<F>(callback: F) -> Self
        where F: Fn(T) + Send + Sync + 'static
    {
        let (tx, stx, handle) = spawn_thread(Arc::new(Mutex::new(callback)));
        Worker {
            sender: vec![tx],
            stopper: vec![stx],
            handler: vec![handle],
            round_robin: 0,
        }
    }

    /// Creates a multi instance worker (count defines the number of threads to be started)
    pub fn multi<F>(count: usize, callback: F) -> Self
        where F: Fn(T) + Send + Sync + 'static
    {
        let cb = Arc::new(Mutex::new(callback));

        let mut sender = Vec::new();
        let mut stopper = Vec::new();
        let mut handler = Vec::new();
        for _ in 0..count {
            let (tx, stx, handle) = spawn_thread(cb.clone());
            sender.push(tx);
            stopper.push(stx);
            handler.push(handle);
        }
        Worker {
            sender: sender,
            stopper: stopper,
            handler: handler,
            round_robin: 0,
        }
    }

    /// Adds an item to the working queue
    pub fn add(&mut self, item: T) {
        self.sender[self.round_robin].send(item).unwrap();

        self.round_robin = if (self.round_robin + 1) >= self.sender.len() {
            0
        } else {
            self.round_robin + 1
        }
    }

    /// Stops the workers gracefully
    pub fn stop(self) {
        for stopper in self.stopper {
            stopper.send(()).unwrap();
        }
        for handler in self.handler {
            handler.join().unwrap();
        }
    }
}

/// Spawn a single worker thread (can be used to spawn multiple threads)
fn spawn_thread<T, F>(callback: Arc<Mutex<F>>) -> (Sender<T>, Sender<()>, JoinHandle<()>)
    where T: Send + 'static,
          F: Fn(T) + Send + Sync + 'static
{
    let (tx, rx) = channel();
    let (stx, srx) = channel();
    (tx,
     stx,
     thread::spawn(move || {
        loop {
            match srx.try_recv() {
                Ok(_) => break,
                Err(_) => (),
            };
            match rx.try_recv() {
                Ok(res) => callback.lock().unwrap()(res),
                Err(_) => (),
            }
        }
    }))
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::thread;
    use std::time;
    use std::sync::{Arc, Mutex};

    #[test]
    fn test_single_worker() {
        let data = Arc::new(Mutex::new(Vec::new()));
        let data_clone = data.clone();
        let mut w = Worker::<&'static str>::single(move |r| {
            data_clone.lock().unwrap().push(r.clone());
        });

        w.add("1");
        w.add("2");

        thread::sleep(time::Duration::from_millis(100));
        assert_eq!(2, data.lock().unwrap().len());

        w.add("3");
        thread::sleep(time::Duration::from_millis(100));
        assert_eq!(3, data.lock().unwrap().len());

        w.stop();
    }

    #[test]
    fn test_multiple_worker() {
        let data = Arc::new(Mutex::new(Vec::new()));
        let data_clone = data.clone();
        let mut w = Worker::<&'static str>::multi(4, move |r| {
            data_clone.lock().unwrap().push(r.clone());
        });

        w.add("1");
        w.add("2");
        w.add("3");
        w.add("4");

        thread::sleep(time::Duration::from_millis(100));
        assert_eq!(4, data.lock().unwrap().len());

        w.add("5");
        thread::sleep(time::Duration::from_millis(100));
        assert_eq!(5, data.lock().unwrap().len());

        w.stop();
    }
}