Things seem to work...
authorSam Moore <[email protected]>
Sat, 23 Feb 2013 08:31:49 +0000 (16:31 +0800)
committerSam Moore <[email protected]>
Sat, 23 Feb 2013 08:31:49 +0000 (16:31 +0800)
So I'll commit before I break everything!

16 files changed:
src/Makefile
src/debug.c [new file with mode: 0644]
src/log.c
src/log.h
src/main.c
src/master.c
src/master.h
src/network.c
src/network.h
src/options.c
src/options.h
src/options.o [new file with mode: 0644]
src/slave.c
src/slave.h
src/ssh.c [new file with mode: 0644]
src/ssh.h [new file with mode: 0644]

index a9f6729..da2cf0d 100644 (file)
@@ -1,10 +1,10 @@
 # Makefile for swarm
 
 CXX = gcc
-LIBRARIES = -lm -lpthread #-lGL -lglut -lGLU -lpthread
+LIBRARIES = /usr/local/lib/libssh2.a -lm -lpthread -lssl -lcrypto -lz #-lGL -lglut -lGLU -lpthread
 FLAGS = --std=c99 -D_POSIX_C_SOURCE=200112L -Wall -pedantic -g
 PREPROCESSOR_FLAGS = 
-LINK_OBJ = options.o log.o task.o network.o master.o daemon.o slave.o main.o
+LINK_OBJ = options.o log.o task.o network.o ssh.o master.o daemon.o slave.o main.o
 
 
 BIN = swarm
diff --git a/src/debug.c b/src/debug.c
new file mode 100644 (file)
index 0000000..83682bd
--- /dev/null
@@ -0,0 +1,10 @@
+#include <stdarg.h>
+
+void libssh2_debug_fuck(LIBSSH2_SESSION * session, int level, char * fmt, ...)
+{
+       va_list va;
+       va_start(va, fmt);
+       vfprintf(stderr, fmt, va);
+       va_end(va);
+       fprintf(stderr, "\n");
+}
index 867ab64..3a8c55c 100644 (file)
--- a/src/log.c
+++ b/src/log.c
@@ -14,17 +14,17 @@ void log_print(int level, char * funct, char * fmt, ...)
        char severity[BUFSIZ];
        switch (level)
        {
-               case 0:
-                       sprintf(severity, "Error");
+               case LOGERR:
+                       sprintf(severity, "ERROR");
                        break;
-               case 1:
-                       sprintf(severity, "Warning");
+               case LOGWARN:
+                       sprintf(severity, "WARNING");
                        break;
-               case 2:
-                       sprintf(severity, "Notice");
+               case LOGNOTE:
+                       sprintf(severity, "NOTICE");
                        break;
-               case 3:
-                       sprintf(severity, "Info");
+               case LOGINFO:
+                       sprintf(severity, "INFO");
                        break;
                default:
                        sprintf(severity, "DEBUG");
@@ -32,7 +32,7 @@ void log_print(int level, char * funct, char * fmt, ...)
        }
 
        if (funct != NULL)
-               last_len = fprintf(stderr, "%s [%d] : %s in %s - ", options.program, getpid(), severity, funct);
+               last_len = fprintf(stderr, "%s [%d] : %s : %s - ", options.program, getpid(), severity, funct);
        else
        {
                for (int i = 0; i < last_len; ++i);
index 24e5947..75e2780 100644 (file)
--- a/src/log.h
+++ b/src/log.h
@@ -7,6 +7,8 @@
 
 #include <stdarg.h>
 
+enum {LOGERR=0, LOGWARN=1, LOGNOTE=2, LOGINFO=3,LOGDEBUG=4};
+
 extern void log_print(int level, char * funct, char * fmt,...);
 extern void error(char * funct, char * fmt, ...);
 
index a7666dc..baab2b7 100644 (file)
@@ -39,7 +39,10 @@ int main(int argc, char ** argv)
                        Master_main(&options);
        }
        else
+       {
+               fprintf(stderr, "%p %s", options.master_addr, options.master_addr);
                Slave_main(&options);
+       }
 
        exit(EXIT_SUCCESS);
        return 0;       
index e61f5fa..2734d07 100644 (file)
 #include <assert.h>
 #include <ctype.h>
 #include "slave.h"
+#include <string.h>
 #include <setjmp.h>
+#include <sys/types.h>
+#include <pwd.h>
 
 #include <unistd.h>
 #include <regex.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <netinet/tcp.h>
+#include "ssh.h"
 
 //#define THREAD_SENDING // You decided to solve a problem with threads; now you have two problems
 
 // probably not that great to use threads anyway, since it eats one of your cores
 // It probably spends 90% of its time sleeping, and 9.9% unlocking mutexes
 
+// the signal handler now breaks threads... don't use them
+
 #ifdef THREAD_SENDING
 pthread_t sender_thread;
 pthread_mutex_t sender_lock = PTHREAD_MUTEX_INITIALIZER;
@@ -69,12 +75,19 @@ void Master_main(Options * o)
 
 void Master_setup(Options * o)
 {
+       int err = libssh2_init(0);
+       if (err != 0)
+       {       
+               error("Master_setup", "Initialising libssh2 - error code %d", err);
+       }
+       
        signal(SIGCHLD, sigchld_handler);
        master.o = o;
        master.barrier_number = -1;
        master.last_number = -1;
        master.nSlaves = o->nCPU;
        master.running = master.nSlaves;
+       master.nRemote = 0; master.remote_err = NULL;
        if (master.nSlaves == 0)
                error("Master_setup", "No CPUs to start slaves with!");
 
@@ -119,7 +132,6 @@ void Make_slave(int i)
        master.slave[i].in = sv[1];
        master.slave[i].out = sv[1];
        master.slave[i].running = true;
-       master.slave[i].ssh_pid = 0;
 }
 
 void Master_input(char c)
@@ -174,9 +186,9 @@ void Master_input(char c)
                                                log_print(0, "Master_input", "No host specified for ABSORB directive");
                                        char * np = strtok(NULL, " ");
                                        if (np != NULL)
-                                               Master_absorb(cmd, options.port, atoi(np));
+                                               Master_absorb(cmd, atoi(np));
                                        else
-                                               Master_absorb(cmd, options.port, 0);
+                                               Master_absorb(cmd, 0);
                                }
                                else if (strcmp(cmd, "OUTPUT") == 0)
                                {
@@ -367,8 +379,11 @@ void Master_output(int i, char c)
        #endif //THREAD_SENDING
        if (t == NULL)
        {
-               log_print(0, "Master_output", "Read input from %s, but no task assigned!",master.slave[i].name);
-               error(NULL, "Please refrain from echoing three bell characters in a row.");
+               log_print(3, "Master_output", "Echo %c back to slave %d", c, i);
+               write(master.slave[i].in, &c, sizeof(char));
+               //log_print(0, "Master_output", "Read input from %s, but no task assigned!",master.slave[i].name);
+               //error(NULL, "Please refrain from echoing three bell characters in a row.");
+               return;
        }
 
                        
@@ -395,8 +410,10 @@ void Master_output(int i, char c)
                {
                        fprintf(stdout, "%d:\n", t->number);
                }
+               /*
                else if (t->output[t->outlen-1] == '\f')
                {
+                       
                        log_print(2, "Master_output", "Slave %d requests name (%s)", i, master.slave[i].name);
                        static int bufsiz = BUFSIZ;
                        char * buffer = (char*)(calloc(bufsiz, sizeof(char)));
@@ -420,6 +437,7 @@ void Master_output(int i, char c)
                                master.slave[i].task_pool = t2;
                        master.last_number = t2->number;
                }
+               */
                else
                {
                        fprintf(stdout, "%d: %s", t->number, t->output); 
@@ -522,9 +540,9 @@ Task * Master_tasker(int i)
 void Master_loop()
 {
        
-       if (sigsetjmp(env,true) != 0)
+       if (sigsetjmp(env,true) != 0) // completely necessary evil
        {
-               log_print(2, "Master_loop", "Restored from longjmp");
+               //log_print(2, "Master_loop", "Restored from longjmp");
        }
        fd_set readSet;
        //fd_set writeSet;
@@ -539,6 +557,7 @@ void Master_loop()
 
        bool quit = false;
        bool input = true;
+       char buffer[BUFSIZ];
        
        while (!quit)
        {
@@ -569,6 +588,11 @@ void Master_loop()
                {
                        if (master.slave[i].running) FD_SET(master.slave[i].out, &readSet);
                }
+
+               for (int i = 0; i < master.nRemote; ++i)
+               {
+                       FD_SET(master.remote_err[i], &readSet);
+               }
                
                select(master.fd_max+1, &readSet, NULL, NULL, NULL);
                
@@ -611,6 +635,16 @@ void Master_loop()
                                
                        }
                }
+
+               for (int i = 0; i < master.nRemote; ++i)
+               {
+                       if (FD_ISSET(master.remote_err[i], &readSet))
+                       {
+                               int len = read(master.remote_err[i], buffer, sizeof(buffer));
+                               buffer[len] = '\0';
+                               fprintf(stderr, "%s", buffer);
+                       }
+               }
                
        }
 
@@ -648,7 +682,7 @@ void Master_send()
        }
        write(send_task.slave_fd, "\n", 1*sizeof(char));
        master.commands_active++;
