Things seem to work...
[matches/swarm.git] / src / master.c
index e61f5fa..2734d07 100644 (file)
 #include <assert.h>
 #include <ctype.h>
 #include "slave.h"
+#include <string.h>
 #include <setjmp.h>
+#include <sys/types.h>
+#include <pwd.h>
 
 #include <unistd.h>
 #include <regex.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <netinet/tcp.h>
+#include "ssh.h"
 
 //#define THREAD_SENDING // You decided to solve a problem with threads; now you have two problems
 
 // probably not that great to use threads anyway, since it eats one of your cores
 // It probably spends 90% of its time sleeping, and 9.9% unlocking mutexes
 
+// the signal handler now breaks threads... don't use them
+
 #ifdef THREAD_SENDING
 pthread_t sender_thread;
 pthread_mutex_t sender_lock = PTHREAD_MUTEX_INITIALIZER;
@@ -69,12 +75,19 @@ void Master_main(Options * o)
 
 void Master_setup(Options * o)
 {
+       int err = libssh2_init(0);
+       if (err != 0)
+       {       
+               error("Master_setup", "Initialising libssh2 - error code %d", err);
+       }
+       
        signal(SIGCHLD, sigchld_handler);
        master.o = o;
        master.barrier_number = -1;
        master.last_number = -1;
        master.nSlaves = o->nCPU;
        master.running = master.nSlaves;
+       master.nRemote = 0; master.remote_err = NULL;
        if (master.nSlaves == 0)
                error("Master_setup", "No CPUs to start slaves with!");
 
@@ -119,7 +132,6 @@ void Make_slave(int i)
        master.slave[i].in = sv[1];
        master.slave[i].out = sv[1];
        master.slave[i].running = true;
-       master.slave[i].ssh_pid = 0;
 }
 
 void Master_input(char c)
@@ -174,9 +186,9 @@ void Master_input(char c)
                                                log_print(0, "Master_input", "No host specified for ABSORB directive");
                                        char * np = strtok(NULL, " ");
                                        if (np != NULL)
-                                               Master_absorb(cmd, options.port, atoi(np));
+                                               Master_absorb(cmd, atoi(np));
                                        else
-                                               Master_absorb(cmd, options.port, 0);
+                                               Master_absorb(cmd, 0);
                                }
                                else if (strcmp(cmd, "OUTPUT") == 0)
                                {
@@ -367,8 +379,11 @@ void Master_output(int i, char c)
        #endif //THREAD_SENDING
        if (t == NULL)
        {
-               log_print(0, "Master_output", "Read input from %s, but no task assigned!",master.slave[i].name);
-               error(NULL, "Please refrain from echoing three bell characters in a row.");
+               log_print(3, "Master_output", "Echo %c back to slave %d", c, i);
+               write(master.slave[i].in, &c, sizeof(char));
+               //log_print(0, "Master_output", "Read input from %s, but no task assigned!",master.slave[i].name);
+               //error(NULL, "Please refrain from echoing three bell characters in a row.");
+               return;
        }
 
                        
@@ -395,8 +410,10 @@ void Master_output(int i, char c)
                {
                        fprintf(stdout, "%d:\n", t->number);
                }
+               /*
                else if (t->output[t->outlen-1] == '\f')
                {
+                       
                        log_print(2, "Master_output", "Slave %d requests name (%s)", i, master.slave[i].name);
                        static int bufsiz = BUFSIZ;
                        char * buffer = (char*)(calloc(bufsiz, sizeof(char)));
@@ -420,6 +437,7 @@ void Master_output(int i, char c)
                                master.slave[i].task_pool = t2;
                        master.last_number = t2->number;
                }
+               */
                else
                {
                        fprintf(stdout, "%d: %s", t->number, t->output); 
@@ -522,9 +540,9 @@ Task * Master_tasker(int i)
 void Master_loop()
 {
        
-       if (sigsetjmp(env,true) != 0)
+       if (sigsetjmp(env,true) != 0) // completely necessary evil
        {
-               log_print(2, "Master_loop", "Restored from longjmp");
+               //log_print(2, "Master_loop", "Restored from longjmp");
        }
        fd_set readSet;
        //fd_set writeSet;
@@ -539,6 +557,7 @@ void Master_loop()
 
        bool quit = false;
        bool input = true;
+       char buffer[BUFSIZ];
        
        while (!quit)
        {
@@ -569,6 +588,11 @@ void Master_loop()
                {
                        if (master.slave[i].running) FD_SET(master.slave[i].out, &readSet);
                }
+
+               for (int i = 0; i < master.nRemote; ++i)
+               {
+                       FD_SET(master.remote_err[i], &readSet);
+               }
                
                select(master.fd_max+1, &readSet, NULL, NULL, NULL);
                
@@ -611,6 +635,16 @@ void Master_loop()
                                
                        }
                }
+
+               for (int i = 0; i < master.nRemote; ++i)
+               {
+                       if (FD_ISSET(master.remote_err[i], &readSet))
+                       {
+                               int len = read(master.remote_err[i], buffer, sizeof(buffer));
+                               buffer[len] = '\0';
+                               fprintf(stderr, "%s", buffer);
+                       }
+               }
                
        }
 
@@ -648,7 +682,7 @@ void Master_send()
        }
        write(send_task.slave_fd, "\n", 1*sizeof(char));
        master.commands_active++;
