/*
 * Telnet Proxy Daemon
 * Copyright (c) 2004  AwesomePlay Productions, Inc.
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 *  * Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
 * DAMAGE.
 */

#include <time.h>
#include <signal.h>
#include <ctype.h>
#include <stdlib.h>
#include <fcntl.h>
#include <unistd.h>
#include <stdio.h>
#include <inttypes.h>
#include <errno.h>
#include <string.h>
#include <pthread.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/poll.h>
#include <netdb.h>
#include <stdarg.h>
#include <pwd.h>
#include <grp.h>
#include <assert.h>

/* ---- CONFIGURATION ---- */

#define BUFFER_SIZE 2048

#define DEFAULT_MAX_CLIENTS 50
#define DEFAULT_MAX_HOST_CLIENTS 3
#define DEFAULT_LISTEN_PORT 9596
#define DEFAULT_HOST_LIST "hosts.txt"
#define DEFAULT_BAN_LIST "deny.txt"
#define DEFAULT_CONNECT_TIMEOUT 30
#define DEFAULT_ACTIVITY_TIMEOUT (60 * 15)

/* ---- DATA TYPES ---- */

struct AllowedServer {
	char* host;
	uint16_t port;
	struct AllowedServer* next;
};

struct BannedClient {
	struct sockaddr_storage addr;
	uint8_t mask;
	struct BannedClient* next;
};

struct ClientAddr {
	struct sockaddr_storage addr;
	uint count;
	struct ClientAddr* prev;
	struct ClientAddr* next;
};

struct ClientInfo {
	int sock;
	struct sockaddr_storage addr;
};

struct Option {
	char short_name;
	char* long_name;
	char** string_arg;
	int* int_arg;
	int* bool_arg;
};

enum ClientState {
	CLIENT_INIT,
	CLIENT_ACTIVE,
	CLIENT_FINISH,
	CLIENT_SHUTDOWN
};

enum LogLevel {
	LOG_NOTICE,
	LOG_WARNING,
	LOG_ERROR,
	LOG_DEBUG
};

/* ---- FUNCTIONS ---- */

int allowed_servers_load (char* filename);
int allowed_servers_check (char* host, int port);

int banned_clients_load (char* filename);
int banned_clients_check (struct sockaddr_storage* addr);

int client_list_add (struct sockaddr_storage* addr);
void client_list_remove (struct sockaddr_storage* addr);

void signal_sighup (int);
void signal_sigterm (int);

int parse_options (struct Option*, int argc, char** argv);

int log_open (char* filename);
void log_msg (enum LogLevel level, char* format, ...);
void log_close (void);

int write_string (int sock, char* string); /* not safe, but simple */

char* sockaddr_name_of (struct sockaddr_storage* addr, char* buffer, size_t len);

/* ---- GLOBALS ---- */

struct AllowedServer* allowed_servers = NULL;
pthread_mutex_t allowed_servers_lock = PTHREAD_MUTEX_INITIALIZER;

struct BannedClient* banned_clients = NULL;
pthread_mutex_t banned_clients_lock = PTHREAD_MUTEX_INITIALIZER;

int max_clients = DEFAULT_MAX_CLIENTS;
int max_host_clients = DEFAULT_MAX_HOST_CLIENTS;
int client_count = 0;
struct ClientAddr* client_list = NULL;
pthread_mutex_t client_list_lock = PTHREAD_MUTEX_INITIALIZER;

int connect_timeout = DEFAULT_CONNECT_TIMEOUT;
int activity_timeout = DEFAULT_ACTIVITY_TIMEOUT;

volatile int reload_flag = 0;
volatile int shutdown_flag = 0;

FILE* log_file = NULL;

/* ---- DEBUG LOG ---- */
#ifdef NDEBUG
#define log_dbg(format,...)
#else
#define log_dbg(format,args...) log_msg(LOG_DEBUG, (format), ## args)
#endif /* NDEBUG */

/* ---- BEGIN CODE ---- */

/* write_string()
   very dumb fuction to use - doesn't guarantee any or even
   some of the string is actually written
*/
int
write_string (int sock, char* string)
{
	return write(sock, string, strlen(string));
}

/* print_usage()
   print out a usage error message
*/
void
print_usage (char* self, struct Option* options)
{
	int i;

	assert(self != NULL);
	assert(options != NULL);

	fprintf(stderr, "Usage: ./proxy");
	for (i = 0; options[i].short_name != 0 || options[i].long_name != NULL; ++i) {
		if (options[i].long_name != NULL) {
			fprintf(stderr, " [--%s", options[i].long_name);
			if (options[i].short_name != 0)
				fprintf(stderr, "|-%c", options[i].short_name);
		} else {
			fprintf(stderr, " [-%c", options[i].short_name);
		}

		if (options[i].string_arg != NULL)
			fprintf(stderr, " <string>");
		else if (options[i].int_arg != NULL)
			fprintf(stderr, " <int>");

		fprintf(stderr, "]");
	}
	fprintf(stderr, "\n");
}

/* parse_options()
   read in options
   return 0 on success, -1 on failure
*/
int
parse_options (struct Option* options, int argc, char** argv)
{
	int opt;
	int i;

	assert(options != NULL);
	assert(argc >= 1);
	assert(argv != NULL);

	/* iterator through options */
	for (opt = 1; opt < argc; ++opt) {
		/* find corresponding option */
		for (i = 0; options[i].short_name != 0 || options[i].long_name != NULL; ++i) {
			if (
				/* match short opt? */
				(options[i].short_name != 0 && argv[opt][0] == '-' && argv[opt][1] == options[i].short_name && argv[opt][2] == 0) ||
				/* match long opt? */
				(options[i].long_name != NULL && argv[opt][0] == '-' && argv[opt][1] == '-' && !strcmp(argv[opt] + 2, options[i].long_name))
			/* found a match! */
			) {
				/* need arg but have none? */
				if (opt == argc - 1 && (options[i].int_arg != NULL || options[i].string_arg != NULL)) {
					fprintf(stderr, "Error: No value given for option: %s\n", argv[opt]);
					print_usage(argv[0], options);
					return -1;
				}

				/* set string arg */
				if (options[i].string_arg != NULL)
					*options[i].string_arg = argv[++opt];
				/* set int arg */
				if (options[i].int_arg != NULL)
					*options[i].int_arg = atol(argv[++opt]);
				/* set bool arg */
				if (options[i].bool_arg != NULL)
					*options[i].bool_arg = 1;

				break;
			}
		}

		/* hit end of options?  bad juju */
		if (options[i].short_name == 0 && options[i].long_name == NULL) {
			/* error! */
			fprintf(stderr, "Error: Invalid option: %s\n", argv[opt]);
			print_usage(argv[0], options);
			return -1;
		}
	}

	/* all found */
	return 0;
}