-       log_print(3, "Master_sender", "Sent task %d \"%s\" - %d tasks active", send_task.task->number, send_task.task->message, master.commands_active);
+       log_print(3, "Master_sender", "Sent task %d \"%s\" on socket %d - %d tasks active", send_task.task->number, send_task.task->message, send_task.slave_fd, master.commands_active);
 }
 
 #ifdef THREAD_SENDING
@@ -702,6 +736,16 @@ void Master_cleanup()
 
        signal(SIGCHLD, SIG_IGN); // ignore child exits now
 
+       // tell all remote nodes to exit
+       for (int i = 0; i < master.nRemote; ++i)
+       {
+               FILE * f = fdopen(master.remote_err[i], "r+"); setbuf(f, NULL);
+
+               fprintf(f, "exit\n");
+       
+               fclose(f);
+       }
+
        for (int i = 0; i < master.nSlaves; ++i)
        {
                
@@ -717,13 +761,6 @@ void Master_cleanup()
                if (master.slave[i].pid <= 0) 
                {
                        Network_close(master.slave[i].in);
-                       if (master.slave[i].ssh_pid > 0)
-                       {
-                               log_print(2, "Master_cleanup", "Killing ssh instance %d", master.slave[i].ssh_pid);
-                               kill(master.slave[i].ssh_pid, 15);
-                               if (kill(master.slave[i].ssh_pid, 0) == 0)
-                                       kill(master.slave[i].ssh_pid, 9);
-                       }
                }
                else
                {
@@ -740,83 +777,109 @@ void Master_cleanup()
        free(master.buffer);
        if (master.outfile != NULL)
                free(master.outfile);
-}
 
-void * start_server(void * args)
-{
+       libssh2_exit();
 
-       *(int*)(args) = Network_server(*(int*)args);
-       log_print(2, "start_server", "started network server");
-       return NULL;
 }
 
-int Secure_connection(char * addr, int port);
 
-void Master_absorb(char * addr, int port, int np)
+
+
+void Master_absorb(char * addr, int np)
 {
+       int port = 0;
+
+       char * user = strstr(addr, "@");
+       if (user != NULL)
+       {
+               *(user-1) = '\0';
+               char * t = user;
+               user = addr;
+               addr = t;
+       }
+       else
+       {
+               user = getpwuid(geteuid())->pw_name;
+       }
+       log_print(3, "Master_absorb", "User %s at address %s", user, addr);
+
+       // ssh to the host on port 22
+       ssh * s = ssh_new(user, addr, 22);
+       if (s == NULL)
+       {
+               log_print(0, "Master_absorb", "Couldn't ssh to %s@%s", user, addr);
+               return;
+       }
+       
+
+       // work out the name to give to the shells
        char * name = strtok(addr, ":");
        name = strtok(NULL, ":");
        if (name == NULL)
+               name = addr; // default is host:X
+       else
+               *(name-1) = '\0'; // otherwise use name:X
+
+       // setup array of remote stderr file descriptors
+       if (master.nRemote++ == 0)
        {
-               name = addr;
+               master.remote_err = (int*)(calloc(master.nRemote, sizeof(int)));
+               master.remote_reserved = master.nRemote;
        }
-       else
+       else if (master.nRemote >= master.remote_reserved)
        {
-               *(name-1) = '\0';
+               // resize dynamically
+               master.remote_reserved *= 2;
+               master.remote_err = (int*)(realloc(master.remote_err,master.remote_reserved * sizeof(int)));
        }
 
 
-       //log_print(0, "name is %s\n", name);
        
-       int first_ssh = 0;
+       int sfd = -1;
        if (master.o->encrypt)
-               first_ssh = Secure_connection(addr, port);
-
-       //pthread_t ss;
-       //int net_fd = port;
-       //pthread_create(&ss, NULL, start_server, (void*)(&net_fd));
-       
-       char buffer[BUFSIZ];
-       if (fork() == 0)
        {
-               // The alternative to this kind of terrible hack is OpenMPI's "opal"
-               // This involves >1000 lines of operating system independent ways to get an IP address from a local interface 
-               // Which will then be completely useless if there is any sort of NAT involved
-               
-               //freopen("/dev/null", "r", stdin);
-               //freopen("/dev/null", "w", stdout);
-               //freopen("/dev/null", "w", stderr);
+               int sv[2];
+               if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) != 0)
+                       error("Master_absorb", "Couldn't create socket for remote swarm");
+               sfd = sv[0];
 
-               char * cmd = buffer+sprintf(buffer, "swarm -p ");
-               if (master.o->encrypt)
-                       cmd = cmd+sprintf(cmd, "%d -e", port+1000);
-               else
-                       cmd = cmd+sprintf(cmd, "%d -u", port);
-       
-               if (np > 0)
-                       cmd = cmd+sprintf(cmd, " -n %d", np);
-               sprintf(cmd, " -m $(echo $SSH_CONNECTION | awk \'{print $1}\')");
-               log_print(3, "Master_absorb", "Execing %s", buffer);
-               execlp("ssh", "ssh", "-f", addr, buffer, NULL);
+               ssh_exec_swarm(s, NULL, sv+1, np); // start swarm remotely forward it to the socket
+               ssh_thread_add(s); // add the ssh to the thread
        }
-       log_print(3, "Master_absorb", "Listening on port %d", port);
-       int net_fd = Network_server(port);
-       log_print(3, "Master_absorb", "Created network server on port %d", port);
+       else
+       {
+               sfd = Network_server_bind(0, &port); // dynamically bind to a port
+               ssh_exec_swarm(s, &port, NULL, np); // start swarm remotely, have it connect on the port
+               ssh_destroy(s); // don't need the ssh anymore
+               sfd = Network_server_listen(sfd, NULL); // accept connections and pray
+       }
+       master.remote_err[master.nRemote-1] = sfd;
+       if (sfd > master.fd_max)
+               master.fd_max = sfd;
+
 
+       char buffer[BUFSIZ];
+
+       int newSlaves = 0;
+       
+       int len = sprintf(buffer, "%s\n", name);
+       int w = 0;
+       while (w < len)
+               w += write(sfd, buffer+w, len-w);
        
