/* * Telnet Proxy Daemon * Copyright (c) 2004 AwesomePlay Productions, Inc. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH * DAMAGE. */ #include <time.h> #include <signal.h> #include <ctype.h> #include <stdlib.h> #include <fcntl.h> #include <unistd.h> #include <stdio.h> #include <inttypes.h> #include <errno.h> #include <string.h> #include <pthread.h> #include <netinet/in.h> #include <arpa/inet.h> #include <sys/socket.h> #include <sys/un.h> #include <sys/poll.h> #include <netdb.h> #include <stdarg.h> #include <pwd.h> #include <grp.h> #include <assert.h> /* ---- CONFIGURATION ---- */ #define BUFFER_SIZE 2048 #define DEFAULT_MAX_CLIENTS 50 #define DEFAULT_MAX_HOST_CLIENTS 3 #define DEFAULT_LISTEN_PORT 9596 #define DEFAULT_HOST_LIST "hosts.txt" #define DEFAULT_BAN_LIST "deny.txt" #define DEFAULT_CONNECT_TIMEOUT 30 #define DEFAULT_ACTIVITY_TIMEOUT (60 * 15) /* ---- DATA TYPES ---- */ struct AllowedServer { char* host; uint16_t port; struct AllowedServer* next; }; struct BannedClient { struct sockaddr_storage addr; uint8_t mask; struct BannedClient* next; }; struct ClientAddr { struct sockaddr_storage addr; uint count; struct ClientAddr* prev; struct ClientAddr* next; }; struct ClientInfo { int sock; struct sockaddr_storage addr; }; struct Option { char short_name; char* long_name; char** string_arg; int* int_arg; int* bool_arg; }; enum ClientState { CLIENT_INIT, CLIENT_ACTIVE, CLIENT_FINISH, CLIENT_SHUTDOWN }; enum LogLevel { LOG_NOTICE, LOG_WARNING, LOG_ERROR, LOG_DEBUG }; /* ---- FUNCTIONS ---- */ int allowed_servers_load (char* filename); int allowed_servers_check (char* host, int port); int banned_clients_load (char* filename); int banned_clients_check (struct sockaddr_storage* addr); int client_list_add (struct sockaddr_storage* addr); void client_list_remove (struct sockaddr_storage* addr); void signal_sighup (int); void signal_sigterm (int); int parse_options (struct Option*, int argc, char** argv); int log_open (char* filename); void log_msg (enum LogLevel level, char* format, ...); void log_close (void); int write_string (int sock, char* string); /* not safe, but simple */ char* sockaddr_name_of (struct sockaddr_storage* addr, char* buffer, size_t len); /* ---- GLOBALS ---- */ struct AllowedServer* allowed_servers = NULL; pthread_mutex_t allowed_servers_lock = PTHREAD_MUTEX_INITIALIZER; struct BannedClient* banned_clients = NULL; pthread_mutex_t banned_clients_lock = PTHREAD_MUTEX_INITIALIZER; int max_clients = DEFAULT_MAX_CLIENTS; int max_host_clients = DEFAULT_MAX_HOST_CLIENTS; int client_count = 0; struct ClientAddr* client_list = NULL; pthread_mutex_t client_list_lock = PTHREAD_MUTEX_INITIALIZER; int connect_timeout = DEFAULT_CONNECT_TIMEOUT; int activity_timeout = DEFAULT_ACTIVITY_TIMEOUT; volatile int reload_flag = 0; volatile int shutdown_flag = 0; FILE* log_file = NULL; /* ---- DEBUG LOG ---- */ #ifdef NDEBUG #define log_dbg(format,...) #else #define log_dbg(format,args...) log_msg(LOG_DEBUG, (format), ## args) #endif /* NDEBUG */ /* ---- BEGIN CODE ---- */ /* write_string() very dumb fuction to use - doesn't guarantee any or even some of the string is actually written */ int write_string (int sock, char* string) { return write(sock, string, strlen(string)); } /* print_usage() print out a usage error message */ void print_usage (char* self, struct Option* options) { int i; assert(self != NULL); assert(options != NULL); fprintf(stderr, "Usage: ./proxy"); for (i = 0; options[i].short_name != 0 || options[i].long_name != NULL; ++i) { if (options[i].long_name != NULL) { fprintf(stderr, " [--%s", options[i].long_name); if (options[i].short_name != 0) fprintf(stderr, "|-%c", options[i].short_name); } else { fprintf(stderr, " [-%c", options[i].short_name); } if (options[i].string_arg != NULL) fprintf(stderr, " <string>"); else if (options[i].int_arg != NULL) fprintf(stderr, " <int>"); fprintf(stderr, "]"); } fprintf(stderr, "\n"); } /* parse_options() read in options return 0 on success, -1 on failure */ int parse_options (struct Option* options, int argc, char** argv) { int opt; int i; assert(options != NULL); assert(argc >= 1); assert(argv != NULL); /* iterator through options */ for (opt = 1; opt < argc; ++opt) { /* find corresponding option */ for (i = 0; options[i].short_name != 0 || options[i].long_name != NULL; ++i) { if ( /* match short opt? */ (options[i].short_name != 0 && argv[opt][0] == '-' && argv[opt][1] == options[i].short_name && argv[opt][2] == 0) || /* match long opt? */ (options[i].long_name != NULL && argv[opt][0] == '-' && argv[opt][1] == '-' && !strcmp(argv[opt] + 2, options[i].long_name)) /* found a match! */ ) { /* need arg but have none? */ if (opt == argc - 1 && (options[i].int_arg != NULL || options[i].string_arg != NULL)) { fprintf(stderr, "Error: No value given for option: %s\n", argv[opt]); print_usage(argv[0], options); return -1; } /* set string arg */ if (options[i].string_arg != NULL) *options[i].string_arg = argv[++opt]; /* set int arg */ if (options[i].int_arg != NULL) *options[i].int_arg = atol(argv[++opt]); /* set bool arg */ if (options[i].bool_arg != NULL) *options[i].bool_arg = 1; break; } } /* hit end of options? bad juju */ if (options[i].short_name == 0 && options[i].long_name == NULL) { /* error! */ fprintf(stderr, "Error: Invalid option: %s\n", argv[opt]); print_usage(argv[0], options); return -1; } } /* all found */ return 0; } /* build_addr_mask() creates a bit mask from the given mask length used for network address masks (i.e., 192.168.1.2/24) NOTE: code taken from FreeBSD 'route' command source Copyright (C) FreeBSD */ void build_addr_mask (uint8_t* buf, int mask, int max) { int q, r; assert(buf != NULL); assert(mask <= max); assert(max > 0); memset (buf, 0, max / 8); q = mask >> 3; r = mask & 7; if (q > 0) memset(buf, 0xff, q); if (r > 0) *(buf + q) = (0xff00 >> r) & 0xff; } /* apply_addr_mask() applies the bit mask to the given address */ void apply_addr_mask (void* addr, uint8_t mask, int family) { int i; uint8_t buf[32]; /* more than big enough */ int size; assert(addr != NULL); assert(mask > 0); assert(family == AF_INET || family == AF_INET6); if (family == AF_INET6) size = 16; /* 16 bytes in IPv6 */ else size = 4; /* 4 bytes in IPv4 */ build_addr_mask(buf, mask, size * 8); for (i = 0; i < size; ++i) ((uint8_t*)addr)[i] &= buf[i]; } /* addr_match() return 0 if both addresses match */ int addr_match (struct sockaddr_storage* addr1, struct sockaddr_storage* addr2) { assert(addr1 != NULL); assert(addr2 != NULL); /* different family? not the same */ if (addr1->ss_family != addr2->ss_family) return -1; /* check IPv6 */ if (addr1->ss_family == AF_INET6) { if (IN6_ARE_ADDR_EQUAL(&((struct sockaddr_in6*)addr1)->sin6_addr, &((struct sockaddr_in6*)addr2)->sin6_addr)) return 0; /* check IPv4 address */ } else if (addr1->ss_family == AF_INET) { if (!memcmp(&((struct sockaddr_in*)addr1)->sin_addr, &((struct sockaddr_in*)addr2)->sin_addr, sizeof(((struct sockaddr_in*)addr2)->sin_addr))) return 0; } /* no match */ return -1; } /* parse_addr() parses a network address returns -1 on parse error */ int parse_addr (char* addr, struct sockaddr_storage* host, uint8_t* mask) { char buffer[128]; int inmask; char* slash; assert(addr != NULL); assert(host != NULL); assert(mask != NULL); /* clear address */ memset(host, 0, sizeof(struct sockaddr_storage)); /* put in buffer, snprintf() guarnatees NUL byte */ snprintf(buffer, sizeof(buffer), "%s", addr); /* get mask - have we a mask? */ inmask = -1; slash = strchr(buffer, '/'); if (slash != NULL) { *slash = '\0'; ++ slash; /* don't use atoi or strtol, guarantee we parse it right */ inmask = 0; while (*slash != '\0') { if (!isdigit(*slash)) break; inmask *= 10; inmask += *slash - '0'; ++ slash; } /* only numbers, rights? */ if (*slash != '\0') return -1; } /* parse IPv6 first */ if (inet_pton(AF_INET6, buffer, &((struct sockaddr_in6*)host)->sin6_addr) > 0) { /* mask must be <= 128 */ if (inmask > 128) return -1; /* default? */ if (inmask == -1) inmask = 128; /* set family and mask */ host->ss_family = AF_INET6; *mask = inmask; return 0; /* try IPv4 parsing */ } else if (inet_pton(AF_INET, buffer, &((struct sockaddr_in*)host)->sin_addr) > 0) { // check mask if (inmask > 32) return -1; // FAIL /* default? */ if (inmask == -1) inmask = 32; /* set family and mask */ ((struct sockaddr_in*)host)->sin_family = AF_INET; *mask = inmask; return 0; } /* no match */ return -1; } /* allowed_servers_check() return 0 if the host/port is allowed, -1 if it's not */ int allowed_servers_check (char* host, int port) { struct AllowedServer* hinfo; assert(host != NULL); assert(port > 0 && port <= UINT16_MAX); pthread_mutex_lock(&allowed_servers_lock); /* scan list */ hinfo = allowed_servers; while (hinfo != NULL) { /* match? return success (0) */ if (port == hinfo->port && !strcmp(host, hinfo->host)) { pthread_mutex_unlock(&allowed_servers_lock); return 0; } hinfo = hinfo->next; } pthread_mutex_unlock(&allowed_servers_lock); /* no matches, failure (-1) */ return -1; } /* allowed_servers_load() load the list of allowed hosts */ int allowed_servers_load (char* filename) { FILE* file; char line[512]; int lineno; long port; char* sep; char* end; struct AllowedServer* hinfo; /* open file */ file = fopen(filename, "rt"); if (file == NULL) { log_msg(LOG_ERROR, "allowed_servers_load(): fopen() failed for '%s': %s\n", filename, strerror(errno)); return -1; } /* lock resource */ pthread_mutex_lock(&allowed_servers_lock); /* clear list */ while (allowed_servers != NULL) { hinfo = allowed_servers->next; free(allowed_servers->host); free(allowed_servers); allowed_servers = hinfo; } /* read in file */ lineno = 0; while (fgets(line, sizeof(line), file) != NULL) { ++lineno; /* trim spaces/newlines */ for (end = line + strlen(line) - 1; end >= line; --end) if (isspace(*end)) *end = 0; /* empty? */ if (!strlen(line)) continue; /* find the host:port separator */ sep = strchr(line, ':'); if (sep == NULL) { log_msg(LOG_WARNING, "allowed_servers_load(): malformed entry at %s:%d\n", filename, lineno); continue; } *sep = 0; /* read and sanity check the port number */ port = strtol(sep + 1, &end, 10); if (*end != 0 || port < 1 || port > UINT16_MAX) { log_msg(LOG_WARNING, "allowed_servers_load(): malformed entry at %s:%d\n", filename, lineno); continue; } /* sanity check the host name */ if (strlen(line) < 1) { log_msg(LOG_WARNING, "allowed_servers_load(): malformed entry at %s:%d\n", filename, lineno); continue; } /* allocate */ hinfo = (struct AllowedServer*)malloc(sizeof(struct AllowedServer)); if (hinfo == NULL) { log_msg(LOG_WARNING, "allowed_servers_load(): malloc() failed: %s\n", strerror(errno)); continue; } /* copy host string */ hinfo->host = strdup(line); if (hinfo->host == NULL) { log_msg(LOG_WARNING, "allowed_servers_add(): strdup() failed: %s\n", strerror(errno)); free(hinfo); continue; } /* set port */ hinfo->port = port; /* make new list head */ hinfo->next = allowed_servers; allowed_servers = hinfo; } /* unlock */ pthread_mutex_unlock(&allowed_servers_lock); /* finish up */ fclose(file); return 0; } /* banned_clients_load() load the list of banned clients */ int banned_clients_load (char* filename) { FILE* file; char line[512]; char* end; int lineno; struct sockaddr_storage addr; uint8_t mask; struct BannedClient* cinfo; /* open file */ file = fopen(filename, "rt"); if (file == NULL) { log_msg(LOG_ERROR, "banned_clients_load(): fopen() failed for '%s': %s\n", filename, strerror(errno)); return -1; } /* lock resource */ pthread_mutex_lock(&banned_clients_lock); /* clear list */ while (banned_clients != NULL) { cinfo = banned_clients->next; free(banned_clients); banned_clients = cinfo; } /* read in file */ lineno = 0; while (fgets(line, sizeof(line), file) != NULL) { ++lineno; /* trim spaces/newlines */ for (end = line + strlen(line) - 1; end >= line; --end) if (isspace(*end)) *end = 0; /* empty? */ if (!strlen(line)) continue; /* parse */ if (parse_addr(line, &addr, &mask) == -1) { log_msg(LOG_WARNING, "banned_clients_load(): malformed entry at %s:%d\n", filename, lineno); continue; } /* allocate */ cinfo = (struct BannedClient*)malloc(sizeof(struct BannedClient)); if (cinfo == NULL) { log_msg(LOG_WARNING, "banned_clients_load(): malloc() failed: %s\n", strerror(errno)); continue; } /* set data */ cinfo->addr = addr; cinfo->mask = mask; /* make new list head */ cinfo->next = banned_clients; banned_clients = cinfo; } /* unlock */ pthread_mutex_unlock(&banned_clients_lock); /* finish up */ fclose(file); return 0; } /* banned_clients_check() returns non-zero if address is banned */ int banned_clients_check (struct sockaddr_storage* addr) { struct BannedClient* cinfo; struct sockaddr_storage temp; /* lock */ pthread_mutex_lock(&banned_clients_lock); /* iterate over all banned clients */ cinfo = banned_clients; while (cinfo != NULL) { /* not same family? skip */ if (cinfo->addr.ss_family != addr->ss_family) { cinfo = cinfo->next; continue; } /* temporary address */ temp = *addr; /* apply mask */ if (cinfo->addr.ss_family == AF_INET6) apply_addr_mask(&((struct sockaddr_in6*)&temp)->sin6_addr, cinfo->mask, cinfo->addr.ss_family); else apply_addr_mask(&((struct sockaddr_in*)&temp)->sin_addr, cinfo->mask, cinfo->addr.ss_family); /* check equality */ if (addr_match(&temp, &cinfo->addr) == 0) { /* match - banned! unlock and return */ pthread_mutex_unlock(&banned_clients_lock); return -1; } cinfo = cinfo->next; } /* unlock */ pthread_mutex_unlock(&banned_clients_lock); /* no matches, all good */ return 0; } /* client_add() add another count for the given address returns 0 on success returns -1 on failure (too many total clients) returns -2 on failure (too many connections from this address) */ int client_add (struct sockaddr_storage* addr) { struct ClientAddr* cinfo; /* lock */ pthread_mutex_lock(&client_list_lock); /* already at max? fail. */ if (client_count >= max_clients) { pthread_mutex_unlock(&client_list_lock); return -1; } /* find client */ cinfo = client_list; while (cinfo != NULL) { if (addr_match(&cinfo->addr, addr) == 0) { /* at max already? */ if (cinfo->count >= max_host_clients) { pthread_mutex_unlock(&client_list_lock); return -2; } /* increment count, break */ ++cinfo->count; break; } cinfo = cinfo->next; } /* no client found? */ if (cinfo == NULL) { /* allocate */ cinfo = (struct ClientAddr*)malloc(sizeof(struct ClientAddr)); if (cinfo == NULL) { log_msg(LOG_ERROR, "client_add(): malloc() failed: %s\n", strerror(errno)); pthread_mutex_unlock(&client_list_lock); return -2; } /* initialize */ cinfo->addr = *addr; cinfo->count = 1; /* put on list */ cinfo->prev = NULL; cinfo->next = client_list; if (client_list != NULL) client_list->prev = cinfo; client_list = cinfo; } /* unlock */ pthread_mutex_unlock(&client_list_lock); return 0; } /* client_count_dec() decrements the number of connected clients */ void client_remove (struct sockaddr_storage* addr) { struct ClientAddr* cinfo; /* lock */ pthread_mutex_lock(&client_list_lock); /* find client */ cinfo = client_list; while (cinfo != NULL) { /* match? decrement counts */ if (addr_match(&cinfo->addr, addr) == 0) { --cinfo->count; --client_count; /* last of our kind (eek ;-) remove structure */ if (cinfo->count == 0) { /* update pointers */ if (cinfo->prev != NULL) cinfo->prev->next = cinfo->next; else client_list = cinfo->next; if (cinfo->next != NULL) cinfo->next->prev = cinfo->prev; /* destroy! */ free(cinfo); } /* all done */ break; } /* increment */ cinfo = cinfo->next; } /* unlock */ pthread_mutex_unlock(&client_list_lock); } /* log_open() open the log file */ int log_open (char* filename) { /* no file? log to stderr */ if (filename == NULL) { log_file = stderr; return 0; } /* open log file */ log_file = fopen(filename, "w+"); if (log_file == NULL) { fprintf(stderr, "ERROR: log_open(): could not open %s: %s\n", filename, strerror(errno)); return -1; } return 0; } /* log_close() close the log file */ void log_close (void) { if (log_file != stderr) fclose(log_file); log_file = NULL; } /* log_msg() write out a log message to the log file */ void log_msg (enum LogLevel level, char* format, ...) { va_list va; time_t t; struct tm lt; char time_buf[128]; assert(format != NULL); /* lock */ flockfile(log_file); /* time message */ time(&t); localtime_r(&t, <); strftime(time_buf, sizeof(time_buf), "%Y-%m-%d %H:%M:%S", <); fprintf(log_file, "%s - ", time_buf); /* print out level prefix */ switch (level) { case LOG_NOTICE: break; case LOG_WARNING: fprintf(log_file, "WARNING: "); break; case LOG_ERROR: fprintf(log_file, "**ERROR** "); break; case LOG_DEBUG: fprintf(log_file, "[debug]: "); break; } /* print out message */ va_start(va, format); vfprintf(log_file, format, va); va_end(va); /* flush */ fflush(log_file); /* unlock */ funlockfile(log_file); } /* sockaddr_name_of() get a printable version of a socket address */ char* sockaddr_name_of (struct sockaddr_storage* addr, char* buffer, size_t len) { char host_buf[NI_MAXHOST]; char serv_buf[NI_MAXSERV]; assert(addr != NULL); assert(buffer != NULL); assert(len > 0); /* get info */ getnameinfo((struct sockaddr*)addr, sizeof(struct sockaddr_storage), host_buf, sizeof(host_buf), serv_buf, sizeof(serv_buf), NI_NUMERICHOST | NI_NUMERICSERV); /* format out */ if (strchr(host_buf, ':')) snprintf(buffer, len, "[%s]:%s", host_buf, serv_buf); else snprintf(buffer, len, "%s:%s", host_buf, serv_buf); return buffer; } int connect_to_server (char* host, int port) { int sock; int err; int count; struct addrinfo hints; struct addrinfo* res; struct addrinfo* res_head; char portstr[16]; assert(host != NULL); assert(port > 0 && port <= UINT16_MAX); /* setup hints */ memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; /* make port into a string */ snprintf(portstr, sizeof(portstr), "%d", port); /* get address list */ if ((err = getaddrinfo(host, portstr, &hints, &res_head)) != 0) { log_msg(LOG_WARNING, "connect_to_server(): getaddrinfo() failed: %s\n", gai_strerror(err)); return -1; } /* iterator through addresses */ count = 0; res = res_head; while (res != NULL) { /* make socket */ sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol); if (sock != -1) { /* connect */ if (connect(sock, res->ai_addr, res->ai_addrlen) == 0) { freeaddrinfo(res_head); return sock; } close(sock); } ++count; res = res->ai_next; } freeaddrinfo(res_head); /* none successful */ if (count == 0) log_msg(LOG_WARNING, "connect_to_server(): no addresses found for %s:%d\n", host, port); else log_msg(LOG_NOTICE, "host %s:%d is currently unavailable\n", host, port); return -1; } /* parse_connect_command() parses the connect command line returns 0 if successful, -1 on parse error */ int parse_connect_command(char* command, char** host, int* port) { char* sep; char* end; char* portstr; long portval; assert(command != NULL); assert(host != NULL); assert(port != NULL); /* initialize */ *host = NULL; *port = 0; /* erase whitespace at end */ for (end = command + strlen(command) - 1; end >= command && isspace(*end); --end) *end = 0; /* must begin with 'connect ' */ if (strncmp(command, "connect ", 8)) return -1; /* host follows */ *host = command + 8; /* find space separating host and port */ sep = strchr(*host, ' '); if (sep == NULL) { /* no separator - failure */ return -1; } /* end host and start port */ *sep = 0; portstr = sep + 1; /* parse port for validity */ portval = strtol(portstr, &end, 10); if (*end != 0) { /* wasn't all a number - failure */ return -1; } /* make sure port is in a valid range */ if (portval < 1 || portval > UINT16_MAX) { /* not in range - failure */ return -1; } /* make sure host has data */ if (strlen(*host) == 0) { /* empty host - failure */ return -1; } /* all good */ *port = (int)portval; return 0; } /* client_main() main loop for the client threads */ void* client_main (void* arg) { int client; int server; int err; char cbuffer[BUFFER_SIZE]; int clen; char sbuffer[BUFFER_SIZE]; int slen; struct pollfd pollfds[2]; char* end; char* host; int port; enum ClientState state; struct sockaddr_storage caddr; char client_name[128]; time_t start; time_t activity; assert(arg != NULL); /* get client socket */ client = ((struct ClientInfo*)arg)->sock; caddr = ((struct ClientInfo*)arg)->addr; free(arg); /* get client_name */ sockaddr_name_of(&caddr, client_name, sizeof(client_name)); /* initialize client data */ clen = 0; pollfds[0].fd = client; pollfds[0].events = POLLIN; /* initialize server data */ server = -1; slen = 0; pollfds[1].fd = -1; pollfds[1].events = 0; /* start time */ start = time(NULL); activity = time(NULL); /* client loop */ state = CLIENT_INIT; do { /* poll fds */ pollfds[0].revents = 0; pollfds[1].revents = 0; err = poll(pollfds, (server == -1) ? 1 : 2, 1000); log_dbg("poll() for %s: %d %hx %hx %d %d\n", client_name, err, pollfds[0].revents, pollfds[1].revents, clen, slen); /* handle error */ if (err < 0 && errno != EINTR) { log_msg(LOG_ERROR, "client_main(%d): poll() failed: %s\n", client, strerror(errno)); state = CLIENT_SHUTDOWN; } /* server disconnected? */ if (pollfds[1].revents & POLLHUP) { close(server); server = -1; state = CLIENT_FINISH; pollfds[1].revents = 0; } /* client disconnected? */ if (pollfds[0].revents & POLLHUP) { close(client); client = -1; state = CLIENT_SHUTDOWN; pollfds[0].revents = 0; } /* read in from client */ if (pollfds[0].revents & POLLIN && clen < BUFFER_SIZE) { /* receive data */ err = recv(pollfds[0].fd, cbuffer + clen, BUFFER_SIZE - clen, 0); if (err < 0 && errno != EINTR) { log_msg(LOG_ERROR, "client_main(%d): recv(client) failed: %s\n", client, strerror(errno)); state = CLIENT_SHUTDOWN; } /* eof? why no POLLHUP? grr! */ if (err == 0) { close(client); client = -1; state = CLIENT_SHUTDOWN; } /* increment buffer count */ if (err > 0) clen += err; /* we have activity */ activity = time(NULL); } /* read in from server */ if (pollfds[1].revents & POLLIN && slen < BUFFER_SIZE) { /* receive data */ err = recv(pollfds[1].fd, sbuffer + slen, BUFFER_SIZE - slen, 0); if (err < 0 && errno != EINTR) { log_msg(LOG_ERROR, "client_main(%d): recv(server) failed: %s\n", client, strerror(errno)); state = CLIENT_SHUTDOWN; } /* eof? why no POLLHIP? grr! */ if (err == 0) { close(server); server = -1; state = CLIENT_FINISH; } /* increment buffer count */ if (err > 0) slen += err; } /* write out to client */ if (client != -1 && slen > 0) { /* send data */ err = send(pollfds[0].fd, sbuffer, slen, MSG_DONTWAIT); if (err < 0 && errno != EINTR) { log_msg(LOG_ERROR, "client_main(%d): send(client) failed: %s\n", client, strerror(errno)); state = CLIENT_SHUTDOWN; } /* decrement buffer count */ if (err > 0) slen -= err; } /* write out to server */ if (server != -1 && clen > 0) { /* send data */ err = send(pollfds[1].fd, cbuffer, clen, MSG_DONTWAIT); if (err < 0 && errno != EINTR) { log_msg(LOG_ERROR, "client_main(%d): send(server) failed: %s\n", client, strerror(errno)); state = CLIENT_SHUTDOWN; } /* decrement buffer count */ if (err > 0) clen -= err; } /* process client states */ switch (state) { case CLIENT_INIT: /* check if we have a line ready for consumption */ if (clen == 0) break; for (end = cbuffer; end - cbuffer < clen; ++end) if (*end == '\n') break; if (end - cbuffer >= clen) break; *end = 0; log_dbg("CLIENT %s INPUT: %s\n", client_name, cbuffer); /* parse connect command */ if (parse_connect_command(cbuffer, &host, &port)) { log_msg(LOG_NOTICE, "client %s sent invalid connect command\n", client_name); slen = snprintf(sbuffer, sizeof(sbuffer), "Invalid connect command.\r\n"); state = CLIENT_FINISH; break; } log_dbg("CLIENT %s REQUEST: %s %d\n", client_name, host, port); /* verify host:port is allowed */ if (allowed_servers_check(host, port)) { log_msg(LOG_NOTICE, "client %s attempted connection to server %s:%d - access denied\n", client_name, host, port); slen = snprintf(sbuffer, sizeof(sbuffer), "Access denied to %s:%d.\r\n", host, port); state = CLIENT_FINISH; break; } log_dbg("CLIENT %s CONNECTING: %s %d\n", client_name, host, port); /* connect to server */ server = connect_to_server(host, port); if (server == -1) { slen = snprintf(sbuffer, sizeof(sbuffer), "Failed to connect to %s:%d.\r\n", host, (int)port); state = CLIENT_FINISH; break; } log_dbg("CLIENT %s READY: %s %d\n", client_name, host, port); /* make active */ pollfds[1].fd = server; log_msg(LOG_NOTICE, "client %s connected to server %s:%d\n", client_name, host, port); state = CLIENT_ACTIVE; /* clean up cbuffer */ if (clen > end - cbuffer + 1) { memmove(cbuffer, end + 1, clen - (end - cbuffer + 1)); clen -= (end - cbuffer + 1); } else clen = 0; /* reset activity timer */ activity = time(NULL); break; case CLIENT_ACTIVE: break; case CLIENT_FINISH: /* if there's nothing left to send to client, shutdown the connection */ if (slen == 0) state = CLIENT_SHUTDOWN; break; case CLIENT_SHUTDOWN: break; } /* set client poll events */ pollfds[0].events = 0; if (clen < BUFFER_SIZE) pollfds[0].events |= POLLIN; if (slen > 0) pollfds[0].events |= POLLOUT; /* set server poll events */ if (server != -1) { pollfds[1].events = 0; if (slen < BUFFER_SIZE) pollfds[1].events = POLLIN; if (clen > 0) pollfds[1].events |= POLLOUT; } /* more than the connect timeout in seconds has passed? */ if (state == CLIENT_INIT && time(NULL) - start >= connect_timeout) { log_msg(LOG_NOTICE, "client %s connect timeout (%d seconds)\n", client_name, connect_timeout); slen = snprintf(sbuffer, sizeof(sbuffer), "Connection timeout: no connect command given within %d seconds.\r\n", connect_timeout); state = CLIENT_FINISH; } /* no recent activity? shutdown */ if (state == CLIENT_ACTIVE && time(NULL) - activity >= activity_timeout) { log_msg(LOG_NOTICE, "disconnecting %s due to lack of activity (%d seconds)\n", client_name, activity_timeout); close(server); server = -1; slen = snprintf(sbuffer, sizeof(sbuffer), "Disconnecting to due lack of activity.\r\n"); state = CLIENT_FINISH; } } while (state != CLIENT_SHUTDOWN); /* disconnection message */ log_msg(LOG_NOTICE, "disconnecting client %s\n", client_name); /* shutdown sockets */ if (client != -1) close(client); if (server != -1) close(server); log_dbg("decrementing client count for %s\n", client_name); /* decrement current client count */ client_remove(&caddr); log_dbg("ending thread for %s\n", client_name); return NULL; } int create_listen_socket (int family, int port) { int sock; int opt; socklen_t addr_len; struct sockaddr_storage sockaddr; assert(family == AF_INET || family == AF_INET6); assert(port > 0 && port <= UINT16_MAX); /* create socket */ sock = socket(family, SOCK_STREAM, 0); if (sock == -1) { log_msg(LOG_ERROR, "create_listen_socket(): socket() failed: %s\n", strerror(errno)); return -1; } /* allow reuse of server port */ opt = 1; setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); if (family == AF_INET6) { /* only listen for IPv6 addresses */ opt = 1; setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)); /* initialie IPV6 addr */ sockaddr.ss_family = AF_INET6; ((struct sockaddr_in6*)&sockaddr)->sin6_port = htons(port); memset(&((struct sockaddr_in6*)&sockaddr)->sin6_addr, 0, sizeof(((struct sockaddr_in6*)&sockaddr)->sin6_addr)); } else if (family == AF_INET) { /* initialize IPV4 addr */ sockaddr.ss_family = AF_INET; ((struct sockaddr_in*)&sockaddr)->sin_port = htons(port); memset(&((struct sockaddr_in*)&sockaddr)->sin_addr, 0, sizeof(((struct sockaddr_in*)&sockaddr)->sin_addr)); } else { log_msg(LOG_ERROR, "create_listen_socket(): unsupported family: %d\n", family); close(sock); return -1; } /* bind addr to socket */ addr_len = sizeof(sockaddr); if (bind(sock, (struct sockaddr*)&sockaddr, addr_len) == -1) { log_msg(LOG_ERROR, "create_listen_socket(): bind() failed: %s\n", strerror(errno)); close(sock); return -1; } /* begin listening on socket */ if (listen(sock, 5) == -1) { log_msg(LOG_ERROR, "create_listen_socket(): listen() failed: %s\n", strerror(errno)); close(sock); return -1; } return sock; } int create_client (int sock) { int err; int client; struct sockaddr_storage sockaddr; socklen_t addr_len = sizeof(sockaddr); struct ClientInfo* cinfo; char client_name[128]; pthread_attr_t p_attr; pthread_t p_thread; assert(sock >= 0); log_dbg("accepting client socket\n"); /* get client socket */ addr_len = sizeof(sockaddr); client = accept(sock, (struct sockaddr*)&sockaddr, &addr_len); if (client == -1) { log_msg(LOG_WARNING, "create_client(): accept() failed: %s\n", strerror(errno)); return -1; } log_dbg("getting client name\n"); /* get client_name */ sockaddr_name_of(&sockaddr, client_name, sizeof(client_name)); log_dbg("checking client %s with ban list\n", client_name); /* check if banned */ if (banned_clients_check(&sockaddr) == -1) { log_msg(LOG_NOTICE, "rejected client %s, banned host/network\n", client_name); write_string(client, "Your client host or network is banned.\r\n"); close(client); return -1; } log_dbg("checking client %s connection count\n", client_name); /* check client count */ err = client_add(&sockaddr); if (err == -1) { log_msg(LOG_NOTICE, "rejected client %s, too many users\n", client_name); /* FIXME: this may not all get written out */ write_string(client, "Maximum user count exceeded.\r\n"); close(client); return -1; } if (err == -2) { log_msg(LOG_NOTICE, "rejected client %s, too many conncetion from that address\n", client_name); /* FIXME: this may not all get written out */ write_string(client, "Maximum client count for your address exceeded.\r\n"); close(client); return -1; } log_dbg("allocating client %s thread storage\n", client_name); /* allocate space to store client socket */ cinfo = (struct ClientInfo*)malloc(sizeof(struct ClientInfo)); if (cinfo == 0) { client_remove(&sockaddr); log_msg(LOG_WARNING, "create_client(): malloc() failed: %s\n", strerror(errno)); log_msg(LOG_NOTICE, "disconnecting client %s\n", client_name); return -1; } cinfo->sock = client; cinfo->addr = sockaddr; /* connection message */ log_msg(LOG_NOTICE, "connection from %s\n", client_name); /* initialize pthread attributes - all threads shoudl be detached */ pthread_attr_init(&p_attr); pthread_attr_setdetachstate(&p_attr, PTHREAD_CREATE_DETACHED); log_dbg("creating client %s thread\n", client_name); /* create thread */ if (pthread_create(&p_thread, &p_attr, client_main, cinfo)) { client_remove(&sockaddr); log_msg(LOG_WARNING, "create_client(): pthread_create() failed: %s\n", strerror(errno)); log_msg(LOG_NOTICE, "disconnecting client %s\n", client_name); close(client); free(cinfo); return -1; } log_dbg("done creating client %s\n", client_name); return client; } int drop_privileges (char* user, char* group) { struct group* grp; struct passwd* pwd; uid_t uid; gid_t gid; char* end; /* handle group */ if (group != NULL) { /* get gid */ gid = strtol(group, &end, 10); /* not numeric? lookup as name */ if (*end != 0) { grp = getgrnam(group); if (grp == NULL) { log_msg(LOG_ERROR, "drop_privileges(): group not found: %s\n", group); return -1; } gid = grp->gr_gid; } /* set real/effective gid */ if (setregid(gid, gid)) { log_msg(LOG_ERROR, "drop_privileges(): setregid() failed: %s\n", strerror(errno)); return -1; } log_msg(LOG_NOTICE, "changed gid to %d\n", gid); /* drop supplementary groups */ if (setgroups(0, NULL)) { log_msg(LOG_ERROR, "drop_privileges(): setgroups() failed: %s\n", strerror(errno)); return -1; } log_msg(LOG_NOTICE, "dropped supplementary groups\n"); } /* handle group */ if (user != NULL) { /* get gid */ uid = strtol(user, &end, 10); /* not numeric? lookup as name */ if (*end != 0) { pwd = getpwnam(user); if (pwd == NULL) { log_msg(LOG_ERROR, "drop_privileges(): user not found: %s\n", group); return -1; } uid = pwd->pw_uid; } /* set real/effective gid */ if (setreuid(uid, uid)) { log_msg(LOG_ERROR, "drop_privileges(): setreuid() failed: %s\n", strerror(errno)); return -1; } log_msg(LOG_NOTICE, "changed UID to %d\n", uid); } return 0; } void signal_sighup (int signal) { reload_flag = 1; } void signal_sigterm (int signal) { shutdown_flag = 1; } int main (int argc, char** argv) { int sock; int sock6; int err; struct pollfd sock_fds[2]; struct sigaction sigact; FILE* pid; /* configuration */ int port = DEFAULT_LISTEN_PORT; char* host_list = DEFAULT_HOST_LIST; char* ban_list = DEFAULT_BAN_LIST; char* log_file = NULL; char* pid_file = NULL; int do_daemon = 0; int do_ipv6 = 0; int do_help = 0; char* user = NULL; char* group = NULL; /* hello! */ printf("Telnet Proxy Daemon v%s\n", PROXY_VERSION); printf("Copyright (C) 2004 AwesomePlay Productions, Inc.\n"); printf("See the file COPYING for license details.\n"); /* options */ struct Option options[] = { { 0, "help", NULL, NULL, &do_help }, { 'd', "daemon", NULL, NULL, &do_daemon }, { '6', "ipv6", NULL, NULL, &do_ipv6 }, { 'p', "port", NULL, &port, NULL }, { 'l', "log", &log_file, NULL, NULL }, { 'h', "hostlist", &host_list, NULL, NULL }, { 'b', "banlist", &ban_list, NULL, NULL }, { 'w', "pid", &pid_file, NULL, NULL }, { 'c', "maxclients", NULL, &max_clients, NULL }, { 'm', "maxhost", NULL, &max_host_clients, NULL }, { 't', "ctime", NULL, &connect_timeout, NULL }, { 'a', "atime", NULL, &activity_timeout, NULL }, { 'u', "user", &user, NULL, NULL }, { 'g', "group", &group, NULL, NULL }, { 0, NULL, NULL, NULL, NULL } }; /* parse options */ if (parse_options(options, argc, argv) == -1) return 1; /* help? */ if (do_help) { print_usage(argv[0], options); return 0; } /* open log file */ if (log_open(log_file)) return 1; /* verify port */ if (port < 1 || port > UINT16_MAX) { log_msg(LOG_ERROR, "invalid port %d\n", port); return 1; } /* verify max clients */ if (max_clients < 1) { log_msg(LOG_ERROR, "invalid client max %d\n", max_clients); return 1; } if (max_host_clients < 1) { log_msg(LOG_ERROR, "invalid client host max %d\n", max_host_clients); return 1; } /* verify timeouts */ if (connect_timeout < 1) { log_msg(LOG_ERROR, "invalid connect timeout %d\n", connect_timeout); return 1; } if (activity_timeout < 1) { log_msg(LOG_ERROR, "invalid activity timeout %d\n", activity_timeout); return 1; } /* load initial set of valid hosts */ if (allowed_servers_load(host_list)) return 1; /* load initial set of banned clients */ if (banned_clients_load(ban_list)) return 1; /* do daemonization */ if (do_daemon) { /* requires a log file */ if (log_file == NULL) { log_msg(LOG_ERROR, "must specify a log file when using daemon mode\n"); return 1; } /* fork */ err = fork(); if (err == -1) { log_msg(LOG_ERROR, "main(): fork() failed: %s\n", strerror(errno)); return 1; } /* main process - exit */ if (err != 0) return 0; /* close stdin, stdout, stderr */ close(0); close(1); close(2); /* set session id */ setsid(); } /* create IPV6 socket */ sock6 = -1; if (do_ipv6) { sock6 = create_listen_socket(AF_INET6, port); if (sock6 == -1) return 1; log_msg(LOG_NOTICE, "IPv6 enabled\n"); } /* create IPV4 socket */ sock = create_listen_socket(AF_INET, port); if (sock == -1) return 1; /* set signals */ sigemptyset (&sigact.sa_mask); sigact.sa_flags = 0; sigact.sa_handler = SIG_IGN; sigaction(SIGPIPE, &sigact, NULL); sigact.sa_handler = signal_sighup; sigaction(SIGHUP, &sigact, NULL); sigact.sa_handler = signal_sigterm; sigaction(SIGTERM, &sigact, NULL); sigact.sa_handler = signal_sigterm; sigaction(SIGINT, &sigact, NULL); /* write pid file */ if (pid_file != NULL) { pid = fopen(pid_file, "wt"); if (pid == NULL) { log_msg(LOG_ERROR, "failed to open %s: %s\n", pid_file, strerror(errno)); return 1; } fprintf(pid, "%d\n", getpid()); fclose(pid); } /* drop privileges */ if (drop_privileges(user, group)) return 1; /* begin server loop */ log_msg(LOG_NOTICE, "listening on port %d\n", port); do { /* poll server socket(s) */ sock_fds[0].fd = sock; sock_fds[0].events = POLLIN; sock_fds[0].revents = 0; sock_fds[1].fd = sock6; sock_fds[1].events = POLLIN; sock_fds[1].revents = 0; err = poll(sock_fds, 2, 1000); /* 1000 for DEBUG - set back to -1 */ log_dbg("poll() for main: %d %hx %hx\n", err, sock_fds[0].revents, sock_fds[1].revents); /* check for error */ if (err == -1 && errno != EINTR) { log_msg(LOG_WARNING, "main(): poll() failed: %s\n", strerror(errno)); break; } /* reload data on SIGHUP */ if (reload_flag) { reload_flag = 0; log_msg(LOG_NOTICE, "reloading host list\n"); allowed_servers_load(host_list); log_msg(LOG_NOTICE, "reloading ban list\n"); banned_clients_load(ban_list); } /* shutdown on SIGTERM */ if (shutdown_flag) { log_msg(LOG_NOTICE, "received terminating signal\n"); break; } /* something ready on IPv4 socket */ if (sock_fds[0].revents & POLLIN) { log_dbg("creating ipv4 client\n"); create_client(sock_fds[0].fd); log_dbg("ipv4 client created\n"); } /* something ready on IPv6 socket */ if (sock_fds[1].revents & POLLIN) { log_dbg("creating ipv6 client\n"); create_client(sock_fds[1].fd); log_dbg("ipv6 client created\n"); } } while(1); /* shutdown sockets */ close(sock); if (sock6 != -1) close(sock6); /* unlink pid file */ if (pid_file != NULL) unlink(pid_file); /* finish up */ log_msg(LOG_NOTICE, "terminating proxy\n"); log_close(); return 0; }