/* build_addr_mask()
   creates a bit mask from the given mask length
   used for network address masks (i.e., 192.168.1.2/24)
   NOTE: code taken from FreeBSD 'route' command source
   Copyright (C) FreeBSD
*/
void
build_addr_mask (uint8_t* buf, int mask, int max)
{
	int q, r;

	assert(buf != NULL);
	assert(mask <= max);
	assert(max > 0);

	memset (buf, 0, max / 8);

	q = mask >> 3;
	r = mask & 7;

	if (q > 0)
		memset(buf, 0xff, q);
	if (r > 0)
		*(buf + q) = (0xff00 >> r) & 0xff;
}

/* apply_addr_mask()
   applies the bit mask to the given address
*/
void
apply_addr_mask (void* addr, uint8_t mask, int family)
{
	int i;
	uint8_t buf[32]; /* more than big enough */
	int size;

	assert(addr != NULL);
	assert(mask > 0);
	assert(family == AF_INET || family == AF_INET6);

	if (family == AF_INET6)
		size = 16; /* 16 bytes in IPv6 */
	else
		size = 4; /* 4 bytes in IPv4 */

	build_addr_mask(buf, mask, size * 8);
	for (i = 0; i < size; ++i)
		((uint8_t*)addr)[i] &= buf[i];
}

/* addr_match()
   return 0 if both addresses match
*/
int
addr_match (struct sockaddr_storage* addr1, struct sockaddr_storage* addr2)
{
	assert(addr1 != NULL);
	assert(addr2 != NULL);

	/* different family?  not the same */
	if (addr1->ss_family != addr2->ss_family)
		return -1;

	/* check IPv6 */
	if (addr1->ss_family == AF_INET6) {
		if (IN6_ARE_ADDR_EQUAL(&((struct sockaddr_in6*)addr1)->sin6_addr, &((struct sockaddr_in6*)addr2)->sin6_addr))
			return 0;

	/* check IPv4 address */
	} else if (addr1->ss_family == AF_INET) {
		if (!memcmp(&((struct sockaddr_in*)addr1)->sin_addr, &((struct sockaddr_in*)addr2)->sin_addr, sizeof(((struct sockaddr_in*)addr2)->sin_addr)))
			return 0;
	}

	/* no match */
	return -1;
}

/* parse_addr()
   parses a network address
   returns -1 on parse error
*/
int
parse_addr (char* addr, struct sockaddr_storage* host, uint8_t* mask)
{
	char buffer[128];
	int inmask;
	char* slash;

	assert(addr != NULL);
	assert(host != NULL);
	assert(mask != NULL);

	/* clear address */
	memset(host, 0, sizeof(struct sockaddr_storage));

	/* put in buffer, snprintf() guarnatees NUL byte */
	snprintf(buffer, sizeof(buffer), "%s", addr);

	/* get mask - have we a mask? */
	inmask = -1;
	slash = strchr(buffer, '/');
	if (slash != NULL) {
		*slash = '\0';
		++ slash;

		/* don't use atoi or strtol, guarantee we parse it right */
		inmask = 0;
		while (*slash != '\0') {
			if (!isdigit(*slash))
				break;
			inmask *= 10;
			inmask += *slash - '0';
			++ slash;
		}

		/* only numbers, rights? */
		if (*slash != '\0')
			return -1;
	}

	/* parse IPv6 first */
	if (inet_pton(AF_INET6, buffer, &((struct sockaddr_in6*)host)->sin6_addr) > 0) {
		/* mask must be <= 128 */
		if (inmask > 128)
			return -1;
		/* default? */
		if (inmask == -1)
			inmask = 128;

		/* set family and mask */
		host->ss_family = AF_INET6;
		*mask = inmask;
		return 0;

	/* try IPv4 parsing */
	} else if (inet_pton(AF_INET, buffer, &((struct sockaddr_in*)host)->sin_addr) > 0) {
		// check mask
		if (inmask > 32)
			return -1; // FAIL
		/* default? */
		if (inmask == -1)
			inmask = 32;

		/* set family and mask */
		((struct sockaddr_in*)host)->sin_family = AF_INET;
		*mask = inmask;
		return 0;
	}

	/* no match */
	return -1;
}

/* allowed_servers_check()
   return 0 if the host/port is allowed, -1 if it's not
*/
int
allowed_servers_check (char* host, int port)
{
	struct AllowedServer* hinfo;

	assert(host != NULL);
	assert(port > 0 && port <= UINT16_MAX);

	pthread_mutex_lock(&allowed_servers_lock);

	/* scan list */
	hinfo = allowed_servers;
	while (hinfo != NULL) {
		/* match?  return success (0) */
		if (port == hinfo->port && !strcmp(host, hinfo->host)) {
			pthread_mutex_unlock(&allowed_servers_lock);
			return 0;
		}
		hinfo = hinfo->next;
	}

	pthread_mutex_unlock(&allowed_servers_lock);

	/* no matches, failure (-1) */
	return -1;
}