-       char * s = buffer;
-       while (read(net_fd, s, sizeof(char)) != 0)
+       
+       len = 0;
+       do
        {
-               if (*s == '\n')
-               {
-                       *s = '\0';
-                       break;
-               }
-               ++s;
+               len = read(sfd, buffer+len, sizeof(buffer));
+               buffer[len] = '\0';
        }
+       while (buffer[len-1] != '\n');
+       buffer[len-1] = '\0';
+       newSlaves = atoi(buffer);
+
+       
 
-       int newSlaves = atoi(buffer);
-       log_print(3, "Master_absorb", "Absorbing %d slaves from %s\n", newSlaves, addr);
        if (newSlaves == 0)
        {
                error("Master_absorb", "No slaves to absorb from %s", addr);
@@ -825,62 +888,70 @@ void Master_absorb(char * addr, int port, int np)
        master.slave = (Slave*)(realloc(master.slave, (master.nSlaves + newSlaves) * sizeof(Slave)));
        if (master.slave == NULL)
        {
-               error("Master_absorb", "Resizing slave array from %d to %d slaves : %s", master.nSlaves, master.nSlaves + newSlaves, strerror(errno));
+               error("Master_absorb", "Resizing slave array from %d to %d slaves : %s", 
+                       master.nSlaves, master.nSlaves + newSlaves, strerror(errno));
        }
 
-
-       if (master.o->encrypt)
-       {
-               for (int i = 0; i < newSlaves-1; ++i)
-                       master.slave[master.nSlaves+i].ssh_pid = Secure_connection(addr, port+i+1);
-
-       }       
-       master.slave[master.nSlaves+newSlaves-1].ssh_pid = first_ssh;
        
+       for (int i = 0; i < newSlaves; ++i)
+       {
+               int ii = master.nSlaves + i;
 
+               if (master.o->encrypt)
+               {
+                       int sv[2];
+                       if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) != 0)
+                               error("Master_absorb", "Couldn't create socket for remote swarm");
 
+                       
+                       LIBSSH2_LISTENER * listener = NULL;
+                       // libssh2 can't finalise the connection when the port is dynamic (ie: port = 0)
+                       while (listener == NULL)
+                       {
+                               port = 20000 + rand() % 30000; // so pick ports at random until binding succeeds
+                               listener = ssh_get_listener(s, &port); // port forward to the socket
+                       }
 
-       
-
-       char c = '\a';
-       log_print(3, "Master_absorb", "Writing bell to slave");
-       write(net_fd, &c, sizeof(char));
-       
+                       log_print(4,"Master_absorb", "Chose port %d", port);
+                       int len = sprintf(buffer, "%d\n", port);
+                       
+                       int w = 0;
+                       while (w < len)
+                       {
+                               w += write(sfd, buffer+w, len-w);
+                       }
+                       usleep(200000); // give ssh_thread a chance to actually send the data
 
+                       log_print(4, "Master_absorb", "Creating tunnel...");
+                       ssh_add_tunnel(s, listener, sv[1]);
+                       master.slave[ii].in = sv[0];
+                       master.slave[ii].out = sv[0];
 
-       for (int i = 0; i < newSlaves; ++i)
-       {
-               log_print(3, NULL, "Absorbing slave %d...", i);
-               int ii = master.nSlaves + i;
-               if (i == newSlaves-1)
-               {
-                       master.slave[ii].out = net_fd;
+                       log_print(4, "Master_absorb", "Tunnel for slave %d using socket %d<->%d setup", ii, sv[0], sv[1]);
                }
                else
                {
-                       log_print(3, "Master_absorb", "Creating server %d at time %d", i, time(NULL));
-                       write(net_fd, &c, sizeof(char));
-                       master.slave[ii].out = Network_server(port + i + 1);
+                       int tmp = Network_server_bind(0, &port); // bind to a port      
+                       
+                       master.slave[ii].in = Network_server_listen(tmp, NULL); // listen for connection
+                       master.slave[ii].out = master.slave[ii].in;
                }
-               master.slave[ii].in = master.slave[ii].out;
+
+       
+               
                
 
                if (master.slave[ii].out > master.fd_max)
                        master.fd_max = master.slave[ii].out;
 
+               char buffer[BUFSIZ];
                sprintf(buffer, "%s:%d", name, i);
                master.slave[ii].name = strdup(buffer);
                master.slave[ii].addr = strdup(addr);
                master.slave[ii].running = true;
                master.slave[ii].pid = -1;
                master.slave[ii].task = NULL;
-               master.slave[ii].task_pool = NULL;
-
-               FILE * f = fdopen(master.slave[ii].in, "w"); setbuf(f, NULL);
-               fprintf(f, "name=%s\n", master.slave[ii].name);
-
-               log_print(3, NULL, "Done absorbing slave %d...", i);
-       
+               master.slave[ii].task_pool = NULL;      
        }       
 
 
@@ -892,21 +963,6 @@ void Master_absorb(char * addr, int port, int np)
 
 }
 
-int Secure_connection(char * addr, int port)
-{
-       int result = fork();
-       if (result == 0)
-       {
-               char buffer[BUFSIZ];
-               sprintf(buffer, "%d:localhost:%d", port+1000, port);
-               freopen("/dev/null", "r", stdin);
-               freopen("/dev/null", "w", stdout);
-               freopen("/dev/null", "w", stderr);
-               execl("/usr/bin/ssh", "/usr/bin/ssh", "-N", "-R", buffer, addr, NULL);
-       }
-       return result;
-}
-
 void sigchld_handler(int signal)
 {
 
@@ -918,13 +974,6 @@ void sigchld_handler(int signal)
        if (p == -1)
                error("sigchld_handler", "waitpid : %s", strerror(errno));
        
-       if (WIFSIGNALED(s))
-       {
-               int sig = WTERMSIG(s);
-               log_print(2, "sigchld_handler", "A child [%d] was terminated with signal %d; terminating self with same signal", p, sig);
-               kill(getpid(), sig);
-               return;
-       }
 
        int i = 0;
        for (i = 0; i < master.nSlaves; ++i)
@@ -937,13 +986,28 @@ void sigchld_handler(int signal)
                return;
        }
 
-       log_print(1, "sigchld_handler", "Slave %d [%d] exited with code %d; restarting it",i, p, s);
+       fprintf(stderr, "Unexpected exit of slave %s", master.slave[i].name);
+       if (WIFSIGNALED(s))
+       {
+               int sig = WTERMSIG(s);
+               fprintf(stderr, " due to %s", strsignal(sig));
+               if (sig == SIGKILL)
+               {
+                       printf(" - committing suicide\n");
+                       kill(getpid(), sig);
+               }
+       }
+       else
+       {               
+               fprintf(stderr, " return code %d.",s);
+       }
+       fprintf(stderr, " Starting replacement.\n");
 
        Make_slave(i);
 
        if (master.o->end != NULL)
        {
-               log_print(1, "sigchld_handler", "Trying to convince slave %d to be nice", i);
+               //log_print(1, "sigchld_handler", "Trying to convince slave %d to be nice", i);
                char buffer[BUFSIZ];
                sprintf(buffer, "name=%s;echo -en \"%s\"\n", master.slave[i].name, master.o->end);
                if (write(master.slave[i].in, buffer, strlen(buffer)) <= 0)
index 1a2d5e9..3595e23 100644 (file)
@@ -13,7 +13,7 @@ extern void * Master_sender(void * args);
 extern void Master_setup(Options * o);
 extern void Master_cleanup();
 extern void Master_send();
-extern void Master_absorb(char * addr, int port, int np);
+extern void Master_absorb(char * addr, int np);
 
 extern void Make_slave(int i);
 extern void sigchld_handler(int signal);
@@ -43,6 +43,10 @@ typedef struct
 
        int commands_active;
 
+       int * remote_err; // sockets used as stderr for remote shells
+       int nRemote; // number of remote shells
+       int remote_reserved; // number of sockets reserved
+
        Options * o;
 } Master;
 