-       log_print(3, "Master_sender", "Sent task %d \"%s\" - %d tasks active", send_task.task->number, send_task.task->message, master.commands_active);
+       log_print(3, "Master_sender", "Sent task %d \"%s\" on socket %d - %d tasks active", send_task.task->number, send_task.task->message, send_task.slave_fd, master.commands_active);
 }
 
 #ifdef THREAD_SENDING
@@ -702,6 +736,16 @@ void Master_cleanup()
 
        signal(SIGCHLD, SIG_IGN); // ignore child exits now
 
+       // tell all remote nodes to exit
+       for (int i = 0; i < master.nRemote; ++i)
+       {
+               FILE * f = fdopen(master.remote_err[i], "r+"); setbuf(f, NULL);
+
+               fprintf(f, "exit\n");
+       
+               fclose(f);
+       }
+
        for (int i = 0; i < master.nSlaves; ++i)
        {
                
@@ -717,13 +761,6 @@ void Master_cleanup()
                if (master.slave[i].pid <= 0) 
                {
                        Network_close(master.slave[i].in);
-                       if (master.slave[i].ssh_pid > 0)
-                       {
-                               log_print(2, "Master_cleanup", "Killing ssh instance %d", master.slave[i].ssh_pid);
-                               kill(master.slave[i].ssh_pid, 15);
-                               if (kill(master.slave[i].ssh_pid, 0) == 0)
-                                       kill(master.slave[i].ssh_pid, 9);
-                       }
                }
                else
                {
@@ -740,83 +777,109 @@ void Master_cleanup()
        free(master.buffer);
        if (master.outfile != NULL)
                free(master.outfile);
-}
 
-void * start_server(void * args)
-{
+       libssh2_exit();
 
-       *(int*)(args) = Network_server(*(int*)args);
-       log_print(2, "start_server", "started network server");
-       return NULL;
 }
 
-int Secure_connection(char * addr, int port);
 
-void Master_absorb(char * addr, int port, int np)
+
+
+void Master_absorb(char * addr, int np)
 {
+       int port = 0;
+
+       char * user = strstr(addr, "@");
+       if (user != NULL)
+       {
+               *(user-1) = '\0';
+               char * t = user;
+               user = addr;
+               addr = t;
+       }
+       else
+       {
+               user = getpwuid(geteuid())->pw_name;
+       }
+       log_print(3, "Master_absorb", "User %s at address %s", user, addr);
+
+       // ssh to the host on port 22
+       ssh * s = ssh_new(user, addr, 22);
+       if (s == NULL)
+       {
+               log_print(0, "Master_absorb", "Couldn't ssh to %s@%s", user, addr);
+               return;
+       }
+       
+
+       // work out the name to give to the shells
        char * name = strtok(addr, ":");
        name = strtok(NULL, ":");
        if (name == NULL)
+               name = addr; // default is host:X
+       else
+               *(name-1) = '\0'; // otherwise use name:X
+
+       // setup array of remote stderr file descriptors
+       if (master.nRemote++ == 0)
        {
-               name = addr;
+               master.remote_err = (int*)(calloc(master.nRemote, sizeof(int)));
+               master.remote_reserved = master.nRemote;
        }
-       else
+       else if (master.nRemote >= master.remote_reserved)
        {
-               *(name-1) = '\0';
+               // resize dynamically
+               master.remote_reserved *= 2;
+               master.remote_err = (int*)(realloc(master.remote_err,master.remote_reserved * sizeof(int)));
        }
 
 
-       //log_print(0, "name is %s\n", name);
        
-       int first_ssh = 0;
+       int sfd = -1;
        if (master.o->encrypt)
-               first_ssh = Secure_connection(addr, port);
-
-       //pthread_t ss;
-       //int net_fd = port;
-       //pthread_create(&ss, NULL, start_server, (void*)(&net_fd));
-       
-       char buffer[BUFSIZ];
-       if (fork() == 0)
        {
-               // The alternative to this kind of terrible hack is OpenMPI's "opal"
-               // This involves >1000 lines of operating system independent ways to get an IP address from a local interface 
-               // Which will then be completely useless if there is any sort of NAT involved
-               
-               //freopen("/dev/null", "r", stdin);
-               //freopen("/dev/null", "w", stdout);
-               //freopen("/dev/null", "w", stderr);
+               int sv[2];
+               if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) != 0)
+                       error("Master_absorb", "Couldn't create socket for remote swarm");
+               sfd = sv[0];
 
-               char * cmd = buffer+sprintf(buffer, "swarm -p ");
-               if (master.o->encrypt)
-                       cmd = cmd+sprintf(cmd, "%d -e", port+1000);
-               else
-                       cmd = cmd+sprintf(cmd, "%d -u", port);
-       
-               if (np > 0)
-                       cmd = cmd+sprintf(cmd, " -n %d", np);
-               sprintf(cmd, " -m $(echo $SSH_CONNECTION | awk \'{print $1}\')");
-               log_print(3, "Master_absorb", "Execing %s", buffer);
-               execlp("ssh", "ssh", "-f", addr, buffer, NULL);
+               ssh_exec_swarm(s, NULL, sv+1, np); // start swarm remotely forward it to the socket
+               ssh_thread_add(s); // add the ssh to the thread
        }