/* allowed_servers_load()
   load the list of allowed hosts
*/
int
allowed_servers_load (char* filename)
{
	FILE* file;
	char line[512];
	int lineno;
	long port;
	char* sep;
	char* end;
	struct AllowedServer* hinfo;

	/* open file */
	file = fopen(filename, "rt");
	if (file == NULL) {
		log_msg(LOG_ERROR, "allowed_servers_load(): fopen() failed for '%s': %s\n", filename, strerror(errno));
		return -1;
	}

	/* lock resource */
	pthread_mutex_lock(&allowed_servers_lock);

	/* clear list */
	while (allowed_servers != NULL) {
		hinfo = allowed_servers->next;
		free(allowed_servers->host);
		free(allowed_servers);
		allowed_servers = hinfo;
	}

	/* read in file */
	lineno = 0;
	while (fgets(line, sizeof(line), file) != NULL) {
		++lineno;

		/* trim spaces/newlines */
		for (end = line + strlen(line) - 1; end >= line; --end)
			if (isspace(*end))
				*end = 0;

		/* empty? */
		if (!strlen(line))
			continue;

		/* find the host:port separator */
		sep = strchr(line, ':');
		if (sep == NULL) {
			log_msg(LOG_WARNING, "allowed_servers_load(): malformed entry at %s:%d\n", filename, lineno);
			continue;
		}
		*sep = 0;

		/* read and sanity check the port number */
		port = strtol(sep + 1, &end, 10);
		if (*end != 0 || port < 1 || port > UINT16_MAX) {
			log_msg(LOG_WARNING, "allowed_servers_load(): malformed entry at %s:%d\n", filename, lineno);
			continue;
		}

		/* sanity check the host name */
		if (strlen(line) < 1) {
			log_msg(LOG_WARNING, "allowed_servers_load(): malformed entry at %s:%d\n", filename, lineno);
			continue;
		}

		/* allocate */
		hinfo = (struct AllowedServer*)malloc(sizeof(struct AllowedServer));
		if (hinfo == NULL) {
			log_msg(LOG_WARNING, "allowed_servers_load(): malloc() failed: %s\n", strerror(errno));
			continue;
		}

		/* copy host string */
		hinfo->host = strdup(line);
		if (hinfo->host == NULL) {
			log_msg(LOG_WARNING, "allowed_servers_add(): strdup() failed: %s\n", strerror(errno));
			free(hinfo);
			continue;
		}

		/* set port */
		hinfo->port = port;

		/* make new list head */
		hinfo->next = allowed_servers;
		allowed_servers = hinfo;
	}

	/* unlock */
	pthread_mutex_unlock(&allowed_servers_lock);

	/* finish up */
	fclose(file);

	return 0;
}

/* banned_clients_load()
   load the list of banned clients
*/
int
banned_clients_load (char* filename)
{
	FILE* file;
	char line[512];
	char* end;
	int lineno;
	struct sockaddr_storage addr;
	uint8_t mask;
	struct BannedClient* cinfo;

	/* open file */
	file = fopen(filename, "rt");
	if (file == NULL) {
		log_msg(LOG_ERROR, "banned_clients_load(): fopen() failed for '%s': %s\n", filename, strerror(errno));
		return -1;
	}

	/* lock resource */
	pthread_mutex_lock(&banned_clients_lock);

	/* clear list */
	while (banned_clients != NULL) {
		cinfo = banned_clients->next;
		free(banned_clients);
		banned_clients = cinfo;
	}

	/* read in file */
	lineno = 0;
	while (fgets(line, sizeof(line), file) != NULL) {
		++lineno;

		/* trim spaces/newlines */
		for (end = line + strlen(line) - 1; end >= line; --end)
			if (isspace(*end))
				*end = 0;

		/* empty? */
		if (!strlen(line))
			continue;

		/* parse */
		if (parse_addr(line, &addr, &mask) == -1) {
			log_msg(LOG_WARNING, "banned_clients_load(): malformed entry at %s:%d\n", filename, lineno);
			continue;
		}


		/* allocate */
		cinfo = (struct BannedClient*)malloc(sizeof(struct BannedClient));
		if (cinfo == NULL) {
			log_msg(LOG_WARNING, "banned_clients_load(): malloc() failed: %s\n", strerror(errno));
			continue;
		}

		/* set data */
		cinfo->addr = addr;
		cinfo->mask = mask;

		/* make new list head */
		cinfo->next = banned_clients;
		banned_clients = cinfo;
	}

	/* unlock */
	pthread_mutex_unlock(&banned_clients_lock);

	/* finish up */
	fclose(file);

	return 0;
}

/* banned_clients_check()
   returns non-zero if address is banned
*/
int
banned_clients_check (struct sockaddr_storage* addr)
{
	struct BannedClient* cinfo;
	struct sockaddr_storage temp;

	/* lock */
	pthread_mutex_lock(&banned_clients_lock);

	/* iterate over all banned clients */
	cinfo = banned_clients;
	while (cinfo != NULL) {
		/* not same family?  skip */
		if (cinfo->addr.ss_family != addr->ss_family) {
			cinfo = cinfo->next;
			continue;
		}

		/* temporary address */
		temp = *addr;

		/* apply mask */
		if (cinfo->addr.ss_family == AF_INET6)
			apply_addr_mask(&((struct sockaddr_in6*)&temp)->sin6_addr, cinfo->mask, cinfo->addr.ss_family);
		else
			apply_addr_mask(&((struct sockaddr_in*)&temp)->sin_addr, cinfo->mask, cinfo->addr.ss_family);

		/* check equality */
		if (addr_match(&temp, &cinfo->addr) == 0) {
			/* match - banned!  unlock and return */
			pthread_mutex_unlock(&banned_clients_lock);
			return -1;
		}

		cinfo = cinfo->next;
	}

	/* unlock */
	pthread_mutex_unlock(&banned_clients_lock);

	/* no matches, all good */
	return 0;
}

