Things seem to work...
[matches/swarm.git] / src / ssh.c
diff --git a/src/ssh.c b/src/ssh.c
new file mode 100644 (file)
index 0000000..0e1e197
--- /dev/null
+++ b/src/ssh.c
@@ -0,0 +1,688 @@
+#include "ssh.h"
+#include "network.h"
+#include "log.h"
+#include <termios.h>
+#include <dirent.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <errno.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <sys/time.h>
+#include <sys/select.h>
+#include <assert.h>
+#include <signal.h>
+
+
+enum {
+    AUTH_NONE = 0,
+    AUTH_PASSWORD,
+    AUTH_PUBLICKEY
+};
+
+static bool ssh_fingerprint_ok(char * f);
+
+static void ssh_get_passwd(char * buffer, int len);
+
+static bool ssh_agent_auth(ssh * s);
+
+static bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts);
+
+static bool ssh_thread_running = false;
+static int ssh_array_reserved = 0;
+static int ssh_array_used = 0;
+static ssh ** ssh_array = NULL;
+static int ssh_thread_maxfd = 0;
+
+static int waitsocket(int socket_fd, LIBSSH2_SESSION *session)
+{
+    struct timeval timeout;
+    int rc;
+    fd_set fd;
+    fd_set *writefd = NULL;
+    fd_set *readfd = NULL;
+    int dir;
+    timeout.tv_sec = 10;
+    timeout.tv_usec = 0;
+    FD_ZERO(&fd);
+    FD_SET(socket_fd, &fd);
+    /* now make sure we wait in the correct direction */ 
+    dir = libssh2_session_block_directions(session);
+
+    if(dir & LIBSSH2_SESSION_BLOCK_INBOUND)
+        readfd = &fd;
+    if(dir & LIBSSH2_SESSION_BLOCK_OUTBOUND)
+        writefd = &fd;
+    rc = select(socket_fd + 1, readfd, writefd, NULL, &timeout);
+    return rc;
+}
+
+ssh * ssh_new(char * username, char * addr, int port)
+{
+       ssh * s = (ssh*)(calloc(1, sizeof(ssh)));
+       s->user = username;
+       s->addr = addr;
+
+       s->socket = Network_client(addr, port,100);
+       s->session = libssh2_session_init();
+       if (s->session == NULL)
+       {
+               free(s);
+               log_print(2,"ssh_new", "libssh2_session_init returned NULL");
+               return NULL;
+       }
+
+
+
+       int err = libssh2_session_handshake(s->session, s->socket);
+       if (err != 0)
+       {
+               free(s);
+               log_print(2,"ssh_new", "libssh2_session_handshake fails - error code %d", err);
+               return NULL;
+       }
+       s->fingerprint = (char*)(libssh2_hostkey_hash(s->session, LIBSSH2_HOSTKEY_HASH_SHA1));
+       if (!ssh_fingerprint_ok(s->fingerprint))
+       {
+               free(s);
+               log_print(2,"ssh_new", "Fingerprint of host \"%s\" was not OK", addr);
+               return NULL;
+       }       
+       
+       char * userauthlist = libssh2_userauth_list(s->session, username, strlen(username));
+
+       int auth = AUTH_NONE;
+       if (strstr(userauthlist, "password"))
+               auth |= AUTH_PASSWORD;
+       if (strstr(userauthlist, "publickey"))
+               auth |= AUTH_PUBLICKEY;
+
+       bool ok = false;
+
+       if (auth & AUTH_PUBLICKEY)
+       {
+               // first try connecting with agent
+               ok = ssh_agent_auth(s);
+
+               
+               
+               if (!ok)
+               {
+                       log_print(3, "ssh_new", "Agent authentication failed, looking at public keys");
+
+                       if (SSH_DIR[0] == '~' && SSH_DIR[1] == '/')
+                       {
+                               char ssh_dir[BUFSIZ];
+                               sprintf(ssh_dir, "%s/%s",getenv("HOME"),SSH_DIR+2);
+                               ok = ssh_publickey_auth(s, ssh_dir,3);
+                       }               
+                       else
+                               ok = ssh_publickey_auth(s, SSH_DIR,3);
+               }
+               
+       }
+       
+       if (auth & AUTH_PASSWORD && !ok)
+       {
+               log_print(3, "ssh_new", "public keys failed, try password");
+               for (int i = 0; i < 3 && !ok; ++i)
+               {
+                       printf("Password for %s@%s:", username, addr);
+                       char password[BUFSIZ];
+                       ssh_get_passwd(password, BUFSIZ);
+       
+                       if (libssh2_userauth_password(s->session, username, password) == 0)
+                       {
+                               ok = true;
+                       }
+               }
+               if (!ok)
+                       log_print(3, "ssh_new", "Failed to authenticate by password.");
+       }
+       
+       if (!ok)
+       {
+               free(s);
+               log_print(2, "ssh_new", "All attempts at authenticating failed.");
+               return NULL;
+       }
+       log_print(3, "ssh_new", "Authenticated!");
+
+       s->reserved_tunnels = 1;
+       s->tunnel = (ssh_tunnel*)(calloc(s->reserved_tunnels, sizeof(ssh_tunnel)));
+       s->nTunnels = 0;
+       libssh2_session_set_blocking(s->session, 0);
+       return s;
+}
+
+void ssh_destroy(ssh * s)
+{
+       ssh_thread_del(s);
+
+       for (int i = 0; i < s->nTunnels; ++i)
+       {
+               int err;
+               char buffer[BUFSIZ];
+               do
+               {
+                       err = libssh2_channel_read(s->tunnel[i].channel, buffer, sizeof(buffer));
+                       write(s->tunnel[i].forward_sock, buffer, err);
+               }
+               while (err > 0);
+
+               while ((err = libssh2_channel_close(s->tunnel[i].channel)) == LIBSSH2_ERROR_EAGAIN)
+                       waitsocket(s->socket, s->session);
+               
+               libssh2_channel_free(s->tunnel[i].channel);
+               close(s->tunnel[i].forward_sock);
+       }
+
+
+       libssh2_session_disconnect(s->session, "goodbye");
+       libssh2_session_free(s->session);
+
+       free(s->tunnel);
+}
+
+bool ssh_fingerprint_ok(char * f)
+{
+       //TODO: Check fingerprint
+       log_print(1, "ssh_fingerprint_ok", "Unimplemented!");
+       return true;
+}
+
+void ssh_get_passwd(char * buffer, int len)
+{
+       struct termios oflags, nflags;
+    
+       tcgetattr(fileno(stdin), &oflags);
+       nflags = oflags;
+       nflags.c_lflag &= ~ECHO;
+       nflags.c_lflag |= ECHONL;
+
+       if (tcsetattr(fileno(stdin), TCSANOW, &nflags) != 0) 
+       {
+               error("ssh_get_passwd", "tcsetattr : %s", strerror(errno));
+       }
+
+       fgets(buffer, len * sizeof(char), stdin);
+       buffer[strlen(buffer) - 1] = '\0';
+   
+       if (tcsetattr(fileno(stdin), TCSANOW, &oflags) != 0)
+       {
+               error("ssh_get_passwd", "tcsetattr : %s", strerror(errno));
+       }
+}
+
+bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts)
+{
+
+       
+       DIR * d = opendir(dir);
+       struct dirent * dp;
+       if (d == NULL)
+       {
+               log_print(0, "ssh_publickey_auth", "Couldn't open directory %s : %s", dir, strerror(errno));
+               return false;
+       }
+
+       while ((dp = readdir(d)) != NULL)
+       {
+               
+               // skip public keys
+               if (strstr(dp->d_name, ".pub") != NULL)
+                       continue;
+               
+               // assume file is a private key 
+               // find corresponding public key
+               char pub[BUFSIZ]; char priv[BUFSIZ];
+               if (dir[strlen(dir)-1] == '/')
+               {
+                       sprintf(pub, "%s%s.pub", dir,dp->d_name);
+                       sprintf(priv, "%s%s", dir, dp->d_name);
+               }
+               else
+               {
+                       sprintf(pub, "%s/%s.pub", dir,dp->d_name);
+                       sprintf(priv, "%s/%s", dir, dp->d_name);
+               }
+               
+               struct stat t;
+               if (stat(priv, &t) != 0)
+               {
+                       log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", priv, strerror(errno));
+                       continue;
+               }
+
+               if (!S_ISREG(t.st_mode))
+               {
+                       log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", priv);
+                       continue;
+               }
+
+               if (stat(pub, &t) != 0)
+               {
+                       log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", pub, strerror(errno));
+                       continue;
+               }
+
+               if (!S_ISREG(t.st_mode))
+               {
+                       log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", pub);
+                       continue;
+               }
+       
+
+
+
+               
+                       
+               //libssh2_trace(s->session, LIBSSH2_TRACE_AUTH | LIBSSH2_TRACE_PUBLICKEY);
+               int err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,"");
+               if (err == 0)
+               {
+                       log_print(1, "ssh_publickey_auth", "Shouldn't use keys with no passphrase");
+               }
+               else if (err == LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED)
+               {
+
+                       char passphrase[BUFSIZ];
+                       for (int i = 0; i < nAttempts; ++i)
+                       {
+                               printf("Passphrase for key %s:", priv);
+                               ssh_get_passwd(passphrase, BUFSIZ);
+                               err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,passphrase);
+                               if (err != LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED) break;
+                       }
+               }
+               if (err == 0)
+               {
+                       closedir(d);
+                       return true;
+               }
+       }
+       closedir(d);
+
+       
+       return false;
+
+       
+
+       
+}
+
+bool ssh_agent_auth(ssh * s)
+{
+       LIBSSH2_AGENT * agent = libssh2_agent_init(s->session);
+       if (agent == NULL)
+       {
+               log_print(0, "ssh_agent_auth", "Couldn't initialise agent support.");
+               return false;
+       }
+
+       if (libssh2_agent_connect(agent) != 0)
+       {
+               log_print(0, "ssh_agent_auth", "Failed to connect to ssh-agent.");
+               return false;
+       }
+
+       if (libssh2_agent_list_identities(agent) != 0)
+       {
+               log_print(0, "ssh_agent_auth", "Failure requesting identities to ssh-agent.");
+               return false;
+       }
+
+       struct libssh2_agent_publickey * identity = NULL;
+       struct libssh2_agent_publickey * prev_identity = NULL;
+
+       while (true)
+       {
+               int err = libssh2_agent_get_identity(agent, &identity, prev_identity);
+               if (err == 1)
+               {
+                       log_print(0, "ssh_agent_auth", "Couldn't continue authentication.");
+                       return false;
+               }
+               if (err < 0)
+               {
+                       log_print(0, "ssh_agent_auth", "Failure obtaining identity from ssh-agent support.");
+                       return false;
+               }
+
+               if (libssh2_agent_userauth(agent, s->user, identity) == 0)
+               {
+                       log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s succeeded!", s->user, identity->comment);
+                       return true;
+               }
+               else
+               {
+                       log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s failed.", s->user, identity->comment);
+               }
+               prev_identity = identity;
+       }
+
+       return false;
+
+}
+
+LIBSSH2_LISTENER * ssh_get_listener(ssh * s, int * port)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+       libssh2_session_set_blocking(s->session, 1);
+       //libssh2_trace(s->session, ~0);
+       LIBSSH2_LISTENER * l = libssh2_channel_forward_listen_ex(s->session, "localhost", *port, port,1);
+       if (l == NULL)
+       {
+               char * error;
+               libssh2_session_last_error(s->session, &error, NULL, 0);
+               log_print(0, "ssh_get_listener", "Error: %s", error);
+       }
+       libssh2_session_set_blocking(s->session, 0);
+       pthread_mutex_unlock(&ssh_thread_mutex);
+       return l;
+}
+
+void ssh_add_tunnel(ssh * s, LIBSSH2_LISTENER * listener, int socket)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+       //log_print(3, "ssh_add_tunnel", "accepting connection...");
+       libssh2_session_set_blocking(s->session , 1);
+       //libssh2_trace(s->session, ~0);
+       LIBSSH2_CHANNEL * channel = libssh2_channel_forward_accept(listener);
+       if (channel == NULL)
+       {
+               char * error;
+               libssh2_session_last_error(s->session, &error, NULL, 0);
+               log_print(0, "ssh_add_tunnel", "Error: %s", error);
+       }
+       libssh2_session_set_blocking(s->session , 0);
+       //log_print(3, "ssh_add_tunnel", "accepted remote connection...");
+       
+       ssh_tunnel * t = s->tunnel+(s->nTunnels++);
+       t->forward_sock = socket;
+       t->channel = channel;
+
+       if (socket > ssh_thread_maxfd)
+               ssh_thread_maxfd = socket;
+       
+       if (s->nTunnels >= s->reserved_tunnels)
+       {
+               s->reserved_tunnels *= 2;
+               s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel)));
+       }
+       
+       pthread_mutex_unlock(&ssh_thread_mutex);
+}
+
+void ssh_exec_swarm(ssh * s, int * port, int * socket, int np)
+{
+
+       // secure things
+       LIBSSH2_CHANNEL * channel = NULL;
+       while ((channel = libssh2_channel_open_session(s->session)) == NULL 
+               && libssh2_session_last_error(s->session, NULL, NULL, 0) == LIBSSH2_ERROR_EAGAIN)
+       {
+               waitsocket(s->socket, s->session);
+       }
+       
+       if (channel == NULL)
+       {
+               error("ssh_exec_swarm", "Couldn't create channel with ssh session");
+       }
+
+
+       char buffer[BUFSIZ];
+       
+       // connect secure
+       if (port == NULL && socket != NULL)
+       {
+               sprintf(buffer, "%s -r -", options.program);
+               if (np != 0)
+                       sprintf(buffer, " -n %d", np);
+       }
+       else if (port != NULL && socket == NULL)
+       {
+               sprintf(buffer, "%s -r $(echo $SSH_CONNECTION | awk \'{print $1}\'):%d", options.program, *port);
+               if (np != 0)
+                       sprintf(buffer, " -n %d", np);
+
+       }
+       else
+               error("ssh_exec_swarm", "Exactly *one* of the port or socket pointers must not be NULL");
+
+       int err;
+       while ((err = libssh2_channel_exec(channel, buffer)) == LIBSSH2_ERROR_EAGAIN)
+       {
+               waitsocket(s->socket, s->session);
+       }
+
+       if (socket != NULL)
+       {
+               
+               pthread_mutex_lock(&ssh_thread_mutex);
+
+               ssh_tunnel * t = s->tunnel+(s->nTunnels++);
+
+               t->forward_sock = *socket;
+               t->channel = channel;
+
+               if (*socket > ssh_thread_maxfd)
+                       ssh_thread_maxfd = *socket;
+
+       
+               if (s->nTunnels >= s->reserved_tunnels)
+               {
+                       s->reserved_tunnels *= 2;
+                       s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel)));
+               }
+
+               pthread_mutex_unlock(&ssh_thread_mutex);
+       }
+       else
+       {
+               
+               // read everything and close the channel
+               while (true)
+               {
+                       while ((err = libssh2_channel_read(channel, buffer, sizeof(buffer))) > 0);
+                       if (err == LIBSSH2_ERROR_EAGAIN)
+                       {
+                               waitsocket(s->socket, s->session);
+                       }
+                       else
+                       {
+                               break;
+                       }
+               }
+
+               while ((err = libssh2_channel_close(channel)) == LIBSSH2_ERROR_EAGAIN)
+               {
+                       waitsocket(s->socket, s->session);
+               }
+               libssh2_channel_free(channel);
+       }
+       
+
+
+
+}
+
+
+
+
+pthread_mutex_t ssh_thread_mutex = PTHREAD_MUTEX_INITIALIZER;
+pthread_t ssh_pthread;
+
+void * ssh_thread(void * args)
+{
+
+
+       fd_set readSet;
+       char buffer[BUFSIZ];
+       struct timeval tv;
+       tv.tv_sec = 0;
+       tv.tv_usec = 100000;
+       
+       while (true)
+       {
+               //log_print(1, "ssh_thread", "loop - %d ssh's", ssh_array_used);
+               FD_ZERO(&readSet);
+               pthread_mutex_lock(&ssh_thread_mutex);
+
+               if (!ssh_thread_running) break;
+
+               for (int i = 0; i < ssh_array_used; ++i)
+               {
+                       ssh * s = ssh_array[i];
+                       if (s == NULL) continue;
+                       for (int j = 0; j < s->nTunnels; ++j)
+                       {
+                               FD_SET(s->tunnel[j].forward_sock, &readSet);
+                       }
+               }
+
+               pthread_mutex_unlock(&ssh_thread_mutex);
+               select(ssh_thread_maxfd+1, &readSet, NULL, NULL, &tv);
+               pthread_mutex_lock(&ssh_thread_mutex);
+
+               for (int i = 0; i < ssh_array_used; ++i)
+               {
+                       ssh * s = ssh_array[i];
+                       //log_print(2, "ssh_thread", "array[%d] = %p", i, s);
+                       if (s == NULL) continue;
+                       for (int j = 0; j < s->nTunnels; ++j)
+                       {
+                               //log_print(2, "ssh_thread", "Tunnel number %d, socket %d", j, s->tunnel[j].forward_sock);
+                               if (FD_ISSET(s->tunnel[j].forward_sock, &readSet))
+                               {
+                                       //log_print(2, "ssh_thread", "reading from socket %d", s->tunnel[j].forward_sock);
+                                       int len = read(s->tunnel[j].forward_sock, buffer, sizeof(buffer));
+                                       
+                                       if (len <= 0)
+                                               continue;
+                                       buffer[len] = '\0';
+                                       int written = 0; int w = 0;
+                                       do
+                                       {
+                                               //log_print(2, "ssh_thread", "writing %s to channel", buffer);
+                                               w = libssh2_channel_write(s->tunnel[j].channel, buffer+written, len-written);
+                                               assert(w >= 0);
+                                               written += w;
+                                       }
+                                       while (w > 0 && written < len);
+                               }
+                               while (true)
+                               {
+                                       //log_print(2, "ssh_thread", "Try to read from channel %p", s->tunnel[j].channel);
+                                       int len = libssh2_channel_read(s->tunnel[j].channel, buffer, sizeof(buffer));
+                                       //log_print(2, "ssh_thread", "Read %s from channel", buffer);
+                                       if (len == LIBSSH2_ERROR_EAGAIN) break;
+                                       assert(len >= 0);
+
+                                       int written = 0; int w = 0;
+                                       while (written < len)
+                                       {
+                                               //log_print(2, "ssh_thread", "Wrote %s to socket %d", buffer+written, s->tunnel[j].forward_sock);
+                                               w = write(s->tunnel[j].forward_sock, buffer+written, len-written);
+                                               written += w;
+                                       }
+                                       if (libssh2_channel_eof(s->tunnel[j].channel))
+                                       {
+                                               //log_print(1, "ssh_thread", "Got to eof in channel %p", s->tunnel[j].channel);
+                                       }
+                               }
+                       }
+               }
+               pthread_mutex_unlock(&ssh_thread_mutex);
+       }
+
+       return NULL;
+}
+
+void ssh_thread_add(ssh * s)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+
+       ssh_array_used++;
+
+       bool found = false;
+       for (int i = 0; (i < ssh_array_reserved && !found); ++i)
+       {
+               if (ssh_array[i] == NULL)
+               {
+                       ssh_array[i] = s;
+                       found = true;
+                       break;
+               }
+       }
+
+
+       if (!found)
+       {
+               int old = ssh_array_reserved;
+               ssh_array_reserved = (ssh_array_reserved + 1) * 2;
+               if (ssh_array == NULL)
+                       ssh_array = (ssh**)(calloc(ssh_array_reserved, sizeof(ssh*)));
+               else
+               {
+                       ssh_array = (ssh**)(realloc(ssh_array, ssh_array_reserved * sizeof(ssh*)));
+                       for (int i = old+1; i < ssh_array_reserved; ++i)
+                               ssh_array[i] = NULL;
+                       
+               }
+               ssh_array[old] = s;
+       }
+
+       for (int i = 0; i < s->nTunnels; ++i)
+       {
+               if (s->tunnel[i].forward_sock > ssh_thread_maxfd)
+                       ssh_thread_maxfd = s->tunnel[i].forward_sock;
+       }
+
+       if (!ssh_thread_running)
+       {
+               ssh_thread_running = true;
+               sigset_t set;
+               int err;
+               sigfillset(&set);
+               err = pthread_sigmask(SIG_SETMASK, &set, NULL);
+               if (err != 0)
+                       error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno));
+               err = pthread_create(&ssh_pthread, NULL, ssh_thread, NULL);
+               if (err != 0)
+                       error("ssh_thread_add", "pthread_create : %s", strerror(errno));
+               sigemptyset(&set);
+               err = pthread_sigmask(SIG_SETMASK, &set, NULL);
+               if (err != 0)
+                       error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno));
+       }
+
+       
+
+       pthread_mutex_unlock(&ssh_thread_mutex);
+}
+
+void ssh_thread_del(ssh * s)
+{
+       pthread_mutex_lock(&ssh_thread_mutex);
+       
+       for (int i = 0; i < ssh_array_reserved; ++i)
+       {
+               if (ssh_array[i] == s)
+               {
+                       ssh_array[i] = NULL;
+                       ssh_thread_running = !(--ssh_array_used == 0);
+                       break;
+               }
+       }
+
+       pthread_mutex_unlock(&ssh_thread_mutex);
+}

UCC git Repository :: git.ucc.asn.au