-       log_print(3, "Master_absorb", "Listening on port %d", port);
-       int net_fd = Network_server(port);
-       log_print(3, "Master_absorb", "Created network server on port %d", port);
+       else
+       {
+               sfd = Network_server_bind(0, &port); // dynamically bind to a port
+               ssh_exec_swarm(s, &port, NULL, np); // start swarm remotely, have it connect on the port
+               ssh_destroy(s); // don't need the ssh anymore
+               sfd = Network_server_listen(sfd, NULL); // accept connections and pray
+       }
+       master.remote_err[master.nRemote-1] = sfd;
+       if (sfd > master.fd_max)
+               master.fd_max = sfd;
+
 
+       char buffer[BUFSIZ];
+
+       int newSlaves = 0;
+       
+       int len = sprintf(buffer, "%s\n", name);
+       int w = 0;
+       while (w < len)
+               w += write(sfd, buffer+w, len-w);
        
-       char * s = buffer;
-       while (read(net_fd, s, sizeof(char)) != 0)
+       
+       len = 0;
+       do
        {
-               if (*s == '\n')
-               {
-                       *s = '\0';
-                       break;
-               }
-               ++s;
+               len = read(sfd, buffer+len, sizeof(buffer));
+               buffer[len] = '\0';
        }
+       while (buffer[len-1] != '\n');
+       buffer[len-1] = '\0';
+       newSlaves = atoi(buffer);
+
+       
 