/* client_add()
   add another count for the given address
   returns 0 on success
   returns -1 on failure (too many total clients)
   returns -2 on failure (too many connections from this address)
*/
int
client_add (struct sockaddr_storage* addr)
{
	struct ClientAddr* cinfo;

	/* lock */
	pthread_mutex_lock(&client_list_lock);

	/* already at max? fail. */
	if (client_count >= max_clients) {
		pthread_mutex_unlock(&client_list_lock);
		return -1;
	}

	/* find client */
	cinfo = client_list;
	while (cinfo != NULL) {
		if (addr_match(&cinfo->addr, addr) == 0) {
			/* at max already? */
			if (cinfo->count >= max_host_clients) {
				pthread_mutex_unlock(&client_list_lock);
				return -2;
			}

			/* increment count, break */
			++cinfo->count;
			break;
		}

		cinfo = cinfo->next;
	}

	/* no client found? */
	if (cinfo == NULL) {
		/* allocate */
		cinfo = (struct ClientAddr*)malloc(sizeof(struct ClientAddr));
		if (cinfo == NULL) {
			log_msg(LOG_ERROR, "client_add(): malloc() failed: %s\n", strerror(errno));
			pthread_mutex_unlock(&client_list_lock);
			return -2;
		}

		/* initialize */
		cinfo->addr = *addr;
		cinfo->count = 1;

		/* put on list */
		cinfo->prev = NULL;
		cinfo->next = client_list;
		if (client_list != NULL)
			client_list->prev = cinfo;
		client_list = cinfo;
	}

	/* unlock */
	pthread_mutex_unlock(&client_list_lock);

	return 0;
}

/* client_count_dec()
   decrements the number of connected clients
*/
void
client_remove (struct sockaddr_storage* addr)
{
	struct ClientAddr* cinfo;

	/* lock */
	pthread_mutex_lock(&client_list_lock);

	/* find client */
	cinfo = client_list;
	while (cinfo != NULL) {
		/* match? decrement counts */
		if (addr_match(&cinfo->addr, addr) == 0) {
			--cinfo->count;
			--client_count;

			/* last of our kind (eek ;-) remove structure */
			if (cinfo->count == 0) {
				/* update pointers */
				if (cinfo->prev != NULL)
					cinfo->prev->next = cinfo->next;
				else
					client_list = cinfo->next;
				if (cinfo->next != NULL)
					cinfo->next->prev = cinfo->prev;

				/* destroy! */
				free(cinfo);
			}

			/* all done */
			break;
		}

		/* increment */
		cinfo = cinfo->next;
	}

	/* unlock */
	pthread_mutex_unlock(&client_list_lock);
}

/* log_open()
   open the log file
*/
int
log_open (char* filename)
{
	/* no file? log to stderr */
	if (filename == NULL) {
		log_file = stderr;
		return 0;
	}

	/* open log file */
	log_file = fopen(filename, "w+");
	if (log_file == NULL) {
		fprintf(stderr, "ERROR: log_open(): could not open %s: %s\n", filename, strerror(errno));
		return -1;
	}

	return 0;
}

/* log_close()
   close the log file
*/
void
log_close (void)
{
	if (log_file != stderr)
		fclose(log_file);

	log_file = NULL;
}

/* log_msg()
   write out a log message to the log file
*/
void
log_msg (enum LogLevel level, char* format, ...)
{
	va_list va;
	time_t t;
	struct tm lt;
	char time_buf[128];

	assert(format != NULL);

	/* lock */
	flockfile(log_file);

	/* time message */
	time(&t);
	localtime_r(&t, &lt);
	strftime(time_buf, sizeof(time_buf), "%Y-%m-%d %H:%M:%S", &lt);
	fprintf(log_file, "%s - ", time_buf);

	/* print out level prefix */
	switch (level) {
		case LOG_NOTICE:
			break;
		case LOG_WARNING:
			fprintf(log_file, "WARNING: ");
			break;
		case LOG_ERROR:
			fprintf(log_file, "**ERROR** ");
			break;
		case LOG_DEBUG:
			fprintf(log_file, "[debug]: ");
			break;
	}

	/* print out message */
	va_start(va, format);
	vfprintf(log_file, format, va);
	va_end(va);

	/* flush */
	fflush(log_file);

	/* unlock */
	funlockfile(log_file);
}

/* sockaddr_name_of()
   get a printable version of a socket address
*/
char*
sockaddr_name_of (struct sockaddr_storage* addr, char* buffer, size_t len)
{
	char host_buf[NI_MAXHOST];
	char serv_buf[NI_MAXSERV];

	assert(addr != NULL);
	assert(buffer != NULL);
	assert(len > 0);

	/* get info */
	getnameinfo((struct sockaddr*)addr, sizeof(struct sockaddr_storage), host_buf, sizeof(host_buf), serv_buf, sizeof(serv_buf), NI_NUMERICHOST | NI_NUMERICSERV);

	/* format out */
	if (strchr(host_buf, ':'))
		snprintf(buffer, len, "[%s]:%s", host_buf, serv_buf);
	else
		snprintf(buffer, len, "%s:%s", host_buf, serv_buf);

	return buffer;
}

int
connect_to_server (char* host, int port)
{
	int sock;
	int err;
	int count;
	struct addrinfo hints;
	struct addrinfo* res;
	struct addrinfo* res_head;
	char portstr[16];

	assert(host != NULL);
	assert(port > 0 && port <= UINT16_MAX);

	/* setup hints */
	memset(&hints, 0, sizeof(hints));
	hints.ai_family = AF_UNSPEC;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_protocol = IPPROTO_TCP;

	/* make port into a string */
	snprintf(portstr, sizeof(portstr), "%d", port);

	/* get address list */
	if ((err = getaddrinfo(host, portstr, &hints, &res_head)) != 0) {
		log_msg(LOG_WARNING, "connect_to_server(): getaddrinfo() failed: %s\n", gai_strerror(err));
		return -1;
	}

	/* iterator through addresses */
	count = 0;
	res = res_head;
	while (res != NULL) {
		/* make socket */
		sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
		if (sock != -1) {
			/* connect */
			if (connect(sock, res->ai_addr, res->ai_addrlen) == 0) {
				freeaddrinfo(res_head);
				return sock;
			}
			close(sock);
		}
		
		++count;
		res = res->ai_next;
	}

	freeaddrinfo(res_head);

	/* none successful */
	if (count == 0)
		log_msg(LOG_WARNING, "connect_to_server(): no addresses found for %s:%d\n", host, port);
	else
		log_msg(LOG_NOTICE, "host %s:%d is currently unavailable\n", host, port);
	return -1;
}

