0e1e1975d22f4643e601474601a8e0b6b9d0e5c9
[matches/swarm.git] / src / ssh.c
1 #include "ssh.h"
2 #include "network.h"
3 #include "log.h"
4 #include <termios.h>
5 #include <dirent.h>
6 #include <sys/types.h>
7 #include <unistd.h>
8 #include <errno.h>
9 #include <sys/types.h>
10 #include <unistd.h>
11 #include <sys/time.h>
12 #include <sys/select.h>
13 #include <assert.h>
14 #include <signal.h>
15
16
17 enum {
18     AUTH_NONE = 0,
19     AUTH_PASSWORD,
20     AUTH_PUBLICKEY
21 };
22
23 static bool ssh_fingerprint_ok(char * f);
24
25 static void ssh_get_passwd(char * buffer, int len);
26
27 static bool ssh_agent_auth(ssh * s);
28
29 static bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts);
30
31 static bool ssh_thread_running = false;
32 static int ssh_array_reserved = 0;
33 static int ssh_array_used = 0;
34 static ssh ** ssh_array = NULL;
35 static int ssh_thread_maxfd = 0;
36
37 static int waitsocket(int socket_fd, LIBSSH2_SESSION *session)
38 {
39     struct timeval timeout;
40     int rc;
41     fd_set fd;
42     fd_set *writefd = NULL;
43     fd_set *readfd = NULL;
44     int dir;
45  
46     timeout.tv_sec = 10;
47     timeout.tv_usec = 0;
48  
49     FD_ZERO(&fd);
50  
51     FD_SET(socket_fd, &fd);
52  
53     /* now make sure we wait in the correct direction */ 
54     dir = libssh2_session_block_directions(session);
55
56  
57     if(dir & LIBSSH2_SESSION_BLOCK_INBOUND)
58         readfd = &fd;
59  
60     if(dir & LIBSSH2_SESSION_BLOCK_OUTBOUND)
61         writefd = &fd;
62  
63     rc = select(socket_fd + 1, readfd, writefd, NULL, &timeout);
64  
65     return rc;
66 }
67
68 ssh * ssh_new(char * username, char * addr, int port)
69 {
70         ssh * s = (ssh*)(calloc(1, sizeof(ssh)));
71         s->user = username;
72         s->addr = addr;
73
74         s->socket = Network_client(addr, port,100);
75         s->session = libssh2_session_init();
76         if (s->session == NULL)
77         {
78                 free(s);
79                 log_print(2,"ssh_new", "libssh2_session_init returned NULL");
80                 return NULL;
81         }
82
83
84
85         int err = libssh2_session_handshake(s->session, s->socket);
86         if (err != 0)
87         {
88                 free(s);
89                 log_print(2,"ssh_new", "libssh2_session_handshake fails - error code %d", err);
90                 return NULL;
91         }
92         s->fingerprint = (char*)(libssh2_hostkey_hash(s->session, LIBSSH2_HOSTKEY_HASH_SHA1));
93         if (!ssh_fingerprint_ok(s->fingerprint))
94         {
95                 free(s);
96                 log_print(2,"ssh_new", "Fingerprint of host \"%s\" was not OK", addr);
97                 return NULL;
98         }       
99         
100         char * userauthlist = libssh2_userauth_list(s->session, username, strlen(username));
101
102         int auth = AUTH_NONE;
103         if (strstr(userauthlist, "password"))
104                 auth |= AUTH_PASSWORD;
105         if (strstr(userauthlist, "publickey"))
106                 auth |= AUTH_PUBLICKEY;
107
108         bool ok = false;
109
110         if (auth & AUTH_PUBLICKEY)
111         {
112                 // first try connecting with agent
113                 ok = ssh_agent_auth(s);
114
115                 
116                 
117                 if (!ok)
118                 {
119                         log_print(3, "ssh_new", "Agent authentication failed, looking at public keys");
120
121                         if (SSH_DIR[0] == '~' && SSH_DIR[1] == '/')
122                         {
123                                 char ssh_dir[BUFSIZ];
124                                 sprintf(ssh_dir, "%s/%s",getenv("HOME"),SSH_DIR+2);
125                                 ok = ssh_publickey_auth(s, ssh_dir,3);
126                         }               
127                         else
128                                 ok = ssh_publickey_auth(s, SSH_DIR,3);
129                 }
130                 
131         }
132         
133         if (auth & AUTH_PASSWORD && !ok)
134         {
135                 log_print(3, "ssh_new", "public keys failed, try password");
136                 for (int i = 0; i < 3 && !ok; ++i)
137                 {
138                         printf("Password for %s@%s:", username, addr);
139                         char password[BUFSIZ];
140                         ssh_get_passwd(password, BUFSIZ);
141         
142                         if (libssh2_userauth_password(s->session, username, password) == 0)
143                         {
144                                 ok = true;
145                         }
146                 }
147                 if (!ok)
148                         log_print(3, "ssh_new", "Failed to authenticate by password.");
149         }
150         
151         if (!ok)
152         {
153                 free(s);
154                 log_print(2, "ssh_new", "All attempts at authenticating failed.");
155                 return NULL;
156         }
157         log_print(3, "ssh_new", "Authenticated!");
158
159         s->reserved_tunnels = 1;
160         s->tunnel = (ssh_tunnel*)(calloc(s->reserved_tunnels, sizeof(ssh_tunnel)));
161         s->nTunnels = 0;
162         libssh2_session_set_blocking(s->session, 0);
163         return s;
164 }
165
166 void ssh_destroy(ssh * s)
167 {
168         ssh_thread_del(s);
169
170         for (int i = 0; i < s->nTunnels; ++i)
171         {
172                 int err;
173                 char buffer[BUFSIZ];
174                 do
175                 {
176                         err = libssh2_channel_read(s->tunnel[i].channel, buffer, sizeof(buffer));
177                         write(s->tunnel[i].forward_sock, buffer, err);
178                 }
179                 while (err > 0);
180
181                 while ((err = libssh2_channel_close(s->tunnel[i].channel)) == LIBSSH2_ERROR_EAGAIN)
182                         waitsocket(s->socket, s->session);
183                 
184                 libssh2_channel_free(s->tunnel[i].channel);
185                 close(s->tunnel[i].forward_sock);
186         }
187
188
189         libssh2_session_disconnect(s->session, "goodbye");
190         libssh2_session_free(s->session);
191
192         free(s->tunnel);
193 }
194
195 bool ssh_fingerprint_ok(char * f)
196 {
197         //TODO: Check fingerprint
198         log_print(1, "ssh_fingerprint_ok", "Unimplemented!");
199         return true;
200 }
201
202 void ssh_get_passwd(char * buffer, int len)
203 {
204         struct termios oflags, nflags;
205     
206         tcgetattr(fileno(stdin), &oflags);
207         nflags = oflags;
208         nflags.c_lflag &= ~ECHO;
209         nflags.c_lflag |= ECHONL;
210
211         if (tcsetattr(fileno(stdin), TCSANOW, &nflags) != 0) 
212         {
213                 error("ssh_get_passwd", "tcsetattr : %s", strerror(errno));
214         }
215
216         fgets(buffer, len * sizeof(char), stdin);
217         buffer[strlen(buffer) - 1] = '\0';
218    
219         if (tcsetattr(fileno(stdin), TCSANOW, &oflags) != 0)
220         {
221                 error("ssh_get_passwd", "tcsetattr : %s", strerror(errno));
222         }
223 }
224
225 bool ssh_publickey_auth(ssh * s, char * dir, int nAttempts)
226 {
227
228         
229         DIR * d = opendir(dir);
230         struct dirent * dp;
231         if (d == NULL)
232         {
233                 log_print(0, "ssh_publickey_auth", "Couldn't open directory %s : %s", dir, strerror(errno));
234                 return false;
235         }
236
237         while ((dp = readdir(d)) != NULL)
238         {
239                 
240                 // skip public keys
241                 if (strstr(dp->d_name, ".pub") != NULL)
242                         continue;
243                 
244                 // assume file is a private key 
245                 // find corresponding public key
246                 char pub[BUFSIZ]; char priv[BUFSIZ];
247                 if (dir[strlen(dir)-1] == '/')
248                 {
249                         sprintf(pub, "%s%s.pub", dir,dp->d_name);
250                         sprintf(priv, "%s%s", dir, dp->d_name);
251                 }
252                 else
253                 {
254                         sprintf(pub, "%s/%s.pub", dir,dp->d_name);
255                         sprintf(priv, "%s/%s", dir, dp->d_name);
256                 }
257                 
258                 struct stat t;
259                 if (stat(priv, &t) != 0)
260                 {
261                         log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", priv, strerror(errno));
262                         continue;
263                 }
264
265                 if (!S_ISREG(t.st_mode))
266                 {
267                         log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", priv);
268                         continue;
269                 }
270
271                 if (stat(pub, &t) != 0)
272                 {
273                         log_print(3,"ssh_publickey_auth", "Can't stat file %s : %s", pub, strerror(errno));
274                         continue;
275                 }
276
277                 if (!S_ISREG(t.st_mode))
278                 {
279                         log_print(3, "ssh_publickey_auth", "%s doesn't appear to be a regular file", pub);
280                         continue;
281                 }
282         
283
284
285
286                 
287                         
288                 //libssh2_trace(s->session, LIBSSH2_TRACE_AUTH | LIBSSH2_TRACE_PUBLICKEY);
289                 int err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,"");
290                 if (err == 0)
291                 {
292                         log_print(1, "ssh_publickey_auth", "Shouldn't use keys with no passphrase");
293                 }
294                 else if (err == LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED)
295                 {
296
297                         char passphrase[BUFSIZ];
298                         for (int i = 0; i < nAttempts; ++i)
299                         {
300                                 printf("Passphrase for key %s:", priv);
301                                 ssh_get_passwd(passphrase, BUFSIZ);
302                                 err = libssh2_userauth_publickey_fromfile(s->session, s->user, pub, priv,passphrase);
303                                 if (err != LIBSSH2_ERROR_PUBLICKEY_UNVERIFIED) break;
304                         }
305                 }
306                 if (err == 0)
307                 {
308                         closedir(d);
309                         return true;
310                 }
311         }
312         closedir(d);
313
314         
315         return false;
316
317         
318
319         
320 }
321
322 bool ssh_agent_auth(ssh * s)
323 {
324         LIBSSH2_AGENT * agent = libssh2_agent_init(s->session);
325         if (agent == NULL)
326         {
327                 log_print(0, "ssh_agent_auth", "Couldn't initialise agent support.");
328                 return false;
329         }
330
331         if (libssh2_agent_connect(agent) != 0)
332         {
333                 log_print(0, "ssh_agent_auth", "Failed to connect to ssh-agent.");
334                 return false;
335         }
336
337         if (libssh2_agent_list_identities(agent) != 0)
338         {
339                 log_print(0, "ssh_agent_auth", "Failure requesting identities to ssh-agent.");
340                 return false;
341         }
342
343         struct libssh2_agent_publickey * identity = NULL;
344         struct libssh2_agent_publickey * prev_identity = NULL;
345
346         while (true)
347         {
348                 int err = libssh2_agent_get_identity(agent, &identity, prev_identity);
349                 if (err == 1)
350                 {
351                         log_print(0, "ssh_agent_auth", "Couldn't continue authentication.");
352                         return false;
353                 }
354                 if (err < 0)
355                 {
356                         log_print(0, "ssh_agent_auth", "Failure obtaining identity from ssh-agent support.");
357                         return false;
358                 }
359
360                 if (libssh2_agent_userauth(agent, s->user, identity) == 0)
361                 {
362                         log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s succeeded!", s->user, identity->comment);
363                         return true;
364                 }
365                 else
366                 {
367                         log_print(3, "ssh_agent_auth", "Authentication with username %s and public key %s failed.", s->user, identity->comment);
368                 }
369                 prev_identity = identity;
370         }
371
372         return false;
373
374 }
375
376 LIBSSH2_LISTENER * ssh_get_listener(ssh * s, int * port)
377 {
378         pthread_mutex_lock(&ssh_thread_mutex);
379         libssh2_session_set_blocking(s->session, 1);
380         //libssh2_trace(s->session, ~0);
381         LIBSSH2_LISTENER * l = libssh2_channel_forward_listen_ex(s->session, "localhost", *port, port,1);
382         if (l == NULL)
383         {
384                 char * error;
385                 libssh2_session_last_error(s->session, &error, NULL, 0);
386                 log_print(0, "ssh_get_listener", "Error: %s", error);
387         }
388         libssh2_session_set_blocking(s->session, 0);
389         pthread_mutex_unlock(&ssh_thread_mutex);
390         return l;
391 }
392
393 void ssh_add_tunnel(ssh * s, LIBSSH2_LISTENER * listener, int socket)
394 {
395         pthread_mutex_lock(&ssh_thread_mutex);
396         //log_print(3, "ssh_add_tunnel", "accepting connection...");
397         libssh2_session_set_blocking(s->session , 1);
398         //libssh2_trace(s->session, ~0);
399         LIBSSH2_CHANNEL * channel = libssh2_channel_forward_accept(listener);
400         if (channel == NULL)
401         {
402                 char * error;
403                 libssh2_session_last_error(s->session, &error, NULL, 0);
404                 log_print(0, "ssh_add_tunnel", "Error: %s", error);
405         }
406         libssh2_session_set_blocking(s->session , 0);
407         //log_print(3, "ssh_add_tunnel", "accepted remote connection...");
408         
409         ssh_tunnel * t = s->tunnel+(s->nTunnels++);
410         t->forward_sock = socket;
411         t->channel = channel;
412
413         if (socket > ssh_thread_maxfd)
414                 ssh_thread_maxfd = socket;
415         
416         if (s->nTunnels >= s->reserved_tunnels)
417         {
418                 s->reserved_tunnels *= 2;
419                 s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel)));
420         }
421         
422         pthread_mutex_unlock(&ssh_thread_mutex);
423 }
424
425 void ssh_exec_swarm(ssh * s, int * port, int * socket, int np)
426 {
427
428         // secure things
429         LIBSSH2_CHANNEL * channel = NULL;
430         while ((channel = libssh2_channel_open_session(s->session)) == NULL 
431                 && libssh2_session_last_error(s->session, NULL, NULL, 0) == LIBSSH2_ERROR_EAGAIN)
432         {
433                 waitsocket(s->socket, s->session);
434         }
435         
436         if (channel == NULL)
437         {
438                 error("ssh_exec_swarm", "Couldn't create channel with ssh session");
439         }
440
441
442         char buffer[BUFSIZ];
443         
444         // connect secure
445         if (port == NULL && socket != NULL)
446         {
447                 sprintf(buffer, "%s -r -", options.program);
448                 if (np != 0)
449                         sprintf(buffer, " -n %d", np);
450         }
451         else if (port != NULL && socket == NULL)
452         {
453                 sprintf(buffer, "%s -r $(echo $SSH_CONNECTION | awk \'{print $1}\'):%d", options.program, *port);
454                 if (np != 0)
455                         sprintf(buffer, " -n %d", np);
456
457         }
458         else
459                 error("ssh_exec_swarm", "Exactly *one* of the port or socket pointers must not be NULL");
460
461         int err;
462         while ((err = libssh2_channel_exec(channel, buffer)) == LIBSSH2_ERROR_EAGAIN)
463         {
464                 waitsocket(s->socket, s->session);
465         }
466
467         if (socket != NULL)
468         {
469                 
470                 pthread_mutex_lock(&ssh_thread_mutex);
471
472                 ssh_tunnel * t = s->tunnel+(s->nTunnels++);
473
474                 t->forward_sock = *socket;
475                 t->channel = channel;
476
477                 if (*socket > ssh_thread_maxfd)
478                         ssh_thread_maxfd = *socket;
479
480         
481                 if (s->nTunnels >= s->reserved_tunnels)
482                 {
483                         s->reserved_tunnels *= 2;
484                         s->tunnel = (ssh_tunnel*)(realloc(s->tunnel, s->reserved_tunnels * sizeof(ssh_tunnel)));
485                 }
486
487                 pthread_mutex_unlock(&ssh_thread_mutex);
488         }
489         else
490         {
491                 
492                 // read everything and close the channel
493                 while (true)
494                 {
495                         while ((err = libssh2_channel_read(channel, buffer, sizeof(buffer))) > 0);
496                         if (err == LIBSSH2_ERROR_EAGAIN)
497                         {
498                                 waitsocket(s->socket, s->session);
499                         }
500                         else
501                         {
502                                 break;
503                         }
504                 }
505
506                 while ((err = libssh2_channel_close(channel)) == LIBSSH2_ERROR_EAGAIN)
507                 {
508                         waitsocket(s->socket, s->session);
509                 }
510                 libssh2_channel_free(channel);
511         }
512         
513
514
515
516 }
517
518
519
520
521 pthread_mutex_t ssh_thread_mutex = PTHREAD_MUTEX_INITIALIZER;
522 pthread_t ssh_pthread;
523
524 void * ssh_thread(void * args)
525 {
526
527
528         fd_set readSet;
529         char buffer[BUFSIZ];
530         struct timeval tv;
531         tv.tv_sec = 0;
532         tv.tv_usec = 100000;
533         
534         while (true)
535         {
536                 //log_print(1, "ssh_thread", "loop - %d ssh's", ssh_array_used);
537                 FD_ZERO(&readSet);
538                 pthread_mutex_lock(&ssh_thread_mutex);
539
540                 if (!ssh_thread_running) break;
541
542                 for (int i = 0; i < ssh_array_used; ++i)
543                 {
544                         ssh * s = ssh_array[i];
545                         if (s == NULL) continue;
546                         for (int j = 0; j < s->nTunnels; ++j)
547                         {
548                                 FD_SET(s->tunnel[j].forward_sock, &readSet);
549                         }
550                 }
551
552                 pthread_mutex_unlock(&ssh_thread_mutex);
553                 select(ssh_thread_maxfd+1, &readSet, NULL, NULL, &tv);
554                 pthread_mutex_lock(&ssh_thread_mutex);
555
556                 for (int i = 0; i < ssh_array_used; ++i)
557                 {
558                         ssh * s = ssh_array[i];
559                         //log_print(2, "ssh_thread", "array[%d] = %p", i, s);
560                         if (s == NULL) continue;
561                         for (int j = 0; j < s->nTunnels; ++j)
562                         {
563                                 //log_print(2, "ssh_thread", "Tunnel number %d, socket %d", j, s->tunnel[j].forward_sock);
564                                 if (FD_ISSET(s->tunnel[j].forward_sock, &readSet))
565                                 {
566                                         //log_print(2, "ssh_thread", "reading from socket %d", s->tunnel[j].forward_sock);
567                                         int len = read(s->tunnel[j].forward_sock, buffer, sizeof(buffer));
568                                         
569                                         if (len <= 0)
570                                                 continue;
571                                         buffer[len] = '\0';
572                                         int written = 0; int w = 0;
573                                         do
574                                         {
575                                                 //log_print(2, "ssh_thread", "writing %s to channel", buffer);
576                                                 w = libssh2_channel_write(s->tunnel[j].channel, buffer+written, len-written);
577                                                 assert(w >= 0);
578                                                 written += w;
579                                         }
580                                         while (w > 0 && written < len);
581                                 }
582                                 while (true)
583                                 {
584                                         //log_print(2, "ssh_thread", "Try to read from channel %p", s->tunnel[j].channel);
585                                         int len = libssh2_channel_read(s->tunnel[j].channel, buffer, sizeof(buffer));
586                                         //log_print(2, "ssh_thread", "Read %s from channel", buffer);
587                                         if (len == LIBSSH2_ERROR_EAGAIN) break;
588                                         assert(len >= 0);
589
590                                         int written = 0; int w = 0;
591                                         while (written < len)
592                                         {
593                                                 //log_print(2, "ssh_thread", "Wrote %s to socket %d", buffer+written, s->tunnel[j].forward_sock);
594                                                 w = write(s->tunnel[j].forward_sock, buffer+written, len-written);
595                                                 written += w;
596                                         }
597                                         if (libssh2_channel_eof(s->tunnel[j].channel))
598                                         {
599                                                 //log_print(1, "ssh_thread", "Got to eof in channel %p", s->tunnel[j].channel);
600                                         }
601                                 }
602                         }
603                 }
604                 pthread_mutex_unlock(&ssh_thread_mutex);
605         }
606
607         return NULL;
608 }
609
610 void ssh_thread_add(ssh * s)
611 {
612         pthread_mutex_lock(&ssh_thread_mutex);
613
614         ssh_array_used++;
615
616         bool found = false;
617         for (int i = 0; (i < ssh_array_reserved && !found); ++i)
618         {
619                 if (ssh_array[i] == NULL)
620                 {
621                         ssh_array[i] = s;
622                         found = true;
623                         break;
624                 }
625         }
626
627
628         if (!found)
629         {
630                 int old = ssh_array_reserved;
631                 ssh_array_reserved = (ssh_array_reserved + 1) * 2;
632                 if (ssh_array == NULL)
633                         ssh_array = (ssh**)(calloc(ssh_array_reserved, sizeof(ssh*)));
634                 else
635                 {
636                         ssh_array = (ssh**)(realloc(ssh_array, ssh_array_reserved * sizeof(ssh*)));
637                         for (int i = old+1; i < ssh_array_reserved; ++i)
638                                 ssh_array[i] = NULL;
639                         
640                 }
641                 ssh_array[old] = s;
642         }
643
644         for (int i = 0; i < s->nTunnels; ++i)
645         {
646                 if (s->tunnel[i].forward_sock > ssh_thread_maxfd)
647                         ssh_thread_maxfd = s->tunnel[i].forward_sock;
648         }
649
650         if (!ssh_thread_running)
651         {
652                 ssh_thread_running = true;
653                 sigset_t set;
654                 int err;
655                 sigfillset(&set);
656                 err = pthread_sigmask(SIG_SETMASK, &set, NULL);
657                 if (err != 0)
658                         error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno));
659                 err = pthread_create(&ssh_pthread, NULL, ssh_thread, NULL);
660                 if (err != 0)
661                         error("ssh_thread_add", "pthread_create : %s", strerror(errno));
662                 sigemptyset(&set);
663                 err = pthread_sigmask(SIG_SETMASK, &set, NULL);
664                 if (err != 0)
665                         error("ssh_thread_add", "pthread_sigmask : %s", strerror(errno));
666         }
667
668         
669
670         pthread_mutex_unlock(&ssh_thread_mutex);
671 }
672
673 void ssh_thread_del(ssh * s)
674 {
675         pthread_mutex_lock(&ssh_thread_mutex);
676         
677         for (int i = 0; i < ssh_array_reserved; ++i)
678         {
679                 if (ssh_array[i] == s)
680                 {
681                         ssh_array[i] = NULL;
682                         ssh_thread_running = !(--ssh_array_used == 0);
683                         break;
684                 }
685         }
686
687         pthread_mutex_unlock(&ssh_thread_mutex);
688 }

UCC git Repository :: git.ucc.asn.au