From: Sam Moore Date: Sat, 23 Feb 2013 08:31:49 +0000 (+0800) Subject: Things seem to work... X-Git-Url: https://git.ucc.asn.au/?a=commitdiff_plain;h=4e2127d6576cea3f54c619d0bb20a22006567206;p=matches%2Fswarm.git Things seem to work... So I'll commit before I break everything! --- diff --git a/src/Makefile b/src/Makefile index a9f6729..da2cf0d 100644 --- a/src/Makefile +++ b/src/Makefile @@ -1,10 +1,10 @@ # Makefile for swarm CXX = gcc -LIBRARIES = -lm -lpthread #-lGL -lglut -lGLU -lpthread +LIBRARIES = /usr/local/lib/libssh2.a -lm -lpthread -lssl -lcrypto -lz #-lGL -lglut -lGLU -lpthread FLAGS = --std=c99 -D_POSIX_C_SOURCE=200112L -Wall -pedantic -g PREPROCESSOR_FLAGS = -LINK_OBJ = options.o log.o task.o network.o master.o daemon.o slave.o main.o +LINK_OBJ = options.o log.o task.o network.o ssh.o master.o daemon.o slave.o main.o BIN = swarm diff --git a/src/debug.c b/src/debug.c new file mode 100644 index 0000000..83682bd --- /dev/null +++ b/src/debug.c @@ -0,0 +1,10 @@ +#include + +void libssh2_debug_fuck(LIBSSH2_SESSION * session, int level, char * fmt, ...) +{ + va_list va; + va_start(va, fmt); + vfprintf(stderr, fmt, va); + va_end(va); + fprintf(stderr, "\n"); +} diff --git a/src/log.c b/src/log.c index 867ab64..3a8c55c 100644 --- a/src/log.c +++ b/src/log.c @@ -14,17 +14,17 @@ void log_print(int level, char * funct, char * fmt, ...) char severity[BUFSIZ]; switch (level) { - case 0: - sprintf(severity, "Error"); + case LOGERR: + sprintf(severity, "ERROR"); break; - case 1: - sprintf(severity, "Warning"); + case LOGWARN: + sprintf(severity, "WARNING"); break; - case 2: - sprintf(severity, "Notice"); + case LOGNOTE: + sprintf(severity, "NOTICE"); break; - case 3: - sprintf(severity, "Info"); + case LOGINFO: + sprintf(severity, "INFO"); break; default: sprintf(severity, "DEBUG"); @@ -32,7 +32,7 @@ void log_print(int level, char * funct, char * fmt, ...) } if (funct != NULL) - last_len = fprintf(stderr, "%s [%d] : %s in %s - ", options.program, getpid(), severity, funct); + last_len = fprintf(stderr, "%s [%d] : %s : %s - ", options.program, getpid(), severity, funct); else { for (int i = 0; i < last_len; ++i); diff --git a/src/log.h b/src/log.h index 24e5947..75e2780 100644 --- a/src/log.h +++ b/src/log.h @@ -7,6 +7,8 @@ #include +enum {LOGERR=0, LOGWARN=1, LOGNOTE=2, LOGINFO=3,LOGDEBUG=4}; + extern void log_print(int level, char * funct, char * fmt,...); extern void error(char * funct, char * fmt, ...); diff --git a/src/main.c b/src/main.c index a7666dc..baab2b7 100644 --- a/src/main.c +++ b/src/main.c @@ -39,7 +39,10 @@ int main(int argc, char ** argv) Master_main(&options); } else + { + fprintf(stderr, "%p %s", options.master_addr, options.master_addr); Slave_main(&options); + } exit(EXIT_SUCCESS); return 0; diff --git a/src/master.c b/src/master.c index e61f5fa..2734d07 100644 --- a/src/master.c +++ b/src/master.c @@ -12,7 +12,10 @@ #include #include #include "slave.h" +#include #include +#include +#include #include #include @@ -22,12 +25,15 @@ #include #include #include +#include "ssh.h" //#define THREAD_SENDING // You decided to solve a problem with threads; now you have two problems // probably not that great to use threads anyway, since it eats one of your cores // It probably spends 90% of its time sleeping, and 9.9% unlocking mutexes +// the signal handler now breaks threads... don't use them + #ifdef THREAD_SENDING pthread_t sender_thread; pthread_mutex_t sender_lock = PTHREAD_MUTEX_INITIALIZER; @@ -69,12 +75,19 @@ void Master_main(Options * o) void Master_setup(Options * o) { + int err = libssh2_init(0); + if (err != 0) + { + error("Master_setup", "Initialising libssh2 - error code %d", err); + } + signal(SIGCHLD, sigchld_handler); master.o = o; master.barrier_number = -1; master.last_number = -1; master.nSlaves = o->nCPU; master.running = master.nSlaves; + master.nRemote = 0; master.remote_err = NULL; if (master.nSlaves == 0) error("Master_setup", "No CPUs to start slaves with!"); @@ -119,7 +132,6 @@ void Make_slave(int i) master.slave[i].in = sv[1]; master.slave[i].out = sv[1]; master.slave[i].running = true; - master.slave[i].ssh_pid = 0; } void Master_input(char c) @@ -174,9 +186,9 @@ void Master_input(char c) log_print(0, "Master_input", "No host specified for ABSORB directive"); char * np = strtok(NULL, " "); if (np != NULL) - Master_absorb(cmd, options.port, atoi(np)); + Master_absorb(cmd, atoi(np)); else - Master_absorb(cmd, options.port, 0); + Master_absorb(cmd, 0); } else if (strcmp(cmd, "OUTPUT") == 0) { @@ -367,8 +379,11 @@ void Master_output(int i, char c) #endif //THREAD_SENDING if (t == NULL) { - log_print(0, "Master_output", "Read input from %s, but no task assigned!",master.slave[i].name); - error(NULL, "Please refrain from echoing three bell characters in a row."); + log_print(3, "Master_output", "Echo %c back to slave %d", c, i); + write(master.slave[i].in, &c, sizeof(char)); + //log_print(0, "Master_output", "Read input from %s, but no task assigned!",master.slave[i].name); + //error(NULL, "Please refrain from echoing three bell characters in a row."); + return; } @@ -395,8 +410,10 @@ void Master_output(int i, char c) { fprintf(stdout, "%d:\n", t->number); } + /* else if (t->output[t->outlen-1] == '\f') { + log_print(2, "Master_output", "Slave %d requests name (%s)", i, master.slave[i].name); static int bufsiz = BUFSIZ; char * buffer = (char*)(calloc(bufsiz, sizeof(char))); @@ -420,6 +437,7 @@ void Master_output(int i, char c) master.slave[i].task_pool = t2; master.last_number = t2->number; } + */ else { fprintf(stdout, "%d: %s", t->number, t->output); @@ -522,9 +540,9 @@ Task * Master_tasker(int i) void Master_loop() { - if (sigsetjmp(env,true) != 0) + if (sigsetjmp(env,true) != 0) // completely necessary evil { - log_print(2, "Master_loop", "Restored from longjmp"); + //log_print(2, "Master_loop", "Restored from longjmp"); } fd_set readSet; //fd_set writeSet; @@ -539,6 +557,7 @@ void Master_loop() bool quit = false; bool input = true; + char buffer[BUFSIZ]; while (!quit) { @@ -569,6 +588,11 @@ void Master_loop() { if (master.slave[i].running) FD_SET(master.slave[i].out, &readSet); } + + for (int i = 0; i < master.nRemote; ++i) + { + FD_SET(master.remote_err[i], &readSet); + } select(master.fd_max+1, &readSet, NULL, NULL, NULL); @@ -611,6 +635,16 @@ void Master_loop() } } + + for (int i = 0; i < master.nRemote; ++i) + { + if (FD_ISSET(master.remote_err[i], &readSet)) + { + int len = read(master.remote_err[i], buffer, sizeof(buffer)); + buffer[len] = '\0'; + fprintf(stderr, "%s", buffer); + } + } } @@ -648,7 +682,7 @@ void Master_send() } write(send_task.slave_fd, "\n", 1*sizeof(char)); master.commands_active++; - log_print(3, "Master_sender", "Sent task %d \"%s\" - %d tasks active", send_task.task->number, send_task.task->message, master.commands_active); + log_print(3, "Master_sender", "Sent task %d \"%s\" on socket %d - %d tasks active", send_task.task->number, send_task.task->message, send_task.slave_fd, master.commands_active); } #ifdef THREAD_SENDING @@ -702,6 +736,16 @@ void Master_cleanup() signal(SIGCHLD, SIG_IGN); // ignore child exits now + // tell all remote nodes to exit + for (int i = 0; i < master.nRemote; ++i) + { + FILE * f = fdopen(master.remote_err[i], "r+"); setbuf(f, NULL); + + fprintf(f, "exit\n"); + + fclose(f); + } + for (int i = 0; i < master.nSlaves; ++i) { @@ -717,13 +761,6 @@ void Master_cleanup() if (master.slave[i].pid <= 0) { Network_close(master.slave[i].in); - if (master.slave[i].ssh_pid > 0) - { - log_print(2, "Master_cleanup", "Killing ssh instance %d", master.slave[i].ssh_pid); - kill(master.slave[i].ssh_pid, 15); - if (kill(master.slave[i].ssh_pid, 0) == 0) - kill(master.slave[i].ssh_pid, 9); - } } else { @@ -740,83 +777,109 @@ void Master_cleanup() free(master.buffer); if (master.outfile != NULL) free(master.outfile); -} -void * start_server(void * args) -{ + libssh2_exit(); - *(int*)(args) = Network_server(*(int*)args); - log_print(2, "start_server", "started network server"); - return NULL; } -int Secure_connection(char * addr, int port); -void Master_absorb(char * addr, int port, int np) + + +void Master_absorb(char * addr, int np) { + int port = 0; + + char * user = strstr(addr, "@"); + if (user != NULL) + { + *(user-1) = '\0'; + char * t = user; + user = addr; + addr = t; + } + else + { + user = getpwuid(geteuid())->pw_name; + } + log_print(3, "Master_absorb", "User %s at address %s", user, addr); + + // ssh to the host on port 22 + ssh * s = ssh_new(user, addr, 22); + if (s == NULL) + { + log_print(0, "Master_absorb", "Couldn't ssh to %s@%s", user, addr); + return; + } + + + // work out the name to give to the shells char * name = strtok(addr, ":"); name = strtok(NULL, ":"); if (name == NULL) + name = addr; // default is host:X + else + *(name-1) = '\0'; // otherwise use name:X + + // setup array of remote stderr file descriptors + if (master.nRemote++ == 0) { - name = addr; + master.remote_err = (int*)(calloc(master.nRemote, sizeof(int))); + master.remote_reserved = master.nRemote; } - else + else if (master.nRemote >= master.remote_reserved) { - *(name-1) = '\0'; + // resize dynamically + master.remote_reserved *= 2; + master.remote_err = (int*)(realloc(master.remote_err,master.remote_reserved * sizeof(int))); } - //log_print(0, "name is %s\n", name); - int first_ssh = 0; + int sfd = -1; if (master.o->encrypt) - first_ssh = Secure_connection(addr, port); - - //pthread_t ss; - //int net_fd = port; - //pthread_create(&ss, NULL, start_server, (void*)(&net_fd)); - - char buffer[BUFSIZ]; - if (fork() == 0) { - // The alternative to this kind of terrible hack is OpenMPI's "opal" - // This involves >1000 lines of operating system independent ways to get an IP address from a local interface - // Which will then be completely useless if there is any sort of NAT involved - - //freopen("/dev/null", "r", stdin); - //freopen("/dev/null", "w", stdout); - //freopen("/dev/null", "w", stderr); + int sv[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) != 0) + error("Master_absorb", "Couldn't create socket for remote swarm"); + sfd = sv[0]; - char * cmd = buffer+sprintf(buffer, "swarm -p "); - if (master.o->encrypt) - cmd = cmd+sprintf(cmd, "%d -e", port+1000); - else - cmd = cmd+sprintf(cmd, "%d -u", port); - - if (np > 0) - cmd = cmd+sprintf(cmd, " -n %d", np); - sprintf(cmd, " -m $(echo $SSH_CONNECTION | awk \'{print $1}\')"); - log_print(3, "Master_absorb", "Execing %s", buffer); - execlp("ssh", "ssh", "-f", addr, buffer, NULL); + ssh_exec_swarm(s, NULL, sv+1, np); // start swarm remotely forward it to the socket + ssh_thread_add(s); // add the ssh to the thread } - log_print(3, "Master_absorb", "Listening on port %d", port); - int net_fd = Network_server(port); - log_print(3, "Master_absorb", "Created network server on port %d", port); + else + { + sfd = Network_server_bind(0, &port); // dynamically bind to a port + ssh_exec_swarm(s, &port, NULL, np); // start swarm remotely, have it connect on the port + ssh_destroy(s); // don't need the ssh anymore + sfd = Network_server_listen(sfd, NULL); // accept connections and pray + } + master.remote_err[master.nRemote-1] = sfd; + if (sfd > master.fd_max) + master.fd_max = sfd; + + char buffer[BUFSIZ]; + + int newSlaves = 0; + + int len = sprintf(buffer, "%s\n", name); + int w = 0; + while (w < len) + w += write(sfd, buffer+w, len-w); - char * s = buffer; - while (read(net_fd, s, sizeof(char)) != 0) + + len = 0; + do { - if (*s == '\n') - { - *s = '\0'; - break; - } - ++s; + len = read(sfd, buffer+len, sizeof(buffer)); + buffer[len] = '\0'; } + while (buffer[len-1] != '\n'); + buffer[len-1] = '\0'; + newSlaves = atoi(buffer); + + - int newSlaves = atoi(buffer); - log_print(3, "Master_absorb", "Absorbing %d slaves from %s\n", newSlaves, addr); if (newSlaves == 0) { error("Master_absorb", "No slaves to absorb from %s", addr); @@ -825,62 +888,70 @@ void Master_absorb(char * addr, int port, int np) master.slave = (Slave*)(realloc(master.slave, (master.nSlaves + newSlaves) * sizeof(Slave))); if (master.slave == NULL) { - error("Master_absorb", "Resizing slave array from %d to %d slaves : %s", master.nSlaves, master.nSlaves + newSlaves, strerror(errno)); + error("Master_absorb", "Resizing slave array from %d to %d slaves : %s", + master.nSlaves, master.nSlaves + newSlaves, strerror(errno)); } - - if (master.o->encrypt) - { - for (int i = 0; i < newSlaves-1; ++i) - master.slave[master.nSlaves+i].ssh_pid = Secure_connection(addr, port+i+1); - - } - master.slave[master.nSlaves+newSlaves-1].ssh_pid = first_ssh; + for (int i = 0; i < newSlaves; ++i) + { + int ii = master.nSlaves + i; + if (master.o->encrypt) + { + int sv[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) != 0) + error("Master_absorb", "Couldn't create socket for remote swarm"); + + LIBSSH2_LISTENER * listener = NULL; + // libssh2 can't finalise the connection when the port is dynamic (ie: port = 0) + while (listener == NULL) + { + port = 20000 + rand() % 30000; // so pick ports at random until binding succeeds + listener = ssh_get_listener(s, &port); // port forward to the socket + } - - - char c = '\a'; - log_print(3, "Master_absorb", "Writing bell to slave"); - write(net_fd, &c, sizeof(char)); - + log_print(4,"Master_absorb", "Chose port %d", port); + int len = sprintf(buffer, "%d\n", port); + + int w = 0; + while (w < len) + { + w += write(sfd, buffer+w, len-w); + } + usleep(200000); // give ssh_thread a chance to actually send the data + log_print(4, "Master_absorb", "Creating tunnel..."); + ssh_add_tunnel(s, listener, sv[1]); + master.slave[ii].in = sv[0]; + master.slave[ii].out = sv[0]; - for (int i = 0; i < newSlaves; ++i) - { - log_print(3, NULL, "Absorbing slave %d...", i); - int ii = master.nSlaves + i; - if (i == newSlaves-1) - { - master.slave[ii].out = net_fd; + log_print(4, "Master_absorb", "Tunnel for slave %d using socket %d<->%d setup", ii, sv[0], sv[1]); } else { - log_print(3, "Master_absorb", "Creating server %d at time %d", i, time(NULL)); - write(net_fd, &c, sizeof(char)); - master.slave[ii].out = Network_server(port + i + 1); + int tmp = Network_server_bind(0, &port); // bind to a port + + master.slave[ii].in = Network_server_listen(tmp, NULL); // listen for connection + master.slave[ii].out = master.slave[ii].in; } - master.slave[ii].in = master.slave[ii].out; + + + if (master.slave[ii].out > master.fd_max) master.fd_max = master.slave[ii].out; + char buffer[BUFSIZ]; sprintf(buffer, "%s:%d", name, i); master.slave[ii].name = strdup(buffer); master.slave[ii].addr = strdup(addr); master.slave[ii].running = true; master.slave[ii].pid = -1; master.slave[ii].task = NULL; - master.slave[ii].task_pool = NULL; - - FILE * f = fdopen(master.slave[ii].in, "w"); setbuf(f, NULL); - fprintf(f, "name=%s\n", master.slave[ii].name); - - log_print(3, NULL, "Done absorbing slave %d...", i); - + master.slave[ii].task_pool = NULL; } @@ -892,21 +963,6 @@ void Master_absorb(char * addr, int port, int np) } -int Secure_connection(char * addr, int port) -{ - int result = fork(); - if (result == 0) - { - char buffer[BUFSIZ]; - sprintf(buffer, "%d:localhost:%d", port+1000, port); - freopen("/dev/null", "r", stdin); - freopen("/dev/null", "w", stdout); - freopen("/dev/null", "w", stderr); - execl("/usr/bin/ssh", "/usr/bin/ssh", "-N", "-R", buffer, addr, NULL); - } - return result; -} - void sigchld_handler(int signal) { @@ -918,13 +974,6 @@ void sigchld_handler(int signal) if (p == -1) error("sigchld_handler", "waitpid : %s", strerror(errno)); - if (WIFSIGNALED(s)) - { - int sig = WTERMSIG(s); - log_print(2, "sigchld_handler", "A child [%d] was terminated with signal %d; terminating self with same signal", p, sig); - kill(getpid(), sig); - return; - } int i = 0; for (i = 0; i < master.nSlaves; ++i) @@ -937,13 +986,28 @@ void sigchld_handler(int signal) return; } - log_print(1, "sigchld_handler", "Slave %d [%d] exited with code %d; restarting it",i, p, s); + fprintf(stderr, "Unexpected exit of slave %s", master.slave[i].name); + if (WIFSIGNALED(s)) + { + int sig = WTERMSIG(s); + fprintf(stderr, " due to %s", strsignal(sig)); + if (sig == SIGKILL) + { + printf(" - committing suicide\n"); + kill(getpid(), sig); + } + } + else + { + fprintf(stderr, " return code %d.",s); + } + fprintf(stderr, " Starting replacement.\n"); Make_slave(i); if (master.o->end != NULL) { - log_print(1, "sigchld_handler", "Trying to convince slave %d to be nice", i); + //log_print(1, "sigchld_handler", "Trying to convince slave %d to be nice", i); char buffer[BUFSIZ]; sprintf(buffer, "name=%s;echo -en \"%s\"\n", master.slave[i].name, master.o->end); if (write(master.slave[i].in, buffer, strlen(buffer)) <= 0) diff --git a/src/master.h b/src/master.h index 1a2d5e9..3595e23 100644 --- a/src/master.h +++ b/src/master.h @@ -13,7 +13,7 @@ extern void * Master_sender(void * args); extern void Master_setup(Options * o); extern void Master_cleanup(); extern void Master_send(); -extern void Master_absorb(char * addr, int port, int np); +extern void Master_absorb(char * addr, int np); extern void Make_slave(int i); extern void sigchld_handler(int signal); @@ -43,6 +43,10 @@ typedef struct int commands_active; + int * remote_err; // sockets used as stderr for remote shells + int nRemote; // number of remote shells + int remote_reserved; // number of sockets reserved + Options * o; } Master; diff --git a/src/network.c b/src/network.c index 5354ac3..bcd06d4 100644 --- a/src/network.c +++ b/src/network.c @@ -6,9 +6,21 @@ #define h_addr h_addr_list[0] -int Network_server(int port) {return Network_server_r(NULL, port);} -int Network_server_r(char * addr, int port) + + + +int Network_get_port(int sfd) +{ + static struct sockaddr_in sin; + static socklen_t len = sizeof(struct sockaddr_in); + + if (getsockname(sfd, (struct sockaddr *)&sin, &len) != 0) + error("Network_port", "getsockname : %s", strerror(errno)); + return ntohs(sin.sin_port); +} + +int Network_server_bind(int port, int * bound) { int sfd = socket(PF_INET, SOCK_STREAM, 0); if (sfd < 0) @@ -26,6 +38,15 @@ int Network_server_r(char * addr, int port) { error("Network_server", "Binding socket on port %d : %s", port, strerror(errno)); } + + if (bound != NULL) + *bound = Network_get_port(sfd); + return sfd; +} + +int Network_server_listen(int sfd, char * addr) +{ + int port = Network_get_port(sfd); if (listen(sfd, 1) < 0) { error("Network_server", "Listening on port %d : %s", port, strerror(errno)); @@ -52,12 +73,18 @@ int Network_server_r(char * addr, int port) assert(sfd >= 0); return sfd; - +} + +int Network_server(char * addr, int port) +{ + return Network_server_listen(Network_server_bind(port, &port), addr); } int Network_client(const char * addr, int port, int timeout) { int sfd = socket(PF_INET, SOCK_STREAM, 0); + + //log_print(2, "Network_client", "Created socket"); long arg = fcntl(sfd, F_GETFL, NULL); arg |= O_NONBLOCK; fcntl(sfd, F_SETFL, arg); @@ -75,6 +102,7 @@ int Network_client(const char * addr, int port, int timeout) bcopy ( hp->h_addr, &(server.sin_addr.s_addr), hp->h_length); server.sin_port = htons(port); + int res = connect(sfd, (struct sockaddr *) &server, sizeof(server)); @@ -91,9 +119,9 @@ int Network_client(const char * addr, int port, int timeout) struct timeval * tp; tp = (timeout < 0) ? NULL : &tv; - + int err = select(sfd+1, NULL, &writeSet, NULL, tp); - + if (err == 0) { error("Network_client", "Timed out trying to connect to %s:%d after %d seconds", addr, port, timeout); @@ -126,7 +154,8 @@ int Network_client(const char * addr, int port, int timeout) arg &= (~O_NONBLOCK); fcntl(sfd, F_SETFL, arg); - + + return sfd; } diff --git a/src/network.h b/src/network.h index b00e595..07c9459 100644 --- a/src/network.h +++ b/src/network.h @@ -15,10 +15,13 @@ #include #include -extern int Network_server(int port); -extern int Network_server_r(char * addr, int port); +extern int Network_get_port(int socket); // get port used by socket +extern int Network_server(char * addr, int port); extern int Network_client(const char * addr, int port, int timeout); +extern int Network_server_bind(int port, int * bound); +extern int Network_server_listen(int sfd, char * addr); + extern void Network_close(int sfd); #endif //_NETWORK_H diff --git a/src/options.c b/src/options.c index 6ef6119..ddd2ddc 100644 --- a/src/options.c +++ b/src/options.c @@ -35,6 +35,8 @@ void close_out() fclose(stdout); } +char name[BUFSIZ]; + void Initialise(int argc, char ** argv, Options * o) { srand(time(NULL)); @@ -44,9 +46,7 @@ void Initialise(int argc, char ** argv, Options * o) o->logfile = NULL; o->outfile = NULL; o->verbosity = 2; - o->port = 4000 + rand() % 1000; - o->slavefile = "slaves.swarm"; - o->dummy_shell = false; + o->port = 0; o->append = NULL; o->prepend = NULL; o->end = "\a\a\a"; @@ -54,7 +54,9 @@ void Initialise(int argc, char ** argv, Options * o) o->daemon = false; o->encrypt = true; o->interactive = true; - + + gethostname(name, sizeof(name)); + o->name = strdup(name); o->master_pid = getpid(); @@ -140,11 +142,17 @@ void ParseArguments(int argc, char ** argv, Options * o) error("ParseArguments", "No argument following %s switch", argv[i]); o->nCPU = atoi(argv[++i]); } - else if (argv[i][1] == 'm') + else if (argv[i][1] == 'r') { if (i >= argc-1) error("ParseArguments", "No argument following %s switch", argv[i]); o->master_addr = argv[++i]; + char * p = strstr(o->master_addr, ":"); + if (p != NULL) + { + *(p-1) = '\0'; + o->port = atoi(p); + } } else if (argv[i][1] == 'c') { diff --git a/src/options.h b/src/options.h index de365eb..24f35eb 100644 --- a/src/options.h +++ b/src/options.h @@ -14,14 +14,13 @@ typedef struct { char * program; + char * name; char * shell; char * master_addr; char * logfile; char * outfile; int verbosity; int port; - char * slavefile; - bool dummy_shell; char * prepend; char * append; char * end; diff --git a/src/options.o b/src/options.o new file mode 100644 index 0000000..88fbc13 Binary files /dev/null and b/src/options.o differ diff --git a/src/slave.c b/src/slave.c index 3327759..65e11c8 100644 --- a/src/slave.c +++ b/src/slave.c @@ -1,7 +1,7 @@ -#define _XOPEN_SOURCE +#define _XOPEN_SOURCE 700 #define _GNU_SOURCE -//#define _SIMPLE_SLAVE + #include "slave.h" #include @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -26,117 +27,182 @@ Slave * slave; +char name[BUFSIZ]; + +void Slave_shell(int i, char * shell); +void Slave_cleanup(); -int running; void Slave_main(Options * o) { + - if (fork() != 0) - exit(EXIT_SUCCESS); + setbuf(stdin, NULL); setbuf(stdout, NULL); setbuf(stderr, NULL); + + dup2(fileno(stdout), fileno(stderr)); // yes, this works, apparently + slave = (Slave*)(calloc(o->nCPU, sizeof(slave))); + atexit(Slave_cleanup); - o->verbosity = 100; - freopen(SLAVE_LOGFILE, "w", stderr); - setbuf(stderr, NULL); - slave = (Slave*)(calloc(o->nCPU, sizeof(Slave))); - int net_fd = -1; - if (o->encrypt) - net_fd = Network_client("localhost", o->port,100); + if (strcmp(o->master_addr, "-") != 0) + { + if (fork() != 0) + exit(EXIT_SUCCESS); + + //log_print(2, "Slave_main", "Using unsecured networking; connect to %s:%d", o->master_addr, o->port); + //log_print(2, "Slave_main", "Connecting to %s:%d", o->master_addr, o->port); + int net_fd = Network_client(o->master_addr, o->port, 100); + dup2(net_fd, fileno(stdin)); + dup2(net_fd, fileno(stdout)); + dup2(net_fd, fileno(stderr)); + + } else - net_fd = Network_client(o->master_addr, o->port,100); + { + o->master_addr = "localhost"; + //log_print(2, "Slave_main", "Using port forwarding; connect to %s", o->master_addr); + } - FILE * f = fdopen(net_fd, "w"); setbuf(f, NULL); - fprintf(f, "%d\n", o->nCPU); + char buffer[BUFSIZ]; - log_print(2, "Slave_main", "Waiting on bell from master"); - char c; - if (read(net_fd, &c, sizeof(char)) == 0 || c != '\a') - error("Slave_main", "Didn't get bell from master"); - + fgets(name, sizeof(name), stdin); + name[strlen(name)-1] = '\0'; + //log_print(2, "Slave_main", "Got name %s", name); + fprintf(stdout, "%d\n", o->nCPU); + //log_print(2, "Slave_main", "Wrote nCPU %d", o->nCPU); - log_print(2, "Slave_main", "Got bell from master"); - running = o->nCPU; + int port = 0; for (int i = 0; i < o->nCPU; ++i) { - int new_fd = net_fd; - if (i != o->nCPU-1) - { + //log_print(2, "Slave_main", "Waiting for port number..."); + fgets(buffer, sizeof(buffer), stdin); + + buffer[strlen(buffer)-1] = '\0'; + sscanf(buffer, "%d", &port); + //log_print(2, "Slave_main", "Port number %d", port); + slave[i].in = Network_client(o->master_addr, port,20); + //log_print(2, "Slave_main", "Connected to %s:%d\n", o->master_addr, port); + slave[i].out = slave[i].in; + + Slave_shell(i, o->shell); + } + - - if (read(net_fd, &c, sizeof(char)) == 0 || c != '\a') - error("Slave_main", "Didn't get bell from master authorising connection of slave %d", i); - sleep(1); + Slave_loop(o); - log_print(3, "Slave_main", "Connecting slave %d to port %d at time %d", i, o->port+i+1, time(NULL)); - if (o->encrypt) - new_fd = Network_client("localhost", o->port+i+1, 100); - else - new_fd = Network_client(o->master_addr, o->port+i+1, 100); + exit(EXIT_SUCCESS); +} - - - } +void Slave_shell(int i, char * shell) +{ + slave[i].pid = fork(); - slave[i].in = new_fd; slave[i].out = new_fd; - slave[i].pid = fork(); - if (slave[i].pid == 0) - { - dup2(slave[i].in, fileno(stdin)); - dup2(slave[i].out, fileno(stdout)); - execlp(o->shell, o->shell, NULL); - } + + if (slave[i].pid == 0) + { + dup2(slave[i].in, fileno(stdin)); + dup2(slave[i].out, fileno(stdout)); + //dup2(error_socket[1], fileno(stderr)); + + execlp(shell, shell, NULL); } - - Slave_loop(o); - free(slave); - exit(EXIT_SUCCESS); + // if the input is a network socket, this message gets sent to the master + // which will then echo it back to the socket and hence the shell + FILE * f = fdopen(slave[i].in, "w"); setbuf(f, NULL); + fprintf(f, "name=\"%s:%d\"\n", name,i); } void Slave_loop(Options * o) { - + fd_set readSet; + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 100000; + int p = -1; int s = 0; - - while (running > 0) + char buffer[BUFSIZ]; + while (true) { + FD_ZERO(&readSet); + FD_SET(fileno(stdin), &readSet); p = waitpid(-1, &s, 0); if (p == -1) { - log_print(0, "Slave_loop", "waitpid : %s", strerror(errno)); + //log_print(0, "Slave_loop", "waitpid : %s", strerror(errno)); continue; } - if (s != SHELL_EXIT_CODE) - { - // there was an error - int i = 0; - for (i = 0; i < o->nCPU; ++i) + //log_print(3, "Slave_loop", "Detected child %d exiting...", p); + + // check for an exit command from the master + select(fileno(stdin) + 1, &readSet, NULL, NULL, &tv); + + if (FD_ISSET(fileno(stdin), &readSet)) + { + fgets(buffer, sizeof(buffer), stdin); + if (strcmp(buffer, "exit\n") == 0) { - if (slave[i].pid == p) break; + log_print(2, "Slave_loop", "Received notification of exit.\n"); + exit(EXIT_SUCCESS); } - if (i >= o->nCPU) - error("Slave_loop", "No child matches pid %d", p); + } + + int i = 0; + for (i = 0; i < o->nCPU; ++i) + { + if (slave[i].pid == p) break; + } + if (i >= o->nCPU) + error("Slave_loop", "No child matches pid %d", p); + - log_print(0, "Slave_loop", "Child [%d] exits with status %d; restarting", p, s); - slave[i].pid = fork(); - if (slave[i].pid == 0) + + fprintf(stderr,"Unexpected exit of slave %s:%d", name, i); + if (WIFSIGNALED(s)) + { + int sig = WTERMSIG(s); + fprintf(stderr," due to %s", strsignal(sig)); + if (sig == SIGKILL) { - dup2(slave[i].in, fileno(stdin)); - dup2(slave[i].out, fileno(stdout)); - execlp(o->shell, o->shell, NULL); + fprintf(stderr," - %s committing suicide\n", name); + kill(getpid(), sig); } - - char buffer[] = "\f\a\a\a"; - if (write(slave[i].in, buffer, strlen(buffer)) <= 0) - log_print(0, "Slave_loop", "Slave %d input closed", i); } else - --running; + { + fprintf(stderr," return code %d.", s); + } + + + // cancel any tasks at the master for this slave + static int len = -1; + if (len < 0) + len = strlen(o->end); + write(slave[i].out, o->end, len); + + Slave_shell(i, o->shell); + + } } + +void Slave_cleanup() +{ + for (int i = 0; i < options.nCPU; ++i) + { + kill(slave[i].pid, SIGTERM); + } + sleep(1); + for (int i = 0; i < options.nCPU; ++i) + { + kill(slave[i].pid, SIGKILL); + } + free(slave); +} + + diff --git a/src/slave.h b/src/slave.h index 40b16ad..bf4db8f 100644 --- a/src/slave.h +++ b/src/slave.h @@ -17,7 +17,7 @@ typedef struct Task * task_pool; // tasks specific to the slave - int ssh_pid; + bool running; } Slave; diff --git a/src/ssh.c b/src/ssh.c new file mode 100644 index 0000000..0e1e197 --- /dev/null +++ b/src/ssh.c @@ -0,0 +1,688 @@ +#include "ssh.h" +#include "network.h" +#include "log.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +enum { + AUTH_NONE = 0, + AUTH_PASSWORD, + AUTH_PUBLICKEY +}; + +static bool ssh_fingerprint_ok(char * f); + +static void ssh_get_passwd(char * buffer, int len); + +static bool ssh_agent_auth(ssh * s); + +static bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts); + +static bool ssh_thread_running = false; +static int ssh_array_reserved = 0; +static int ssh_array_used = 0; +static ssh ** ssh_array = NULL; +static int ssh_thread_maxfd = 0; + +static int waitsocket(int socket_fd, LIBSSH2_SESSION *session) +{ + struct timeval timeout; + int rc; + fd_set fd; + fd_set *writefd = NULL; + fd_set *readfd = NULL; + int dir; + + timeout.tv_sec = 10; + timeout.tv_usec = 0; + + FD_ZERO(&fd); + + FD_SET(socket_fd, &fd); + + /* now make sure we wait in the correct direction */ + dir = libssh2_session_block_directions(session); + + + if(dir & LIBSSH2_SESSION_BLOCK_INBOUND) + readfd = &fd; + + if(dir & LIBSSH2_SESSION_BLOCK_OUTBOUND) + writefd = &fd; + + rc = select(socket_fd + 1, readfd, writefd, NULL, &timeout); + + return rc; +} + +ssh * ssh_new(char * username, char * addr, int port) +{ + ssh * s = (ssh*)(calloc(1, sizeof(ssh))); + s->user = username; + s->addr = addr; + + s->socket = Network_client(addr, port,100); + s->session = libssh2_session_init(); + if (s->session == NULL) + { + free(s); + log_print(2,"ssh_new", "libssh2_session_init returned NULL"); + return NULL; + } + + + + int err = libssh2_session_handshake(s->session, s->socket); + if (err != 0) + { + free(s); + log_print(2,"ssh_new", "libssh2_session_handshake fails - error code %d", err); + return NULL; + } + s->fingerprint = (char*)(libssh2_hostkey_hash(s->session, LIBSSH2_HOSTKEY_HASH_SHA1)); + if (!ssh_fingerprint_ok(s->fingerprint)) + { + free(s); + log_print(2,"ssh_new", "Fingerprint of host \"%s\" was not OK", addr); + return NULL; + } + + char * userauthlist = libssh2_userauth_list(s->session, username, strlen(username)); + + int auth = AUTH_NONE; + if (strstr(userauthlist, "password")) + auth |= AUTH_PASSWORD; + if (strstr(userauthlist, "publickey")) + auth |= AUTH_PUBLICKEY; + + bool ok = false; + + if (auth & AUTH_PUBLICKEY) + { + // first try connecting with agent + ok = ssh_agent_auth(s); + + + + if (!ok) + { + log_print(3, "ssh_new", "Agent authentication failed, looking at public keys"); + + if (SSH_DIR[0] == '~' && SSH_DIR[1] == '/') + { + char ssh_dir[BUFSIZ]; + sprintf(ssh_dir, "%s/%s",getenv("HOME"),SSH_DIR+2); + ok = ssh_publickey_auth(s, ssh_dir,3); + } + else + ok = ssh_publickey_auth(s, SSH_DIR,3); + } + + } + + if (auth & AUTH_PASSWORD && !ok) + { + log_print(3, "ssh_new", "public keys failed, try password"); + for (int i = 0; i < 3 && !ok; ++i) + { + printf("Password for %s@%s:", username, addr); + char password[BUFSIZ]; + ssh_get_passwd(password, BUFSIZ); + + if (libssh2_userauth_password(s->session, username, password) == 0) + { + ok = true; + } + } + if (!ok) + log_print(3, "ssh_new", "Failed to authenticate by password."); + } + + if (!ok) + { + free(s); + log_print(2, "ssh_new", "All attempts at authenticating failed."); + return NULL; + } + log_print(3, "ssh_new", "Authenticated!"); + + s->reserved_tunnels = 1; + s->tunnel = (ssh_tunnel*)(calloc(s->reserved_tunnels, sizeof(ssh_tunnel))); + s->nTunnels = 0; + libssh2_session_set_blocking(s->session, 0); + return s; +} + +void ssh_destroy(ssh * s) +{ + ssh_thread_del(s); + + for (int i = 0; i < s->nTunnels; ++i) + { + int err; + char buffer[BUFSIZ]; + do + { + err = libssh2_channel_read(s->tunnel[i].channel, buffer, sizeof(buffer)); + write(s->tunnel[i].forward_sock, buffer, err); + } + while (err > 0); + + while ((err = libssh2_channel_close(s->tunnel[i].channel)) == LIBSSH2_ERROR_EAGAIN) + waitsocket(s->socket, s->session); + + libssh2_channel_free(s->tunnel[i].channel); + close(s->tunnel[i].forward_sock); + } + + + libssh2_session_disconnect(s->session, "goodbye"); + libssh2_session_free(s->session); + + free(s->tunnel); +} + +bool ssh_fingerprint_ok(char * f) +{ + //TODO: Check fingerprint + log_print(1, "ssh_fingerprint_ok", "Unimplemented!"); + return true; +} + +void ssh_get_passwd(char * buffer, int len) +{ + struct termios oflags, nflags; + + tcgetattr(fileno(stdin), &oflags); + nflags = oflags; + nflags.c_lflag &= ~ECHO; + nflags.c_lflag |= ECHONL; + + if (tcsetattr(fileno(stdin), TCSANOW, &nflags) != 0) + { + error("ssh_get_passwd", "tcsetattr : %s", strerror(errno)); + } + + fgets(buffer, len * sizeof(char), stdin); + buffer[strlen(buffer) - 1] = '\0'; + + if (tcsetattr(fileno(stdin), TCSANOW, &oflags) != 0) + { + error("ssh_get_passwd", "tcsetattr : %s", strerror(errno)); + } +} + +bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts) +{ + + + DIR * d = opendir(dir); + struct dirent * dp; + if (d == NULL) + { + log_print(0, "ssh_publickey_auth", "Couldn't open directory %s : %s", dir, strerror(errno)); + return false; + } + + while ((dp = readdir(d)) != NULL) + { + + // skip public keys + if (strstr(dp->d_name, ".pub") != NULL) + continue; + + // assume file is a private key + // find corresponding public key + char pub[BUFSIZ]; char priv[BUFSIZ]; + if (dir[strlen(dir)-1] == '/') + { + sprintf(pub, "%s%s.pub", dir,dp->d_name); + sprintf(priv, "%s%s", dir, dp->d_name); + } + else + { + sprintf(pub, "%s/%s.pub", dir,dp->d_name); + sprintf(priv, "%s/%s", dir, dp->d_name); + } + + struct stat t; + if (stat(priv, &t) != 0) + { + log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", priv, strerror(errno)); + continue; + } + + if (!S_ISREG(t.st_mode)) + { + log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", priv); + continue; + } + + if (stat(pub, &t) != 0) + { + log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", pub, strerror(errno)); + continue; + } + + if (!S_ISREG(t.st_mode)) + { + log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", pub); + continue; + } + + + + + + + //libssh2_trace(s->session, LIBSSH2_TRACE_AUTH | LIBSSH2_TRACE_PUBLICKEY); + int err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,""); + if (err == 0) + { + log_print(1, "ssh_publickey_auth", "Shouldn't use keys with no passphrase"); + } + else if (err == LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED) + { + + char passphrase[BUFSIZ]; + for (int i = 0; i < nAttempts; ++i) + { + printf("Passphrase for key %s:", priv); + ssh_get_passwd(passphrase, BUFSIZ); + err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,passphrase); + if (err != LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED) break; + } + } + if (err == 0) + { + closedir(d); + return true; + } + } + closedir(d); + + + return false; + + + + +} + +bool ssh_agent_auth(ssh * s) +{ + LIBSSH2_AGENT * agent = libssh2_agent_init(s->session); + if (agent == NULL) + { + log_print(0, "ssh_agent_auth", "Couldn't initialise agent support."); + return false; + } + + if (libssh2_agent_connect(agent) != 0) + { + log_print(0, "ssh_agent_auth", "Failed to connect to ssh-agent."); + return false; + } + + if (libssh2_agent_list_identities(agent) != 0) + { + log_print(0, "ssh_agent_auth", "Failure requesting identities to ssh-agent."); + return false; + } + + struct libssh2_agent_publickey * identity = NULL; + struct libssh2_agent_publickey * prev_identity = NULL; + + while (true) + { + int err = libssh2_agent_get_identity(agent, &identity, prev_identity); + if (err == 1) + { + log_print(0, "ssh_agent_auth", "Couldn't continue authentication."); + return false; + } + if (err < 0) + { + log_print(0, "ssh_agent_auth", "Failure obtaining identity from ssh-agent support."); + return false; + } + + if (libssh2_agent_userauth(agent, s->user, identity) == 0) + { + log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s succeeded!", s->user, identity->comment); + return true; + } + else + { + log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s failed.", s->user, identity->comment); + } + prev_identity = identity; + } + + return false; + +} + +LIBSSH2_LISTENER * ssh_get_listener(ssh * s, int * port) +{ + pthread_mutex_lock(&ssh_thread_mutex); + libssh2_session_set_blocking(s->session, 1); + //libssh2_trace(s->session, ~0); + LIBSSH2_LISTENER * l = libssh2_channel_forward_listen_ex(s->session, "localhost", *port, port,1); + if (l == NULL) + { + char * error; + libssh2_session_last_error(s->session, &error, NULL, 0); + log_print(0, "ssh_get_listener", "Error: %s", error); + } + libssh2_session_set_blocking(s->session, 0); + pthread_mutex_unlock(&ssh_thread_mutex); + return l; +} + +void ssh_add_tunnel(ssh * s, LIBSSH2_LISTENER * listener, int socket) +{ + pthread_mutex_lock(&ssh_thread_mutex); + //log_print(3, "ssh_add_tunnel", "accepting connection..."); + libssh2_session_set_blocking(s->session , 1); + //libssh2_trace(s->session, ~0); + LIBSSH2_CHANNEL * channel = libssh2_channel_forward_accept(listener); + if (channel == NULL) + { + char * error; + libssh2_session_last_error(s->session, &error, NULL, 0); + log_print(0, "ssh_add_tunnel", "Error: %s", error); + } + libssh2_session_set_blocking(s->session , 0); + //log_print(3, "ssh_add_tunnel", "accepted remote connection..."); + + ssh_tunnel * t = s->tunnel+(s->nTunnels++); + t->forward_sock = socket; + t->channel = channel; + + if (socket > ssh_thread_maxfd) + ssh_thread_maxfd = socket; + + if (s->nTunnels >= s->reserved_tunnels) + { + s->reserved_tunnels *= 2; + s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel))); + } + + pthread_mutex_unlock(&ssh_thread_mutex); +} + +void ssh_exec_swarm(ssh * s, int * port, int * socket, int np) +{ + + // secure things + LIBSSH2_CHANNEL * channel = NULL; + while ((channel = libssh2_channel_open_session(s->session)) == NULL + && libssh2_session_last_error(s->session, NULL, NULL, 0) == LIBSSH2_ERROR_EAGAIN) + { + waitsocket(s->socket, s->session); + } + + if (channel == NULL) + { + error("ssh_exec_swarm", "Couldn't create channel with ssh session"); + } + + + char buffer[BUFSIZ]; + + // connect secure + if (port == NULL && socket != NULL) + { + sprintf(buffer, "%s -r -", options.program); + if (np != 0) + sprintf(buffer, " -n %d", np); + } + else if (port != NULL && socket == NULL) + { + sprintf(buffer, "%s -r $(echo $SSH_CONNECTION | awk \'{print $1}\'):%d", options.program, *port); + if (np != 0) + sprintf(buffer, " -n %d", np); + + } + else + error("ssh_exec_swarm", "Exactly *one* of the port or socket pointers must not be NULL"); + + int err; + while ((err = libssh2_channel_exec(channel, buffer)) == LIBSSH2_ERROR_EAGAIN) + { + waitsocket(s->socket, s->session); + } + + if (socket != NULL) + { + + pthread_mutex_lock(&ssh_thread_mutex); + + ssh_tunnel * t = s->tunnel+(s->nTunnels++); + + t->forward_sock = *socket; + t->channel = channel; + + if (*socket > ssh_thread_maxfd) + ssh_thread_maxfd = *socket; + + + if (s->nTunnels >= s->reserved_tunnels) + { + s->reserved_tunnels *= 2; + s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel))); + } + + pthread_mutex_unlock(&ssh_thread_mutex); + } + else + { + + // read everything and close the channel + while (true) + { + while ((err = libssh2_channel_read(channel, buffer, sizeof(buffer))) > 0); + if (err == LIBSSH2_ERROR_EAGAIN) + { + waitsocket(s->socket, s->session); + } + else + { + break; + } + } + + while ((err = libssh2_channel_close(channel)) == LIBSSH2_ERROR_EAGAIN) + { + waitsocket(s->socket, s->session); + } + libssh2_channel_free(channel); + } + + + + +} + + + + +pthread_mutex_t ssh_thread_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_t ssh_pthread; + +void * ssh_thread(void * args) +{ + + + fd_set readSet; + char buffer[BUFSIZ]; + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 100000; + + while (true) + { + //log_print(1, "ssh_thread", "loop - %d ssh's", ssh_array_used); + FD_ZERO(&readSet); + pthread_mutex_lock(&ssh_thread_mutex); + + if (!ssh_thread_running) break; + + for (int i = 0; i < ssh_array_used; ++i) + { + ssh * s = ssh_array[i]; + if (s == NULL) continue; + for (int j = 0; j < s->nTunnels; ++j) + { + FD_SET(s->tunnel[j].forward_sock, &readSet); + } + } + + pthread_mutex_unlock(&ssh_thread_mutex); + select(ssh_thread_maxfd+1, &readSet, NULL, NULL, &tv); + pthread_mutex_lock(&ssh_thread_mutex); + + for (int i = 0; i < ssh_array_used; ++i) + { + ssh * s = ssh_array[i]; + //log_print(2, "ssh_thread", "array[%d] = %p", i, s); + if (s == NULL) continue; + for (int j = 0; j < s->nTunnels; ++j) + { + //log_print(2, "ssh_thread", "Tunnel number %d, socket %d", j, s->tunnel[j].forward_sock); + if (FD_ISSET(s->tunnel[j].forward_sock, &readSet)) + { + //log_print(2, "ssh_thread", "reading from socket %d", s->tunnel[j].forward_sock); + int len = read(s->tunnel[j].forward_sock, buffer, sizeof(buffer)); + + if (len <= 0) + continue; + buffer[len] = '\0'; + int written = 0; int w = 0; + do + { + //log_print(2, "ssh_thread", "writing %s to channel", buffer); + w = libssh2_channel_write(s->tunnel[j].channel, buffer+written, len-written); + assert(w >= 0); + written += w; + } + while (w > 0 && written < len); + } + while (true) + { + //log_print(2, "ssh_thread", "Try to read from channel %p", s->tunnel[j].channel); + int len = libssh2_channel_read(s->tunnel[j].channel, buffer, sizeof(buffer)); + //log_print(2, "ssh_thread", "Read %s from channel", buffer); + if (len == LIBSSH2_ERROR_EAGAIN) break; + assert(len >= 0); + + int written = 0; int w = 0; + while (written < len) + { + //log_print(2, "ssh_thread", "Wrote %s to socket %d", buffer+written, s->tunnel[j].forward_sock); + w = write(s->tunnel[j].forward_sock, buffer+written, len-written); + written += w; + } + if (libssh2_channel_eof(s->tunnel[j].channel)) + { + //log_print(1, "ssh_thread", "Got to eof in channel %p", s->tunnel[j].channel); + } + } + } + } + pthread_mutex_unlock(&ssh_thread_mutex); + } + + return NULL; +} + +void ssh_thread_add(ssh * s) +{ + pthread_mutex_lock(&ssh_thread_mutex); + + ssh_array_used++; + + bool found = false; + for (int i = 0; (i < ssh_array_reserved && !found); ++i) + { + if (ssh_array[i] == NULL) + { + ssh_array[i] = s; + found = true; + break; + } + } + + + if (!found) + { + int old = ssh_array_reserved; + ssh_array_reserved = (ssh_array_reserved + 1) * 2; + if (ssh_array == NULL) + ssh_array = (ssh**)(calloc(ssh_array_reserved, sizeof(ssh*))); + else + { + ssh_array = (ssh**)(realloc(ssh_array, ssh_array_reserved * sizeof(ssh*))); + for (int i = old+1; i < ssh_array_reserved; ++i) + ssh_array[i] = NULL; + + } + ssh_array[old] = s; + } + + for (int i = 0; i < s->nTunnels; ++i) + { + if (s->tunnel[i].forward_sock > ssh_thread_maxfd) + ssh_thread_maxfd = s->tunnel[i].forward_sock; + } + + if (!ssh_thread_running) + { + ssh_thread_running = true; + sigset_t set; + int err; + sigfillset(&set); + err = pthread_sigmask(SIG_SETMASK, &set, NULL); + if (err != 0) + error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno)); + err = pthread_create(&ssh_pthread, NULL, ssh_thread, NULL); + if (err != 0) + error("ssh_thread_add", "pthread_create : %s", strerror(errno)); + sigemptyset(&set); + err = pthread_sigmask(SIG_SETMASK, &set, NULL); + if (err != 0) + error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno)); + } + + + + pthread_mutex_unlock(&ssh_thread_mutex); +} + +void ssh_thread_del(ssh * s) +{ + pthread_mutex_lock(&ssh_thread_mutex); + + for (int i = 0; i < ssh_array_reserved; ++i) + { + if (ssh_array[i] == s) + { + ssh_array[i] = NULL; + ssh_thread_running = !(--ssh_array_used == 0); + break; + } + } + + pthread_mutex_unlock(&ssh_thread_mutex); +} diff --git a/src/ssh.h b/src/ssh.h new file mode 100644 index 0000000..6d12e99 --- /dev/null +++ b/src/ssh.h @@ -0,0 +1,57 @@ +#ifndef _SSH_H +#define _SSH_H + +#include "network.h" +#include "master.h" +#include "options.h" + +#define SSH_DIR "~/.ssh/" +#define SSH_KNOWN_HOSTS "~/.ssh/known_hosts" + +#include +#include +#include +#include + +typedef struct +{ + int forward_sock; + LIBSSH2_CHANNEL * channel; + int port; + +} ssh_tunnel; + +typedef struct +{ + ssh_tunnel * tunnel; + int nTunnels; + int reserved_tunnels; + + int socket; + LIBSSH2_SESSION *session; + char * fingerprint; + + char * user; + char * addr; +} ssh; + +extern ssh * ssh_new(char * username, char * addr, int port); +extern void ssh_destroy(ssh * s); + +extern void ssh_exec_swarm(ssh * s, int * port, int * socket, int np); +extern LIBSSH2_LISTENER * ssh_get_listener(ssh * s, int * port); +extern void ssh_add_tunnel(ssh * s, LIBSSH2_LISTENER * listener, int socket); + +extern void * ssh_thread(void * args); +extern void ssh_thread_add(ssh * s); +extern void ssh_thread_del(ssh * s); + + +extern pthread_mutex_t ssh_thread_mutex; + +extern pthread_t ssh_pthread; + + +#endif //_SSH_H + +//EOF