/* parse_connect_command()
   parses the connect command line
   returns 0 if successful, -1 on parse error
*/
int
parse_connect_command(char* command, char** host, int* port)
{
	char* sep;
	char* end;
	char* portstr;
	long portval;

	assert(command != NULL);
	assert(host != NULL);
	assert(port != NULL);

	/* initialize */
	*host = NULL;
	*port = 0;

	/* erase whitespace at end */
	for (end = command + strlen(command) - 1; end >= command && isspace(*end); --end)
		*end = 0;

	/* must begin with 'connect ' */
	if (strncmp(command, "connect ", 8))
		return -1;

	/* host follows */
	*host = command + 8;

	/* find space separating host and port */
	sep = strchr(*host, ' ');
	if (sep == NULL) {
		/* no separator - failure */
		return -1;
	}

	/* end host and start port */
	*sep = 0;
	portstr = sep + 1; 

	/* parse port for validity */
	portval = strtol(portstr, &end, 10);
	if (*end != 0) {
		/* wasn't all a number - failure */
		return -1;
	}

	/* make sure port is in a valid range */
	if (portval < 1 || portval > UINT16_MAX) {
		/* not in range - failure */
		return -1;
	}

	/* make sure host has data */
	if (strlen(*host) == 0) {
		/* empty host - failure */
		return -1;
	}

	/* all good */
	*port = (int)portval;
	return 0;
}

/* client_main()
   main loop for the client threads
*/
void*
client_main (void* arg)
{
	int client;
	int server;
	int err;
	char cbuffer[BUFFER_SIZE];
	int clen;
	char sbuffer[BUFFER_SIZE];
	int slen;
	struct pollfd pollfds[2];
	char* end;
	char* host;
	int port;
	enum ClientState state;
	struct sockaddr_storage caddr;
	char client_name[128];
	time_t start;
	time_t activity;

	assert(arg != NULL);
	
	/* get client socket */
	client = ((struct ClientInfo*)arg)->sock;
	caddr = ((struct ClientInfo*)arg)->addr;
	free(arg);

	/* get client_name */
	sockaddr_name_of(&caddr, client_name, sizeof(client_name));

	/* initialize client data */
	clen = 0;
	pollfds[0].fd = client;
	pollfds[0].events = POLLIN;

	/* initialize server data */
	server = -1;
	slen = 0;
	pollfds[1].fd = -1;
	pollfds[1].events = 0;

	/* start time */
	start = time(NULL);
	activity = time(NULL);

	/* client loop */
	state = CLIENT_INIT;
	do {
		/* poll fds */
		pollfds[0].revents = 0;
		pollfds[1].revents = 0;
		err = poll(pollfds, (server == -1) ? 1 : 2, 1000);

		log_dbg("poll() for %s: %d %hx %hx %d %d\n", client_name, err, pollfds[0].revents, pollfds[1].revents, clen, slen);

		/* handle error */
		if (err < 0 && errno != EINTR) {
			log_msg(LOG_ERROR, "client_main(%d): poll() failed: %s\n", client, strerror(errno));
			state = CLIENT_SHUTDOWN;
		}

		/* server disconnected? */
		if (pollfds[1].revents & POLLHUP) {
			close(server);
			server = -1;
			state = CLIENT_FINISH;
			pollfds[1].revents = 0;
		}

		/* client disconnected? */
		if (pollfds[0].revents & POLLHUP) {
			close(client);
			client = -1;
			state = CLIENT_SHUTDOWN;
			pollfds[0].revents = 0;
		}

		/* read in from client */
		if (pollfds[0].revents & POLLIN && clen < BUFFER_SIZE) {
			/* receive data */
			err = recv(pollfds[0].fd, cbuffer + clen, BUFFER_SIZE - clen, 0);
			if (err < 0 && errno != EINTR) {
				log_msg(LOG_ERROR, "client_main(%d): recv(client) failed: %s\n", client, strerror(errno));
				state = CLIENT_SHUTDOWN;
			}

			/* eof?  why no POLLHUP?  grr! */
			if (err == 0) {
				close(client);
				client = -1;
				state = CLIENT_SHUTDOWN;
			}

			/* increment buffer count */
			if (err > 0)
				clen += err;

			/* we have activity */
			activity = time(NULL);
		}

		/* read in from server */
		if (pollfds[1].revents & POLLIN && slen < BUFFER_SIZE) {
			/* receive data */
			err = recv(pollfds[1].fd, sbuffer + slen, BUFFER_SIZE - slen, 0);
			if (err < 0 && errno != EINTR) {
				log_msg(LOG_ERROR, "client_main(%d): recv(server) failed: %s\n", client, strerror(errno));
				state = CLIENT_SHUTDOWN;
			}

			/* eof?  why no POLLHIP?  grr! */
			if (err == 0) {
				close(server);
				server = -1;
				state = CLIENT_FINISH;
			}

			/* increment buffer count */
			if (err > 0)
				slen += err;
		}

		/* write out to client */
		if (client != -1 && slen > 0) {
			/* send data */
			err = send(pollfds[0].fd, sbuffer, slen, MSG_DONTWAIT);
			if (err < 0 && errno != EINTR) {
				log_msg(LOG_ERROR, "client_main(%d): send(client) failed: %s\n", client, strerror(errno));
				state = CLIENT_SHUTDOWN;
			}

			/* decrement buffer count */
			if (err > 0)
				slen -= err;
		}

		/* write out to server */
		if (server != -1 && clen > 0) {
			/* send data */
			err = send(pollfds[1].fd, cbuffer, clen, MSG_DONTWAIT);
			if (err < 0 && errno != EINTR) {
				log_msg(LOG_ERROR, "client_main(%d): send(server) failed: %s\n", client, strerror(errno));
				state = CLIENT_SHUTDOWN;
			}

			/* decrement buffer count */
			if (err > 0)
				clen -= err;
		}

		/* process client states */
		switch (state) {
			case CLIENT_INIT:
				/* check if we have a line ready for consumption */
				if (clen == 0)
					break;
				for (end = cbuffer; end - cbuffer < clen; ++end)
					if (*end == '\n')
						break;
				if (end - cbuffer >= clen)
					break;
				*end = 0;

				log_dbg("CLIENT %s INPUT: %s\n", client_name, cbuffer);

				/* parse connect command */
				if (parse_connect_command(cbuffer, &host, &port)) {
					log_msg(LOG_NOTICE, "client %s sent invalid connect command\n", client_name);
					slen = snprintf(sbuffer, sizeof(sbuffer), "Invalid connect command.\r\n");
					state = CLIENT_FINISH;
					break;
				}

				log_dbg("CLIENT %s REQUEST: %s %d\n", client_name, host, port);
				
				/* verify host:port is allowed */
				if (allowed_servers_check(host, port)) {
					log_msg(LOG_NOTICE, "client %s attempted connection to server %s:%d - access denied\n", client_name, host, port);
					slen = snprintf(sbuffer, sizeof(sbuffer), "Access denied to %s:%d.\r\n", host, port);
					state = CLIENT_FINISH;
					break;
				}

				log_dbg("CLIENT %s CONNECTING: %s %d\n", client_name, host, port);

				/* connect to server */
				server = connect_to_server(host, port);
				if (server == -1) {
					slen = snprintf(sbuffer, sizeof(sbuffer), "Failed to connect to %s:%d.\r\n", host, (int)port);
					state = CLIENT_FINISH;
					break;
				}

				log_dbg("CLIENT %s READY: %s %d\n", client_name, host, port);

				/* make active */
				pollfds[1].fd = server;
				log_msg(LOG_NOTICE, "client %s connected to server %s:%d\n", client_name, host, port);
				state = CLIENT_ACTIVE;

				/* clean up cbuffer */
				if (clen > end - cbuffer + 1) {
					memmove(cbuffer, end + 1, clen - (end - cbuffer + 1));
					clen -= (end - cbuffer + 1);
				} else
					clen = 0;

				/* reset activity timer */
				activity = time(NULL);

				break;
			case CLIENT_ACTIVE:
				break;
			case CLIENT_FINISH:
				/* if there's nothing left to send to client,
				   shutdown the connection */
				if (slen == 0)
					state = CLIENT_SHUTDOWN;
				break;
			case CLIENT_SHUTDOWN:
				break;
		}

		/* set client poll events */
		pollfds[0].events = 0;
		if (clen < BUFFER_SIZE)
			pollfds[0].events |= POLLIN;
		if (slen > 0)
			pollfds[0].events |= POLLOUT;

		/* set server poll events */
		if (server != -1) {
			pollfds[1].events = 0;
			if (slen < BUFFER_SIZE)
				pollfds[1].events = POLLIN;
			if (clen > 0)
				pollfds[1].events |= POLLOUT;
		}

		/* more than the connect timeout in seconds has passed? */
		if (state == CLIENT_INIT && time(NULL) - start >= connect_timeout) {
			log_msg(LOG_NOTICE, "client %s connect timeout (%d seconds)\n", client_name, connect_timeout);
			slen = snprintf(sbuffer, sizeof(sbuffer), "Connection timeout: no connect command given within %d seconds.\r\n", connect_timeout);
			state = CLIENT_FINISH;
		}


		/* no recent activity?  shutdown */
		if (state == CLIENT_ACTIVE && time(NULL) - activity >= activity_timeout) {
			log_msg(LOG_NOTICE, "disconnecting %s due to lack of activity (%d seconds)\n", client_name, activity_timeout);
			close(server);
			server = -1;
			slen = snprintf(sbuffer, sizeof(sbuffer), "Disconnecting to due lack of activity.\r\n");
			state = CLIENT_FINISH;
		}

	} while (state != CLIENT_SHUTDOWN);

	/* disconnection message */
	log_msg(LOG_NOTICE, "disconnecting client %s\n", client_name);

	/* shutdown sockets */
	if (client != -1)
		close(client);
	if (server != -1)
		close(server);

	log_dbg("decrementing client count for %s\n", client_name);

	/* decrement current client count */
	client_remove(&caddr);

	log_dbg("ending thread for %s\n", client_name);

	return NULL;
}

