1#include "tpool.h"
  2#include <stdio.h>
  3#include <stdlib.h>
  4
  5typedef struct ThreadPoolJobNode {
  6	ThreadPoolJob job;
  7	struct ThreadPoolJobNode *next;
  8} ThreadPoolJobNode;
  9
 10struct ThreadPool {
 11	pthread_mutex_t lock;
 12	pthread_cond_t notify;
 13	pthread_cond_t working_cond;
 14
 15	pthread_t *threads;
 16	int num_threads;
 17
 18	ThreadPoolJobNode *queue_head;
 19	ThreadPoolJobNode *queue_tail;
 20
 21	int active_jobs; // Jobs currently running
 22	int queued_jobs; // Jobs waiting in queue
 23	bool stop;
 24};
 25
 26static void *tp_worker(void *arg) {
 27	ThreadPool *pool = (ThreadPool *)arg;
 28
 29	while (1) {
 30		pthread_mutex_lock(&pool->lock);
 31
 32		while (pool->queue_head == NULL && !pool->stop) {
 33			pthread_cond_wait(&pool->notify, &pool->lock);
 34		}
 35
 36		if (pool->stop && pool->queue_head == NULL) {
 37			pthread_mutex_unlock(&pool->lock);
 38			break;
 39		}
 40
 41		ThreadPoolJobNode *node = pool->queue_head;
 42		pool->queue_head = node->next;
 43		if (pool->queue_head == NULL) {
 44			pool->queue_tail = NULL;
 45		}
 46
 47		pool->queued_jobs--;
 48		pool->active_jobs++;
 49
 50		pthread_mutex_unlock(&pool->lock);
 51
 52		// Execute job
 53		if (node->job.function) {
 54			node->job.function(node->job.arg);
 55		}
 56		free(node);
 57
 58		pthread_mutex_lock(&pool->lock);
 59		pool->active_jobs--;
 60		if (pool->active_jobs == 0 && pool->queue_head == NULL) {
 61			pthread_cond_signal(&pool->working_cond);
 62		}
 63		pthread_mutex_unlock(&pool->lock);
 64	}
 65
 66	return NULL;
 67}
 68
 69ThreadPool *tp_create(int num_threads) {
 70	ThreadPool *pool = (ThreadPool *)malloc(sizeof(ThreadPool));
 71	if (pool == NULL)
 72		return NULL;
 73
 74	pool->num_threads = num_threads;
 75	pool->queue_head = NULL;
 76	pool->queue_tail = NULL;
 77	pool->active_jobs = 0;
 78	pool->queued_jobs = 0;
 79	pool->stop = false;
 80
 81	pthread_mutex_init(&pool->lock, NULL);
 82	pthread_cond_init(&pool->notify, NULL);
 83	pthread_cond_init(&pool->working_cond, NULL);
 84
 85	pool->threads = (pthread_t *)malloc(sizeof(pthread_t) * num_threads);
 86	for (int i = 0; i < num_threads; i++) {
 87		pthread_create(&pool->threads[i], NULL, tp_worker, pool);
 88	}
 89
 90	return pool;
 91}
 92
 93void tp_add_job(ThreadPool *pool, thread_func_t function, void *arg) {
 94	ThreadPoolJobNode *node = (ThreadPoolJobNode *)malloc(sizeof(ThreadPoolJobNode));
 95	if (node == NULL) {
 96		perror("malloc");
 97		exit(EXIT_FAILURE);
 98	}
 99	node->job.function = function;
100	node->job.arg = arg;
101	node->next = NULL;
102
103	pthread_mutex_lock(&pool->lock);
104
105	if (pool->queue_tail) {
106		pool->queue_tail->next = node;
107	} else {
108		pool->queue_head = node;
109	}
110	pool->queue_tail = node;
111
112	pool->queued_jobs++;
113	pthread_cond_signal(&pool->notify);
114
115	pthread_mutex_unlock(&pool->lock);
116}
117
118void tp_wait(ThreadPool *pool) {
119	pthread_mutex_lock(&pool->lock);
120	while (pool->active_jobs > 0 || pool->queue_head != NULL) {
121		pthread_cond_wait(&pool->working_cond, &pool->lock);
122	}
123	pthread_mutex_unlock(&pool->lock);
124}
125
126void tp_destroy(ThreadPool *pool) {
127	pthread_mutex_lock(&pool->lock);
128	pool->stop = true;
129	pthread_cond_broadcast(&pool->notify);
130	pthread_mutex_unlock(&pool->lock);
131
132	for (int i = 0; i < pool->num_threads; i++) {
133		pthread_join(pool->threads[i], NULL);
134	}
135
136	free(pool->threads);
137	pthread_mutex_destroy(&pool->lock);
138	pthread_cond_destroy(&pool->notify);
139	pthread_cond_destroy(&pool->working_cond);
140	free(pool);
141}