index 5354ac3..bcd06d4 100644 (file)
@@ -6,9 +6,21 @@
 
 #define h_addr h_addr_list[0]
 
-int Network_server(int port) {return Network_server_r(NULL, port);}
 
-int Network_server_r(char * addr, int port)
+
+
+
+int Network_get_port(int sfd)
+{
+       static struct sockaddr_in sin;
+       static socklen_t len = sizeof(struct sockaddr_in);
+
+       if (getsockname(sfd, (struct sockaddr *)&sin, &len) != 0)
+                error("Network_port", "getsockname : %s", strerror(errno));
+       return ntohs(sin.sin_port);
+}
+
+int Network_server_bind(int port, int * bound)
 {
        int sfd = socket(PF_INET, SOCK_STREAM, 0);
        if (sfd < 0)
@@ -26,6 +38,15 @@ int Network_server_r(char * addr, int port)
        {
                error("Network_server", "Binding socket on port %d : %s", port, strerror(errno));
        }
+
+       if (bound != NULL)
+               *bound = Network_get_port(sfd);
+       return sfd;     
+}
+
+int Network_server_listen(int sfd, char * addr)
+{
+       int port = Network_get_port(sfd);
        if (listen(sfd, 1) < 0)
        {
                error("Network_server", "Listening on port %d : %s", port, strerror(errno));
@@ -52,12 +73,18 @@ int Network_server_r(char * addr, int port)
        assert(sfd >= 0);
 
        return sfd;
-       
+}
+
+int Network_server(char * addr, int port)
+{
+       return Network_server_listen(Network_server_bind(port, &port), addr);
 }
 
 int Network_client(const char * addr, int port, int timeout)
 {
        int sfd = socket(PF_INET, SOCK_STREAM, 0);
+
+       //log_print(2, "Network_client", "Created socket");
        long arg = fcntl(sfd, F_GETFL, NULL);
        arg |= O_NONBLOCK;
        fcntl(sfd, F_SETFL, arg);
@@ -75,6 +102,7 @@ int Network_client(const char * addr, int port, int timeout)
        bcopy ( hp->h_addr, &(server.sin_addr.s_addr), hp->h_length);
        server.sin_port = htons(port);
 
+
        int res = connect(sfd, (struct sockaddr *) &server, sizeof(server));
        
 
@@ -91,9 +119,9 @@ int Network_client(const char * addr, int port, int timeout)
 
                struct timeval * tp;
                tp = (timeout < 0) ? NULL : &tv;
-       
+               
                int err = select(sfd+1, NULL, &writeSet, NULL, tp);
-
+               
                if (err == 0)
                {
                        error("Network_client", "Timed out trying to connect to %s:%d after %d seconds", addr, port, timeout);
@@ -126,7 +154,8 @@ int Network_client(const char * addr, int port, int timeout)
        arg &= (~O_NONBLOCK);
        fcntl(sfd, F_SETFL, arg);
        
-
+       
+       
        return sfd;
 }
 
index b00e595..07c9459 100644 (file)
 #include <strings.h>
 #include <stdarg.h>
 
-extern int Network_server(int port);
-extern int Network_server_r(char * addr, int port);
+extern int Network_get_port(int socket); // get port used by socket
+extern int Network_server(char * addr, int port);
 extern int Network_client(const char * addr, int port, int timeout);
 
+extern int Network_server_bind(int port, int * bound);
+extern int Network_server_listen(int sfd, char * addr);
+
 extern void Network_close(int sfd);
 
 #endif //_NETWORK_H
index 6ef6119..ddd2ddc 100644 (file)
@@ -35,6 +35,8 @@ void close_out()
        fclose(stdout);
 }
 
+char name[BUFSIZ];
+
 void Initialise(int argc, char ** argv, Options * o)
 {
        srand(time(NULL));
@@ -44,9 +46,7 @@ void Initialise(int argc, char ** argv, Options * o)
        o->logfile = NULL;
        o->outfile = NULL;
        o->verbosity = 2;
-       o->port = 4000 + rand() % 1000;
-       o->slavefile = "slaves.swarm";
-       o->dummy_shell = false;
+       o->port = 0;
        o->append = NULL;
        o->prepend = NULL;
        o->end = "\a\a\a";
@@ -54,7 +54,9 @@ void Initialise(int argc, char ** argv, Options * o)
        o->daemon = false;
        o->encrypt = true;
        o->interactive = true;
-       
+
+       gethostname(name, sizeof(name));
+       o->name = strdup(name);
 
        o->master_pid = getpid();
        
@@ -140,11 +142,17 @@ void ParseArguments(int argc, char ** argv, Options * o)
                                        error("ParseArguments", "No argument following %s switch", argv[i]);
                                o->nCPU = atoi(argv[++i]);
                        }
-                       else if (argv[i][1] == 'm')
+                       else if (argv[i][1] == 'r')
                        {
                                if (i >= argc-1)
                                        error("ParseArguments", "No argument following %s switch", argv[i]);
                                o->master_addr = argv[++i];
+                               char * p = strstr(o->master_addr, ":");
+                               if (p != NULL)
+                               {
+                                       *(p-1) = '\0';
+                                       o->port = atoi(p);
+                               }
                        }
                        else if (argv[i][1] == 'c')
                        {
index de365eb..24f35eb 100644 (file)
 typedef struct
 {
        char * program;
+       char * name;
        char * shell;
        char * master_addr;
        char * logfile;
        char * outfile;
        int verbosity;
        int port;
-       char * slavefile;
-       bool dummy_shell;
        char * prepend;
        char * append;
        char * end;
diff --git a/src/options.o b/src/options.o
new file mode 100644 (file)
index 0000000..88fbc13
Binary files /dev/null and b/src/options.o differ
index 3327759..65e11c8 100644 (file)
@@ -1,7 +1,7 @@
-#define _XOPEN_SOURCE
+#define _XOPEN_SOURCE 700
 #define _GNU_SOURCE
 
-//#define _SIMPLE_SLAVE
+
 
 #include "slave.h"
 #include <assert.h>
@@ -12,6 +12,7 @@
 #include <errno.h>
 #include <pty.h>
 #include <fcntl.h>
+#include <string.h>
 
 #include <pthread.h>
 #include <syslog.h>
 
 Slave * slave;
 
+char name[BUFSIZ];
+
+void Slave_shell(int i, char * shell);
+void Slave_cleanup();
 
-int running;
 
 void Slave_main(Options * o)
 {
+
        
-       if (fork() != 0)
-               exit(EXIT_SUCCESS);
+       setbuf(stdin, NULL); setbuf(stdout, NULL); setbuf(stderr, NULL);
+
+       dup2(fileno(stdout), fileno(stderr)); // yes, this works, apparently
 
+       slave = (Slave*)(calloc(o->nCPU, sizeof(slave)));       
+       atexit(Slave_cleanup);
 
-       o->verbosity = 100;
-       freopen(SLAVE_LOGFILE, "w", stderr);
-       setbuf(stderr, NULL);
-       slave = (Slave*)(calloc(o->nCPU, sizeof(Slave)));
 
-       int net_fd = -1;
-       if (o->encrypt)
-               net_fd = Network_client("localhost", o->port,100);
+       if (strcmp(o->master_addr, "-") != 0)
+       {
+               if (fork() != 0)
+                       exit(EXIT_SUCCESS);
+
+               //log_print(2, "Slave_main", "Using unsecured networking; connect to %s:%d", o->master_addr, o->port);
+               //log_print(2, "Slave_main", "Connecting to %s:%d", o->master_addr, o->port);
+               int net_fd = Network_client(o->master_addr, o->port, 100);
+               dup2(net_fd, fileno(stdin));
+               dup2(net_fd, fileno(stdout));
+               dup2(net_fd, fileno(stderr));
+               
+       }
        else
-               net_fd = Network_client(o->master_addr, o->port,100);
+       {
+               o->master_addr = "localhost";
+               //log_print(2, "Slave_main", "Using port forwarding; connect to %s", o->master_addr);
+       }
 
-       FILE * f = fdopen(net_fd, "w"); setbuf(f, NULL);
-       fprintf(f, "%d\n", o->nCPU);
+       char buffer[BUFSIZ];
 
-       log_print(2, "Slave_main", "Waiting on bell from master");
-       char c;
-       if (read(net_fd, &c, sizeof(char)) == 0 || c != '\a')
-               error("Slave_main", "Didn't get bell from master");
-       
+       fgets(name, sizeof(name), stdin);
+       name[strlen(name)-1] = '\0';
+       //log_print(2, "Slave_main", "Got name %s", name);
 
+       fprintf(stdout, "%d\n", o->nCPU);
+       //log_print(2, "Slave_main", "Wrote nCPU %d", o->nCPU);
        
 
-       log_print(2, "Slave_main", "Got bell from master");
-       running = o->nCPU;
+       int port = 0;
        for (int i = 0; i < o->nCPU; ++i)
        {
-               int new_fd = net_fd;
-               if (i != o->nCPU-1)
-               {
+               //log_print(2, "Slave_main", "Waiting for port number...");
+               fgets(buffer, sizeof(buffer), stdin);
+               
+               buffer[strlen(buffer)-1] = '\0';
+               sscanf(buffer, "%d", &port);    
+               //log_print(2, "Slave_main", "Port number %d", port);
+               slave[i].in = Network_client(o->master_addr, port,20);
+               //log_print(2, "Slave_main", "Connected to %s:%d\n", o->master_addr, port);
+               slave[i].out = slave[i].in;
+
+               Slave_shell(i, o->shell);
+       }
+       
 
-                       
-                       if (read(net_fd, &c, sizeof(char)) == 0 || c != '\a')
-                               error("Slave_main", "Didn't get bell from master authorising connection of slave %d", i);
-                       sleep(1);
+       Slave_loop(o);
 
-                       log_print(3, "Slave_main", "Connecting slave %d to port %d at time %d", i, o->port+i+1, time(NULL));
-                       if (o->encrypt)
-                               new_fd = Network_client("localhost", o->port+i+1, 100);
-                       else
-                               new_fd = Network_client(o->master_addr, o->port+i+1, 100);
+       exit(EXIT_SUCCESS);
+}
 
-                       
-                       
-               }
+void Slave_shell(int i, char * shell)
+{
+       slave[i].pid = fork();
 
-               slave[i].in = new_fd; slave[i].out = new_fd;
 
-               slave[i].pid = fork();
-               if (slave[i].pid == 0)
-               {
-                       dup2(slave[i].in, fileno(stdin));
-                       dup2(slave[i].out, fileno(stdout));
-                       execlp(o->shell, o->shell, NULL);
-               }
+
+       if (slave[i].pid == 0)
+       {
+               dup2(slave[i].in, fileno(stdin));
+               dup2(slave[i].out, fileno(stdout));
+               //dup2(error_socket[1], fileno(stderr));
+
+               execlp(shell, shell, NULL);
        }
-       
-       Slave_loop(o);
 
-       free(slave);
-       exit(EXIT_SUCCESS);
+       // if the input is a network socket, this message gets sent to the master
+       // which will then echo it back to the socket and hence the shell
+       FILE * f = fdopen(slave[i].in, "w"); setbuf(f, NULL);
+       fprintf(f, "name=\"%s:%d\"\n", name,i);
 }
 
 void Slave_loop(Options * o)
 {
-       
+       fd_set readSet;
+       struct timeval tv;
+       tv.tv_sec = 0;
+       tv.tv_usec = 100000;
+
        int p = -1; int s = 0;
-       
-       while (running > 0)
+       char buffer[BUFSIZ];
+       while (true)
        {
+               FD_ZERO(&readSet);
+               FD_SET(fileno(stdin), &readSet);
                p = waitpid(-1, &s, 0);
                if (p == -1)
                {
-                       log_print(0, "Slave_loop", "waitpid : %s", strerror(errno));
+                       //log_print(0, "Slave_loop", "waitpid : %s", strerror(errno));
                        continue;
                }
-               if (s != SHELL_EXIT_CODE)
-               {
-                       // there was an error
 
-                       int i = 0;
-                       for (i = 0; i < o->nCPU; ++i)
+               //log_print(3, "Slave_loop", "Detected child %d exiting...", p);
+
+               // check for an exit command from the master
+               select(fileno(stdin) + 1, &readSet, NULL, NULL, &tv);
+
+               if (FD_ISSET(fileno(stdin), &readSet))
+               {
+                       fgets(buffer, sizeof(buffer), stdin);
+                       if (strcmp(buffer, "exit\n") == 0)
                        {
-                               if (slave[i].pid == p) break;
+                               log_print(2, "Slave_loop", "Received notification of exit.\n");
+                               exit(EXIT_SUCCESS);
                        }
-                       if (i >= o->nCPU)
-                               error("Slave_loop", "No child matches pid %d", p);
+               }
+               
+               int i = 0;
+               for (i = 0; i < o->nCPU; ++i)
+               {
+                       if (slave[i].pid == p) break;
+               }
+               if (i >= o->nCPU)
+                       error("Slave_loop", "No child matches pid %d", p);
+
 
-                       log_print(0, "Slave_loop", "Child [%d] exits with status %d; restarting", p, s);
-                       slave[i].pid = fork();
-                       if (slave[i].pid == 0)
+               
+               fprintf(stderr,"Unexpected exit of slave %s:%d", name, i);
+               if (WIFSIGNALED(s))
+               {
+                       int sig = WTERMSIG(s);
+                       fprintf(stderr," due to %s", strsignal(sig));
+                       if (sig == SIGKILL)
                        {
-                               dup2(slave[i].in, fileno(stdin));
-                               dup2(slave[i].out, fileno(stdout));
-                               execlp(o->shell, o->shell, NULL);
+                               fprintf(stderr," - %s committing suicide\n", name);
+                               kill(getpid(), sig);
                        }
-
-                       char buffer[] = "\f\a\a\a";
-                       if (write(slave[i].in, buffer, strlen(buffer)) <= 0)
-                               log_print(0, "Slave_loop", "Slave %d input closed", i);
                }
                else
-                       --running;
+               {               
+                       fprintf(stderr," return code %d.", s);
+               }
+               
+
+               // cancel any tasks at the master for this slave
+               static int len = -1;
+               if (len < 0)
+                       len = strlen(o->end);
+               write(slave[i].out, o->end, len);
+
+               Slave_shell(i, o->shell);
+
+               
        }
 }
+
+void Slave_cleanup()
+{
+       for (int i = 0; i < options.nCPU; ++i)
+       {
+               kill(slave[i].pid, SIGTERM);
+       }
+       sleep(1);
+       for (int i = 0; i < options.nCPU; ++i)
+       {
+               kill(slave[i].pid, SIGKILL);
+       }
+       free(slave);
+}
+
+
index 40b16ad..bf4db8f 100644 (file)
@@ -17,7 +17,7 @@ typedef struct
        Task * task_pool; // tasks specific to the slave
 
 
-       int ssh_pid; 
+       
 
        bool running;
 } Slave;
diff --git a/src/ssh.c b/src/ssh.c
new file mode 100644 (file)
index 0000000..0e1e197
--- /dev/null
+++ b/src/ssh.c
@@ -0,0 +1,688 @@
+#include "ssh.h"
+#include "network.h"
+#include "log.h"
+#include <termios.h>
+#include <dirent.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <errno.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <sys/time.h>
+#include <sys/select.h>
+#include <assert.h>
+#include <signal.h>
+
+
+enum {
+    AUTH_NONE = 0,
+    AUTH_PASSWORD,
+    AUTH_PUBLICKEY
+};
+
+static bool ssh_fingerprint_ok(char * f);
+
+static void ssh_get_passwd(char * buffer, int len);
+
+static bool ssh_agent_auth(ssh * s);
+
+static bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts);
+
+static bool ssh_thread_running = false;
+static int ssh_array_reserved = 0;
+static int ssh_array_used = 0;
+static ssh ** ssh_array = NULL;
+static int ssh_thread_maxfd = 0;
+
+static int waitsocket(int socket_fd, LIBSSH2_SESSION *session)
+{
+    struct timeval timeout;
+    int rc;
+    fd_set fd;
+    fd_set *writefd = NULL;
+    fd_set *readfd = NULL;
+    int dir;
+    timeout.tv_sec = 10;
+    timeout.tv_usec = 0;
+    FD_ZERO(&fd);
+    FD_SET(socket_fd, &fd);
+    /* now make sure we wait in the correct direction */ 
+    dir = libssh2_session_block_directions(session);
+
+    if(dir & LIBSSH2_SESSION_BLOCK_INBOUND)
+        readfd = &fd;
+    if(dir & LIBSSH2_SESSION_BLOCK_OUTBOUND)
+        writefd = &fd;
+    rc = select(socket_fd + 1, readfd, writefd, NULL, &timeout);
+    return rc;
+}
+
+ssh * ssh_new(char * username, char * addr, int port)
+{
+       ssh * s = (ssh*)(calloc(1, sizeof(ssh)));
+       s->user = username;
+       s->addr = addr;
+
+       s->socket = Network_client(addr, port,100);
+       s->session = libssh2_session_init();
+       if (s->session == NULL)
+       {
+               free(s);
+               log_print(2,"ssh_new", "libssh2_session_init returned NULL");
+               return NULL;
+       }
+
+
+
+       int err = libssh2_session_handshake(s->session, s->socket);
+       if (err != 0)
+       {
+               free(s);
+               log_print(2,"ssh_new", "libssh2_session_handshake fails - error code %d", err);
+               return NULL;
+       }
+       s->fingerprint = (char*)(libssh2_hostkey_hash(s->session, LIBSSH2_HOSTKEY_HASH_SHA1));
+       if (!ssh_fingerprint_ok(s->fingerprint))
+       {
+               free(s);
+               log_print(2,"ssh_new", "Fingerprint of host \"%s\" was not OK", addr);
+               return NULL;
+       }       
+       
+       char * userauthlist = libssh2_userauth_list(s->session, username, strlen(username));
+
+       int auth = AUTH_NONE;
+       if (strstr(userauthlist, "password"))
+               auth |= AUTH_PASSWORD;
+       if (strstr(userauthlist, "publickey"))
+               auth |= AUTH_PUBLICKEY;
+
+       bool ok = false;
+
+       if (auth & AUTH_PUBLICKEY)
+       {
+               // first try connecting with agent
+               ok = ssh_agent_auth(s);
+
+               
+               
+               if (!ok)
+               {
+                       log_print(3, "ssh_new", "Agent authentication failed, looking at public keys");
+
+                       if (SSH_DIR[0] == '~' && SSH_DIR[1] == '/')
+                       {
+                               char ssh_dir[BUFSIZ];
+                               sprintf(ssh_dir, "%s/%s",getenv("HOME"),SSH_DIR+2);
+                               ok = ssh_publickey_auth(s, ssh_dir,3);
+                       }               
+                       else
+                               ok = ssh_publickey_auth(s, SSH_DIR,3);
+               }
+               
+       }
+       
+       if (auth & AUTH_PASSWORD && !ok)
+       {
+               log_print(3, "ssh_new", "public keys failed, try password");
+               for (int i = 0; i < 3 && !ok; ++i)
+               {
+                       printf("Password for %s@%s:", username, addr);
+                       char password[BUFSIZ];
+                       ssh_get_passwd(password, BUFSIZ);
+       
+                       if (libssh2_userauth_password(s->session, username, password) == 0)
+                       {
+                               ok = true;
+                       }
+               }
+               if (!ok)
+                       log_print(3, "ssh_new", "Failed to authenticate by password.");
+       }
+       
+       if (!ok)
+       {
+               free(s);
+               log_print(2, "ssh_new", "All attempts at authenticating failed.");
+               return NULL;
+       }
+       log_print(3, "ssh_new", "Authenticated!");
+
+       s->reserved_tunnels = 1;
+       s->tunnel = (ssh_tunnel*)(calloc(s->reserved_tunnels, sizeof(ssh_tunnel)));
+       s->nTunnels = 0;
+       libssh2_session_set_blocking(s->session, 0);
+       return s;
+}
+
+void ssh_destroy(ssh * s)
+{
+       ssh_thread_del(s);
+
+       for (int i = 0; i < s->nTunnels; ++i)
+       {
+               int err;
+               char buffer[BUFSIZ];
+               do
+               {
+                       err = libssh2_channel_read(s->tunnel[i].channel, buffer, sizeof(buffer));
+                       write(s->tunnel[i].forward_sock, buffer, err);
+               }
+               while (err > 0);
+
+               while ((err = libssh2_channel_close(s->tunnel[i].channel)) == LIBSSH2_ERROR_EAGAIN)
+                       waitsocket(s->socket, s->session);
+               
+               libssh2_channel_free(s->tunnel[i].channel);
+               close(s->tunnel[i].forward_sock);
+       }
+
+
+       libssh2_session_disconnect(s->session, "goodbye");
+       libssh2_session_free(s->session);
+
+       free(s->tunnel);
+}
+
+bool ssh_fingerprint_ok(char * f)
+{
+       //TODO: Check fingerprint
+       log_print(1, "ssh_fingerprint_ok", "Unimplemented!");
+       return true;
+}
+
+void ssh_get_passwd(char * buffer, int len)
+{
+       struct termios oflags, nflags;
+    
+       tcgetattr(fileno(stdin), &oflags);
+       nflags = oflags;
+       nflags.c_lflag &= ~ECHO;
+       nflags.c_lflag |= ECHONL;
+
+       if (tcsetattr(fileno(stdin), TCSANOW, &nflags) != 0) 
+       {
+               error("ssh_get_passwd", "tcsetattr : %s", strerror(errno));
+       }
+
+       fgets(buffer, len * sizeof(char), stdin);
+       buffer[strlen(buffer) - 1] = '\0';
+   
+       if (tcsetattr(fileno(stdin), TCSANOW, &oflags) != 0)
+       {
+               error("ssh_get_passwd", "tcsetattr : %s", strerror(errno));
+       }
+}
+
+bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts)
+{
+
+       
+       DIR * d = opendir(dir);
+       struct dirent * dp;
+       if (d == NULL)
+       {
+               log_print(0, "ssh_publickey_auth", "Couldn't open directory %s : %s", dir, strerror(errno));
+               return false;
+       }
+
+       while ((dp = readdir(d)) != NULL)
+       {
+               
+               // skip public keys
+               if (strstr(dp->d_name, ".pub") != NULL)
+                       continue;
+               
+               // assume file is a private key 
+               // find corresponding public key
+               char pub[BUFSIZ]; char priv[BUFSIZ];
+               if (dir[strlen(dir)-1] == '/')
+               {
+                       sprintf(pub, "%s%s.pub", dir,dp->d_name);
+                       sprintf(priv, "%s%s", dir, dp->d_name);
+               }
+               else
+               {
+                       sprintf(pub, "%s/%s.pub", dir,dp->d_name);
+                       sprintf(priv, "%s/%s", dir, dp->d_name);
+               }
+               
+               struct stat t;
+               if (stat(priv, &t) != 0)
+               {
+                       log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", priv, strerror(errno));
+                       continue;
+               }
+
+               if (!S_ISREG(t.st_mode))
+               {
+                       log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", priv);
+                       continue;
+               }
+
+               if (stat(pub, &t) != 0)
+               {
+                       log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", pub, strerror(errno));
+                       continue;
+               }
+
+               if (!S_ISREG(t.st_mode))
+               {
+                       log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", pub);
+                       continue;
+               }
+       
+
+
+
+               
+                       
+               //libssh2_trace(s->session, LIBSSH2_TRACE_AUTH | LIBSSH2_TRACE_PUBLICKEY);
+               int err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,"");
+               if (err == 0)
+               {
+                       log_print(1, "ssh_publickey_auth", "Shouldn't use keys with no passphrase");
+               }
+               else if (err == LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED)
+               {
+
+                       char passphrase[BUFSIZ];
+                       for (int i = 0; i < nAttempts; ++i)
+                       {
+                               printf("Passphrase for key %s:", priv);
+                               ssh_get_passwd(passphrase, BUFSIZ);
+                               err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,passphrase);
+                               if (err != LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED) break;
+                       }
+               }
+               if (err == 0)
+               {
+                       closedir(d);
+                       return true;
+               }
+       }
+       closedir(d);
+
+       
+       return false;
+
+       
+
+       
+}
+
+bool ssh_agent_auth(ssh * s)
+{
+       LIBSSH2_AGENT * agent = libssh2_agent_init(s->session);
+       if (agent == NULL)
+       {
+               log_print(0, "ssh_agent_auth", "Couldn't initialise agent support.");
+               return false;
+       }
+
+       if (libssh2_agent_connect(agent) != 0)
+       {
+               log_print(0, "ssh_agent_auth", "Failed to connect to ssh-agent.");
+               return false;
+       }
+
+       if (libssh2_agent_list_identities(agent) != 0)
+       {
+               log_print(0, "ssh_agent_auth", "Failure requesting identities to ssh-agent.");
+               return false;
+       }
+
+       struct libssh2_agent_publickey * identity = NULL;
+       struct libssh2_agent_publickey * prev_identity = NULL;
+
+       while (true)
+       {
+               int err = libssh2_agent_get_identity(agent, &identity, prev_identity);
+               if (err == 1)
+               {
+                       log_print(0, "ssh_agent_auth", "Couldn't continue authentication.");
+                       return false;
+               }
+               if (err < 0)
+               {
+                       log_print(0, "ssh_agent_auth", "Failure obtaining identity from ssh-agent support.");
+                       return false;
+               }
+
+               if (libssh2_agent_userauth(agent, s->user, identity) == 0)
+               {
+                       log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s succeeded!", s->user, identity->comment);
+                       return true;
+               }
+               else
+               {
+                       log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s failed.", s->user, identity->comment);
+               }
+               prev_identity = identity;
+       }
+
+       return false;
+
+}
+
+LIBSSH2_LISTENER * ssh_get_listener(ssh * s, int * port)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+       libssh2_session_set_blocking(s->session, 1);
+       //libssh2_trace(s->session, ~0);
+       LIBSSH2_LISTENER * l = libssh2_channel_forward_listen_ex(s->session, "localhost", *port, port,1);
+       if (l == NULL)
+       {
+               char * error;
+               libssh2_session_last_error(s->session, &error, NULL, 0);
+               log_print(0, "ssh_get_listener", "Error: %s", error);
+       }
+       libssh2_session_set_blocking(s->session, 0);
+       pthread_mutex_unlock(&ssh_thread_mutex);
+       return l;
+}
+
+void ssh_add_tunnel(ssh * s, LIBSSH2_LISTENER * listener, int socket)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+       //log_print(3, "ssh_add_tunnel", "accepting connection...");
+       libssh2_session_set_blocking(s->session , 1);
+       //libssh2_trace(s->session, ~0);
+       LIBSSH2_CHANNEL * channel = libssh2_channel_forward_accept(listener);
+       if (channel == NULL)
+       {
+               char * error;
+               libssh2_session_last_error(s->session, &error, NULL, 0);
+               log_print(0, "ssh_add_tunnel", "Error: %s", error);
+       }
+       libssh2_session_set_blocking(s->session , 0);
+       //log_print(3, "ssh_add_tunnel", "accepted remote connection...");
+       
+       ssh_tunnel * t = s->tunnel+(s->nTunnels++);
+       t->forward_sock = socket;
+       t->channel = channel;
+
+       if (socket > ssh_thread_maxfd)
+               ssh_thread_maxfd = socket;
+       
+       if (s->nTunnels >= s->reserved_tunnels)
+       {
+               s->reserved_tunnels *= 2;
+               s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel)));
+       }
+       
+       pthread_mutex_unlock(&ssh_thread_mutex);
+}
+
+void ssh_exec_swarm(ssh * s, int * port, int * socket, int np)
+{
+
+       // secure things
+       LIBSSH2_CHANNEL * channel = NULL;
+       while ((channel = libssh2_channel_open_session(s->session)) == NULL 
+               && libssh2_session_last_error(s->session, NULL, NULL, 0) == LIBSSH2_ERROR_EAGAIN)
+       {
+               waitsocket(s->socket, s->session);
+       }
+       
+       if (channel == NULL)
+       {
+               error("ssh_exec_swarm", "Couldn't create channel with ssh session");
+       }
+
+
+       char buffer[BUFSIZ];
+       
+       // connect secure
+       if (port == NULL && socket != NULL)
+       {
+               sprintf(buffer, "%s -r -", options.program);
+               if (np != 0)
+                       sprintf(buffer, " -n %d", np);
+       }
+       else if (port != NULL && socket == NULL)
+       {
+               sprintf(buffer, "%s -r $(echo $SSH_CONNECTION | awk \'{print $1}\'):%d", options.program, *port);
+               if (np != 0)
+                       sprintf(buffer, " -n %d", np);
+
+       }
+       else
+               error("ssh_exec_swarm", "Exactly *one* of the port or socket pointers must not be NULL");
+
+       int err;
+       while ((err = libssh2_channel_exec(channel, buffer)) == LIBSSH2_ERROR_EAGAIN)
+       {
+               waitsocket(s->socket, s->session);
+       }
+
+       if (socket != NULL)
+       {
+               
+               pthread_mutex_lock(&ssh_thread_mutex);
+
+               ssh_tunnel * t = s->tunnel+(s->nTunnels++);
+
+               t->forward_sock = *socket;
+               t->channel = channel;
+
+               if (*socket > ssh_thread_maxfd)
+                       ssh_thread_maxfd = *socket;
+
+       
+               if (s->nTunnels >= s->reserved_tunnels)
+               {
+                       s->reserved_tunnels *= 2;
+                       s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel)));
+               }
+
+               pthread_mutex_unlock(&ssh_thread_mutex);
+       }
+       else
+       {
+               
+               // read everything and close the channel
+               while (true)
+               {
+                       while ((err = libssh2_channel_read(channel, buffer, sizeof(buffer))) > 0);
+                       if (err == LIBSSH2_ERROR_EAGAIN)
+                       {
+                               waitsocket(s->socket, s->session);
+                       }
+                       else
+                       {
+                               break;
+                       }
+               }
+
+               while ((err = libssh2_channel_close(channel)) == LIBSSH2_ERROR_EAGAIN)
+               {
+                       waitsocket(s->socket, s->session);
+               }
+               libssh2_channel_free(channel);
+       }
+       
+
+
+
+}
+
+
+
+
+pthread_mutex_t ssh_thread_mutex = PTHREAD_MUTEX_INITIALIZER;
+pthread_t ssh_pthread;
+
+void * ssh_thread(void * args)
+{
+
+
+       fd_set readSet;
+       char buffer[BUFSIZ];
+       struct timeval tv;
+       tv.tv_sec = 0;
+       tv.tv_usec = 100000;
+       
+       while (true)
+       {
+               //log_print(1, "ssh_thread", "loop - %d ssh's", ssh_array_used);
+               FD_ZERO(&readSet);
+               pthread_mutex_lock(&ssh_thread_mutex);
+
+               if (!ssh_thread_running) break;
+
+               for (int i = 0; i < ssh_array_used; ++i)
+               {
+                       ssh * s = ssh_array[i];
+                       if (s == NULL) continue;
+                       for (int j = 0; j < s->nTunnels; ++j)
+                       {
+                               FD_SET(s->tunnel[j].forward_sock, &readSet);
+                       }
+               }
+
+               pthread_mutex_unlock(&ssh_thread_mutex);
+               select(ssh_thread_maxfd+1, &readSet, NULL, NULL, &tv);
+               pthread_mutex_lock(&ssh_thread_mutex);
+
+               for (int i = 0; i < ssh_array_used; ++i)
+               {
+                       ssh * s = ssh_array[i];
+                       //log_print(2, "ssh_thread", "array[%d] = %p", i, s);
+                       if (s == NULL) continue;
+                       for (int j = 0; j < s->nTunnels; ++j)
+                       {
+                               //log_print(2, "ssh_thread", "Tunnel number %d, socket %d", j, s->tunnel[j].forward_sock);
+                               if (FD_ISSET(s->tunnel[j].forward_sock, &readSet))
+                               {
+                                       //log_print(2, "ssh_thread", "reading from socket %d", s->tunnel[j].forward_sock);
+                                       int len = read(s->tunnel[j].forward_sock, buffer, sizeof(buffer));
+                                       
+                                       if (len <= 0)
+                                               continue;
+                                       buffer[len] = '\0';
+                                       int written = 0; int w = 0;
+                                       do
+                                       {
+                                               //log_print(2, "ssh_thread", "writing %s to channel", buffer);
+                                               w = libssh2_channel_write(s->tunnel[j].channel, buffer+written, len-written);
+                                               assert(w >= 0);
+                                               written += w;
+                                       }
+                                       while (w > 0 && written < len);
+                               }
+                               while (true)
+                               {
+                                       //log_print(2, "ssh_thread", "Try to read from channel %p", s->tunnel[j].channel);
+                                       int len = libssh2_channel_read(s->tunnel[j].channel, buffer, sizeof(buffer));
+                                       //log_print(2, "ssh_thread", "Read %s from channel", buffer);
+                                       if (len == LIBSSH2_ERROR_EAGAIN) break;
+                                       assert(len >= 0);
+
+                                       int written = 0; int w = 0;
+                                       while (written < len)
+                                       {
+                                               //log_print(2, "ssh_thread", "Wrote %s to socket %d", buffer+written, s->tunnel[j].forward_sock);
+                                               w = write(s->tunnel[j].forward_sock, buffer+written, len-written);
+                                               written += w;
+                                       }
+                                       if (libssh2_channel_eof(s->tunnel[j].channel))
+                                       {
+                                               //log_print(1, "ssh_thread", "Got to eof in channel %p", s->tunnel[j].channel);
+                                       }
+                               }
+                       }
+               }
+               pthread_mutex_unlock(&ssh_thread_mutex);
+       }
+
+       return NULL;
+}
+
+void ssh_thread_add(ssh * s)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+
+       ssh_array_used++;
+
+       bool found = false;
+       for (int i = 0; (i < ssh_array_reserved && !found); ++i)
+       {
+               if (ssh_array[i] == NULL)
+               {
+                       ssh_array[i] = s;
+                       found = true;
+                       break;
+               }
+       }
+
+
+       if (!found)
+       {
+               int old = ssh_array_reserved;
+               ssh_array_reserved = (ssh_array_reserved + 1) * 2;
+               if (ssh_array == NULL)
+                       ssh_array = (ssh**)(calloc(ssh_array_reserved, sizeof(ssh*)));
+               else
+               {
+                       ssh_array = (ssh**)(realloc(ssh_array, ssh_array_reserved * sizeof(ssh*)));
+                       for (int i = old+1; i < ssh_array_reserved; ++i)
+                               ssh_array[i] = NULL;
+                       
+               }
+               ssh_array[old] = s;
+       }
+
+       for (int i = 0; i < s->nTunnels; ++i)
+       {
+               if (s->tunnel[i].forward_sock > ssh_thread_maxfd)
+                       ssh_thread_maxfd = s->tunnel[i].forward_sock;
+       }
+
+       if (!ssh_thread_running)
+       {
+               ssh_thread_running = true;
+               sigset_t set;
+               int err;
+               sigfillset(&set);
+               err = pthread_sigmask(SIG_SETMASK, &set, NULL);
+               if (err != 0)
+                       error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno));
+               err = pthread_create(&ssh_pthread, NULL, ssh_thread, NULL);
+               if (err != 0)
+                       error("ssh_thread_add", "pthread_create : %s", strerror(errno));
+               sigemptyset(&set);
+               err = pthread_sigmask(SIG_SETMASK, &set, NULL);
+               if (err != 0)
+                       error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno));
+       }
+
+       
+
+       pthread_mutex_unlock(&ssh_thread_mutex);
+}
+
+void ssh_thread_del(ssh * s)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+       
+       for (int i = 0; i < ssh_array_reserved; ++i)
+       {
+               if (ssh_array[i] == s)
+               {
+                       ssh_array[i] = NULL;
+                       ssh_thread_running = !(--ssh_array_used == 0);
+                       break;
+               }
+       }
+
+       pthread_mutex_unlock(&ssh_thread_mutex);
+}
diff --git a/src/ssh.h b/src/ssh.h
new file mode 100644 (file)
index 0000000..6d12e99
--- /dev/null
+++ b/src/ssh.h
@@ -0,0 +1,57 @@
+#ifndef _SSH_H
+#define _SSH_H
+
+#include "network.h"
+#include "master.h"
+#include "options.h"
+
+#define SSH_DIR "~/.ssh/"
+#define SSH_KNOWN_HOSTS "~/.ssh/known_hosts"
+
+#include <libssh2.h>
+#include <pthread.h>
+#include <sys/select.h>
+#include <sys/fcntl.h>
+
+typedef struct
+{
+       int forward_sock;
+       LIBSSH2_CHANNEL * channel;
+       int port;
+       
+} ssh_tunnel;
+
+typedef struct
+{
+       ssh_tunnel * tunnel;
+       int nTunnels;
+       int reserved_tunnels;
+
+       int socket;
+       LIBSSH2_SESSION *session;
+       char * fingerprint;
+
+       char * user;
+       char * addr;
+} ssh;
+
+extern ssh * ssh_new(char * username, char * addr, int port);
+extern void ssh_destroy(ssh * s);
+
+extern void ssh_exec_swarm(ssh * s, int * port, int * socket, int np);
+extern LIBSSH2_LISTENER * ssh_get_listener(ssh * s, int * port);
+extern void ssh_add_tunnel(ssh * s, LIBSSH2_LISTENER * listener, int socket);
+
+extern void * ssh_thread(void * args);
+extern void ssh_thread_add(ssh * s);
+extern void ssh_thread_del(ssh * s);
+
+
+extern pthread_mutex_t ssh_thread_mutex;
+
+extern pthread_t ssh_pthread;
+
+
+#endif //_SSH_H
+
+//EOF

UCC git Repository :: git.ucc.asn.au