int
create_listen_socket (int family, int port)
{
	int sock;
	int opt;
	socklen_t addr_len;
	struct sockaddr_storage sockaddr;

	assert(family == AF_INET || family == AF_INET6);
	assert(port > 0 && port <= UINT16_MAX);

	/* create socket */
	sock = socket(family, SOCK_STREAM, 0);
	if (sock == -1) {
		log_msg(LOG_ERROR, "create_listen_socket(): socket() failed: %s\n", strerror(errno));
		return -1;
	}

	/* allow reuse of server port */
	opt = 1;
	setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

	if (family == AF_INET6) {

		/* only listen for IPv6 addresses */
		opt = 1;
		setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt));

		/* initialie IPV6 addr */
		sockaddr.ss_family = AF_INET6;
		((struct sockaddr_in6*)&sockaddr)->sin6_port = htons(port);
		memset(&((struct sockaddr_in6*)&sockaddr)->sin6_addr, 0, sizeof(((struct sockaddr_in6*)&sockaddr)->sin6_addr));

	} else if (family == AF_INET) {

		/* initialize IPV4 addr */
		sockaddr.ss_family = AF_INET;
		((struct sockaddr_in*)&sockaddr)->sin_port = htons(port);
		memset(&((struct sockaddr_in*)&sockaddr)->sin_addr, 0, sizeof(((struct sockaddr_in*)&sockaddr)->sin_addr));

	} else {

		log_msg(LOG_ERROR, "create_listen_socket(): unsupported family: %d\n", family);
		close(sock);
		return -1;
	}

	/* bind addr to socket */
	addr_len = sizeof(sockaddr);
	if (bind(sock, (struct sockaddr*)&sockaddr, addr_len) == -1) {
		log_msg(LOG_ERROR, "create_listen_socket(): bind() failed: %s\n", strerror(errno));
		close(sock);
		return -1;
	}

	/* begin listening on socket */
	if (listen(sock, 5) == -1) {
		log_msg(LOG_ERROR, "create_listen_socket(): listen() failed: %s\n", strerror(errno));
		close(sock);
		return -1;
	}

	return sock;
}