-       int newSlaves = atoi(buffer);
-       log_print(3, "Master_absorb", "Absorbing %d slaves from %s\n", newSlaves, addr);
        if (newSlaves == 0)
        {
                error("Master_absorb", "No slaves to absorb from %s", addr);
@@ -825,62 +888,70 @@ void Master_absorb(char * addr, int port, int np)
        master.slave = (Slave*)(realloc(master.slave, (master.nSlaves + newSlaves) * sizeof(Slave)));
        if (master.slave == NULL)
        {
-               error("Master_absorb", "Resizing slave array from %d to %d slaves : %s", master.nSlaves, master.nSlaves + newSlaves, strerror(errno));
+               error("Master_absorb", "Resizing slave array from %d to %d slaves : %s", 
+                       master.nSlaves, master.nSlaves + newSlaves, strerror(errno));
        }
 
-
-       if (master.o->encrypt)
-       {
-               for (int i = 0; i < newSlaves-1; ++i)
-                       master.slave[master.nSlaves+i].ssh_pid = Secure_connection(addr, port+i+1);
-
-       }       
-       master.slave[master.nSlaves+newSlaves-1].ssh_pid = first_ssh;
        
+       for (int i = 0; i < newSlaves; ++i)
+       {
+               int ii = master.nSlaves + i;
 
+               if (master.o->encrypt)
+               {
+                       int sv[2];
+                       if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) != 0)
+                               error("Master_absorb", "Couldn't create socket for remote swarm");
 
+                       
+                       LIBSSH2_LISTENER * listener = NULL;
+                       // libssh2 can't finalise the connection when the port is dynamic (ie: port = 0)
+                       while (listener == NULL)
+                       {
+                               port = 20000 + rand() % 30000; // so pick ports at random until binding succeeds
+                               listener = ssh_get_listener(s, &port); // port forward to the socket
+                       }
 
-       
-
-       char c = '\a';
-       log_print(3, "Master_absorb", "Writing bell to slave");
-       write(net_fd, &c, sizeof(char));
-       
+                       log_print(4,"Master_absorb", "Chose port %d", port);
+                       int len = sprintf(buffer, "%d\n", port);
+                       
+                       int w = 0;
+                       while (w < len)
+                       {
+                               w += write(sfd, buffer+w, len-w);
+                       }
+                       usleep(200000); // give ssh_thread a chance to actually send the data
 
+                       log_print(4, "Master_absorb", "Creating tunnel...");
+                       ssh_add_tunnel(s, listener, sv[1]);
+                       master.slave[ii].in = sv[0];
+                       master.slave[ii].out = sv[0];
 
-       for (int i = 0; i < newSlaves; ++i)
-       {
-               log_print(3, NULL, "Absorbing slave %d...", i);
-               int ii = master.nSlaves + i;
-               if (i == newSlaves-1)
-               {
-                       master.slave[ii].out = net_fd;
+                       log_print(4, "Master_absorb", "Tunnel for slave %d using socket %d<->%d setup", ii, sv[0], sv[1]);
                }
                else
                {
-                       log_print(3, "Master_absorb", "Creating server %d at time %d", i, time(NULL));
-                       write(net_fd, &c, sizeof(char));
-                       master.slave[ii].out = Network_server(port + i + 1);
+                       int tmp = Network_server_bind(0, &port); // bind to a port      
+                       
+                       master.slave[ii].in = Network_server_listen(tmp, NULL); // listen for connection
+                       master.slave[ii].out = master.slave[ii].in;
                }
-               master.slave[ii].in = master.slave[ii].out;
+
+       
+               
                
 
                if (master.slave[ii].out > master.fd_max)
                        master.fd_max = master.slave[ii].out;
 
+               char buffer[BUFSIZ];
                sprintf(buffer, "%s:%d", name, i);
                master.slave[ii].name = strdup(buffer);
                master.slave[ii].addr = strdup(addr);
                master.slave[ii].running = true;
                master.slave[ii].pid = -1;
                master.slave[ii].task = NULL;
-               master.slave[ii].task_pool = NULL;
-
-               FILE * f = fdopen(master.slave[ii].in, "w"); setbuf(f, NULL);
-               fprintf(f, "name=%s\n", master.slave[ii].name);
-
-               log_print(3, NULL, "Done absorbing slave %d...", i);
-       
+               master.slave[ii].task_pool = NULL;      
        }       
 
 
@@ -892,21 +963,6 @@ void Master_absorb(char * addr, int port, int np)
 
 }
 
-int Secure_connection(char * addr, int port)
-{
-       int result = fork();
-       if (result == 0)
-       {
-               char buffer[BUFSIZ];
-               sprintf(buffer, "%d:localhost:%d", port+1000, port);
-               freopen("/dev/null", "r", stdin);
-               freopen("/dev/null", "w", stdout);
-               freopen("/dev/null", "w", stderr);
-               execl("/usr/bin/ssh", "/usr/bin/ssh", "-N", "-R", buffer, addr, NULL);
-       }
-       return result;
-}
-
 void sigchld_handler(int signal)
 {
 
@@ -918,13 +974,6 @@ void sigchld_handler(int signal)
        if (p == -1)
                error("sigchld_handler", "waitpid : %s", strerror(errno));
        
-       if (WIFSIGNALED(s))
-       {
-               int sig = WTERMSIG(s);
-               log_print(2, "sigchld_handler", "A child [%d] was terminated with signal %d; terminating self with same signal", p, sig);
-               kill(getpid(), sig);
-               return;
-       }
 
        int i = 0;
        for (i = 0; i < master.nSlaves; ++i)
@@ -937,13 +986,28 @@ void sigchld_handler(int signal)
                return;
        }
 
-       log_print(1, "sigchld_handler", "Slave %d [%d] exited with code %d; restarting it",i, p, s);
+       fprintf(stderr, "Unexpected exit of slave %s", master.slave[i].name);
+       if (WIFSIGNALED(s))
+       {
+               int sig = WTERMSIG(s);
+               fprintf(stderr, " due to %s", strsignal(sig));
+               if (sig == SIGKILL)
+               {
+                       printf(" - committing suicide\n");
+                       kill(getpid(), sig);
+               }
+       }
+       else
+       {               
+               fprintf(stderr, " return code %d.",s);
+       }
+       fprintf(stderr, " Starting replacement.\n");
 
        Make_slave(i);
 
        if (master.o->end != NULL)
        {
-               log_print(1, "sigchld_handler", "Trying to convince slave %d to be nice", i);
+               //log_print(1, "sigchld_handler", "Trying to convince slave %d to be nice", i);
                char buffer[BUFSIZ];
                sprintf(buffer, "name=%s;echo -en \"%s\"\n", master.slave[i].name, master.o->end);
                if (write(master.slave[i].in, buffer, strlen(buffer)) <= 0)

UCC git Repository :: git.ucc.asn.au