int
create_client (int sock)
{
	int err;
	int client;
	struct sockaddr_storage sockaddr;
	socklen_t addr_len = sizeof(sockaddr);
	struct ClientInfo* cinfo;
	char client_name[128];
	pthread_attr_t p_attr;
	pthread_t p_thread;

	assert(sock >= 0);

	log_dbg("accepting client socket\n");

	/* get client socket */
	addr_len = sizeof(sockaddr);
	client = accept(sock, (struct sockaddr*)&sockaddr, &addr_len);
	if (client == -1) {
		log_msg(LOG_WARNING, "create_client(): accept() failed: %s\n", strerror(errno));
		return -1;
	}

	log_dbg("getting client name\n");

	/* get client_name */
	sockaddr_name_of(&sockaddr, client_name, sizeof(client_name));

	log_dbg("checking client %s with ban list\n", client_name);

	/* check if banned */
	if (banned_clients_check(&sockaddr) == -1) {
		log_msg(LOG_NOTICE, "rejected client %s, banned host/network\n", client_name);
		write_string(client, "Your client host or network is banned.\r\n");
		close(client);
		return -1;
	}

	log_dbg("checking client %s connection count\n", client_name);

	/* check client count */
	err = client_add(&sockaddr);

	if (err == -1) {
		log_msg(LOG_NOTICE, "rejected client %s, too many users\n", client_name);
		/* FIXME: this may not all get written out */
		write_string(client, "Maximum user count exceeded.\r\n");
		close(client);
		return -1;
	}
	if (err == -2) {
		log_msg(LOG_NOTICE, "rejected client %s, too many conncetion from that address\n", client_name);
		/* FIXME: this may not all get written out */
		write_string(client, "Maximum client count for your address exceeded.\r\n");
		close(client);
		return -1;
	}

	log_dbg("allocating client %s thread storage\n", client_name);

	/* allocate space to store client socket */
	cinfo = (struct ClientInfo*)malloc(sizeof(struct ClientInfo));
	if (cinfo == 0) {
		client_remove(&sockaddr);
		log_msg(LOG_WARNING, "create_client(): malloc() failed: %s\n", strerror(errno));
		log_msg(LOG_NOTICE, "disconnecting client %s\n", client_name);
		return -1;
	}
	cinfo->sock = client;
	cinfo->addr = sockaddr;

	/* connection message */
	log_msg(LOG_NOTICE, "connection from %s\n", client_name);

	/* initialize pthread attributes -
	   all threads shoudl be detached */
	pthread_attr_init(&p_attr);
	pthread_attr_setdetachstate(&p_attr, PTHREAD_CREATE_DETACHED);

	log_dbg("creating client %s thread\n", client_name);

	/* create thread */
	if (pthread_create(&p_thread, &p_attr, client_main, cinfo)) {
		client_remove(&sockaddr);
		log_msg(LOG_WARNING, "create_client(): pthread_create() failed: %s\n", strerror(errno));
		log_msg(LOG_NOTICE, "disconnecting client %s\n", client_name);
		close(client);
		free(cinfo);
		return -1;
	}

	log_dbg("done creating client %s\n", client_name);

	return client;
}

int
drop_privileges (char* user, char* group)
{
	struct group* grp;
	struct passwd* pwd;
	uid_t uid;
	gid_t gid;
	char* end;

	/* handle group */
	if (group != NULL) {
		/* get gid */
		gid = strtol(group, &end, 10);

		/* not numeric?  lookup as name */
		if (*end != 0) {
			grp = getgrnam(group);
			if (grp == NULL) {
				log_msg(LOG_ERROR, "drop_privileges(): group not found: %s\n", group);
				return -1;
			}
			gid = grp->gr_gid;
		}

		/* set real/effective gid */
		if (setregid(gid, gid)) {
			log_msg(LOG_ERROR, "drop_privileges(): setregid() failed: %s\n", strerror(errno));
			return -1;
		}

		log_msg(LOG_NOTICE, "changed gid to %d\n", gid);

		/* drop supplementary groups */
		if (setgroups(0, NULL)) {
			log_msg(LOG_ERROR, "drop_privileges(): setgroups() failed: %s\n", strerror(errno));
			return -1;
		}

		log_msg(LOG_NOTICE, "dropped supplementary groups\n");
	}

	/* handle group */
	if (user != NULL) {
		/* get gid */
		uid = strtol(user, &end, 10);

		/* not numeric?  lookup as name */
		if (*end != 0) {
			pwd = getpwnam(user);
			if (pwd == NULL) {
				log_msg(LOG_ERROR, "drop_privileges(): user not found: %s\n", group);
				return -1;
			}
			uid = pwd->pw_uid;
		}

		/* set real/effective gid */
		if (setreuid(uid, uid)) {
			log_msg(LOG_ERROR, "drop_privileges(): setreuid() failed: %s\n", strerror(errno));
			return -1;
		}

		log_msg(LOG_NOTICE, "changed UID to %d\n", uid);
	}

	return 0;
}

void
signal_sighup (int signal)
{
	reload_flag = 1;
}

void
signal_sigterm (int signal)
{
	shutdown_flag = 1;
}

int
main (int argc, char** argv)
{
	int sock;
	int sock6;
	int err;
	struct pollfd sock_fds[2];
	struct sigaction sigact;
	FILE* pid;

	/* configuration */
	int port = DEFAULT_LISTEN_PORT;
	char* host_list = DEFAULT_HOST_LIST;
	char* ban_list = DEFAULT_BAN_LIST;
	char* log_file = NULL;
	char* pid_file = NULL;
	int do_daemon = 0;
	int do_ipv6 = 0;
	int do_help = 0;
	char* user = NULL;
	char* group = NULL;

	/* hello! */
	printf("Telnet Proxy Daemon v%s\n", PROXY_VERSION);
	printf("Copyright (C) 2004  AwesomePlay Productions, Inc.\n");
	printf("See the file COPYING for license details.\n");

	/* options */
	struct Option options[] = {
		{ 0, "help", NULL, NULL, &do_help },
		{ 'd', "daemon", NULL, NULL, &do_daemon },
		{ '6', "ipv6", NULL, NULL, &do_ipv6 },
		{ 'p', "port", NULL, &port, NULL },
		{ 'l', "log", &log_file, NULL, NULL },
		{ 'h', "hostlist", &host_list, NULL, NULL },
		{ 'b', "banlist", &ban_list, NULL, NULL },
		{ 'w', "pid", &pid_file, NULL, NULL },
		{ 'c', "maxclients", NULL, &max_clients, NULL },
		{ 'm', "maxhost", NULL, &max_host_clients, NULL },
		{ 't', "ctime", NULL, &connect_timeout, NULL },
		{ 'a', "atime", NULL, &activity_timeout, NULL },
		{ 'u', "user", &user, NULL, NULL },
		{ 'g', "group", &group, NULL, NULL },
		{ 0, NULL, NULL, NULL, NULL }
	};

	/* parse options */
	if (parse_options(options, argc, argv) == -1)
		return 1;

	/* help? */
	if (do_help) {
		print_usage(argv[0], options);
		return 0;
	}

	/* open log file */
	if (log_open(log_file))
		return 1;

	/* verify port */
	if (port < 1 || port > UINT16_MAX) {
		log_msg(LOG_ERROR, "invalid port %d\n", port);
		return 1;
	}

	/* verify max clients */
	if (max_clients < 1) {
		log_msg(LOG_ERROR, "invalid client max %d\n", max_clients);
		return 1;
	}
	if (max_host_clients < 1) {
		log_msg(LOG_ERROR, "invalid client host max %d\n", max_host_clients);
		return 1;
	}

	/* verify timeouts */
	if (connect_timeout < 1) {
		log_msg(LOG_ERROR, "invalid connect timeout %d\n", connect_timeout);
		return 1;
	}
	if (activity_timeout < 1) {
		log_msg(LOG_ERROR, "invalid activity timeout %d\n", activity_timeout);
		return 1;
	}

	/* load initial set of valid hosts */
	if (allowed_servers_load(host_list))
		return 1;

	/* load initial set of banned clients */
	if (banned_clients_load(ban_list))
		return 1;

	/* do daemonization */
	if (do_daemon) {
		/* requires a log file */
		if (log_file == NULL) {
			log_msg(LOG_ERROR, "must specify a log file when using daemon mode\n");
			return 1;
		}

		/* fork */
		err = fork();
		if (err == -1) {
			log_msg(LOG_ERROR, "main(): fork() failed: %s\n", strerror(errno));
			return 1;
		}

		/* main process - exit */
		if (err != 0)
			return 0;

		/* close stdin, stdout, stderr */
		close(0);
		close(1);
		close(2);

		/* set session id */
		setsid();
	}

	/* create IPV6 socket */
	sock6 = -1;
	if (do_ipv6) {
		sock6 = create_listen_socket(AF_INET6, port);
		if (sock6 == -1)
			return 1;

		log_msg(LOG_NOTICE, "IPv6 enabled\n");
	}

	/* create IPV4 socket */
	sock = create_listen_socket(AF_INET, port);
	if (sock == -1)
		return 1;

	/* set signals */
	sigemptyset (&sigact.sa_mask);
	sigact.sa_flags = 0;
	sigact.sa_handler = SIG_IGN;
	sigaction(SIGPIPE, &sigact, NULL);
	sigact.sa_handler = signal_sighup;
	sigaction(SIGHUP, &sigact, NULL);
	sigact.sa_handler = signal_sigterm;
	sigaction(SIGTERM, &sigact, NULL);
	sigact.sa_handler = signal_sigterm;
	sigaction(SIGINT, &sigact, NULL);

	/* write pid file */
	if (pid_file != NULL) {
		pid = fopen(pid_file, "wt");
		if (pid == NULL) {
			log_msg(LOG_ERROR, "failed to open %s: %s\n", pid_file, strerror(errno));
			return 1;
		}
		fprintf(pid, "%d\n", getpid());
		fclose(pid);
	}

	/* drop privileges */
	if (drop_privileges(user, group))
		return 1;

	/* begin server loop */
	log_msg(LOG_NOTICE, "listening on port %d\n", port);
	do {
		/* poll server socket(s) */
		sock_fds[0].fd = sock;
		sock_fds[0].events = POLLIN;
		sock_fds[0].revents = 0;
		sock_fds[1].fd = sock6;
		sock_fds[1].events = POLLIN;
		sock_fds[1].revents = 0;
		err = poll(sock_fds, 2, 1000); /* 1000 for DEBUG - set back to -1 */

		log_dbg("poll() for main: %d %hx %hx\n", err, sock_fds[0].revents, sock_fds[1].revents);

		/* check for error */
		if (err == -1 && errno != EINTR) {
			log_msg(LOG_WARNING, "main(): poll() failed: %s\n", strerror(errno));
			break;
		}

		/* reload data on SIGHUP */
		if (reload_flag) {
			reload_flag = 0;
			log_msg(LOG_NOTICE, "reloading host list\n");
			allowed_servers_load(host_list);
			log_msg(LOG_NOTICE, "reloading ban list\n");
			banned_clients_load(ban_list);
		}

		/* shutdown on SIGTERM */
		if (shutdown_flag) {
			log_msg(LOG_NOTICE, "received terminating signal\n");
			break;
		}

		/* something ready on IPv4 socket */
		if (sock_fds[0].revents & POLLIN) {
			log_dbg("creating ipv4 client\n");

			create_client(sock_fds[0].fd);

			log_dbg("ipv4 client created\n");
		}

		/* something ready on IPv6 socket */
		if (sock_fds[1].revents & POLLIN) {
			log_dbg("creating ipv6 client\n");

			create_client(sock_fds[1].fd);

			log_dbg("ipv6 client created\n");
		}

	} while(1);

	/* shutdown sockets */
	close(sock);
	if (sock6 != -1)
		close(sock6);

	/* unlink pid file */
	if (pid_file != NULL)
		unlink(pid_file);

	/* finish up */
	log_msg(LOG_NOTICE, "terminating proxy\n");
	log_close();

	return 0;
}