Working on password-based user auth
[tpg/opendispense2.git] / src / client / main.c
1 /*
2  * OpenDispense 2 
3  * UCC (University [of WA] Computer Club) Electronic Accounting System
4  * - Dispense Client
5  *
6  * main.c - Core and Initialisation
7  *
8  * This file is licenced under the 3-clause BSD Licence. See the file
9  * COPYING for full details.
10  */
11 #include <stdlib.h>
12 #include <stdio.h>
13 #include <string.h>
14 #include <ctype.h>      // isspace
15 #include <stdarg.h>
16 #include <regex.h>
17 #include <ncurses.h>
18
19 #include <unistd.h>     // close
20 #include <netdb.h>      // gethostbyname
21 #include <pwd.h>        // getpwuids
22 #include <sys/socket.h>
23 #include <netinet/in.h>
24 #include <arpa/inet.h>
25 #include <openssl/sha.h>        // SHA1
26
27 // === TYPES ===
28 typedef struct sItem {
29         char    *Ident;
30         char    *Desc;
31          int    Price;
32 }       tItem;
33
34 // === PROTOTYPES ===
35  int    ShowNCursesUI(void);
36 void    PrintAlign(int Row, int Col, int Width, const char *Left, char Pad1, const char *Mid, char Pad2, const char *Right, ...);
37
38  int    sendf(int Socket, const char *Format, ...);
39  int    OpenConnection(const char *Host, int Port);
40 void    Authenticate(int Socket);
41 char    *trim(char *string);
42  int    RunRegex(regex_t *regex, const char *string, int nMatches, regmatch_t *matches, const char *errorMessage);
43 void    CompileRegex(regex_t *regex, const char *pattern, int flags);
44
45 // === GLOBALS ===
46 char    *gsDispenseServer = "localhost";
47  int    giDispensePort = 11020;
48 tItem   *gaItems;
49  int    giNumItems;
50 regex_t gArrayRegex, gItemRegex, gSaltRegex;
51
52 // === CODE ===
53 int main(int argc, char *argv[])
54 {
55          int    sock;
56          int    i, responseCode, len;
57         char    buffer[BUFSIZ];
58         
59         // -- Create regular expressions
60         // > Code Type Count ...
61         CompileRegex(&gArrayRegex, "^([0-9]{3})\\s+([A-Za-z]+)\\s+([0-9]+)", REG_EXTENDED);     //
62         // > Code Type Ident Price Desc
63         CompileRegex(&gItemRegex, "^([0-9]{3})\\s+(.+?)\\s+(.+?)\\s+([0-9]+)\\s+(.+)$", REG_EXTENDED);
64         // > Code 'SALT' salt
65         CompileRegex(&gSaltRegex, "^([0-9]{3})\\s+(.+)\\s+(.+)$", REG_EXTENDED);
66         
67         // Connect to server
68         sock = OpenConnection(gsDispenseServer, giDispensePort);
69         if( sock < 0 )  return -1;
70
71         // Determine what to do
72         if( argc > 1 )
73         {
74                 if( strcmp(argv[1], "acct") == 0 )
75                 {
76                         // Alter account
77                         // List accounts
78                         return 0;
79                 }
80         }
81
82         // Ask server for stock list
83         send(sock, "ENUM_ITEMS\n", 11, 0);
84         len = recv(sock, buffer, BUFSIZ-1, 0);
85         buffer[len] = '\0';
86         
87         trim(buffer);
88         
89         printf("Output: %s\n", buffer);
90         
91         responseCode = atoi(buffer);
92         if( responseCode != 201 )
93         {
94                 fprintf(stderr, "Unknown response from dispense server (Response Code %i)\n", responseCode);
95                 return -1;
96         }
97         
98         // Get item list
99         {
100                 char    *itemType, *itemStart;
101                  int    count;
102                 regmatch_t      matches[4];
103                 
104                 // Expected format: 201 Items <count> <item1> <item2> ...
105                 RunRegex(&gArrayRegex, buffer, 4, matches, "Malformed server response");
106                 
107                 itemType = &buffer[ matches[2].rm_so ]; buffer[ matches[2].rm_eo ] = '\0';
108                 count = atoi( &buffer[ matches[3].rm_so ] );
109                 
110                 // Check array type
111                 if( strcmp(itemType, "Items") != 0 ) {
112                         // What the?!
113                         fprintf(stderr, "Unexpected array type, expected 'Items', got '%s'\n",
114                                 itemType);
115                         return -1;
116                 }
117                 
118                 itemStart = &buffer[ matches[3].rm_eo ];
119                 
120                 gaItems = malloc( count * sizeof(tItem) );
121                 
122                 for( giNumItems = 0; giNumItems < count && itemStart; giNumItems ++ )
123                 {
124                         char    *next = strchr( ++itemStart, ' ' );
125                         if( next )      *next = '\0';
126                         gaItems[giNumItems].Ident = strdup(itemStart);
127                         itemStart = next;
128                 }
129         }
130         
131         // Get item information
132         for( i = 0; i < giNumItems; i ++ )
133         {
134                 regmatch_t      matches[6];
135                 
136                 // Print item Ident
137                 printf("%2i %s\t", i, gaItems[i].Ident);
138                 
139                 // Get item info
140                 sendf(sock, "ITEM_INFO %s\n", gaItems[i].Ident);
141                 len = recv(sock, buffer, BUFSIZ-1, 0);
142                 buffer[len] = '\0';
143                 trim(buffer);
144                 
145                 responseCode = atoi(buffer);
146                 if( responseCode != 202 ) {
147                         fprintf(stderr, "Unknown response from dispense server (Response Code %i)\n", responseCode);
148                         return -1;
149                 }
150                 
151                 RunRegex(&gItemRegex, buffer, 6, matches, "Malformed server response");
152                 
153                 buffer[ matches[3].rm_eo ] = '\0';
154                 
155                 gaItems[i].Price = atoi( buffer + matches[4].rm_so );
156                 gaItems[i].Desc = strdup( buffer + matches[5].rm_so );
157                 
158                 printf("%3i %s\n", gaItems[i].Price, gaItems[i].Desc);
159         }
160         
161         // and choose what to dispense
162         // TODO: ncurses interface (with separation between item classes)
163         // - Hmm... that would require standardising the item ID to be <class>:<index>
164         // Oh, why not :)
165         
166         #if 1
167         i = ShowNCursesUI();
168         #else
169         
170         for(;;)
171         {
172                 char    *buf;
173                 
174                 fgets(buffer, BUFSIZ, stdin);
175                 
176                 buf = trim(buffer);
177                 
178                 if( buf[0] == 'q' )     break;
179                 
180                 i = atoi(buf);
181                 
182                 printf("buf = '%s', atoi(buf) = %i\n", buf, i);
183                 
184                 if( i != 0 || buf[0] == '0' )
185                 {
186                         printf("i = %i\n", i);
187                         
188                         if( i < 0 || i >= giNumItems ) {
189                                 printf("Bad item (should be between 0 and %i)\n", giNumItems);
190                                 continue;
191                         }
192                         break;
193                 }
194         }
195         #endif
196         
197         Authenticate(sock);
198         
199         if( i >= 0 )
200         {       
201                 // Dispense!
202                 sendf(sock, "DISPENSE %s\n", gaItems[i].Ident);
203                 
204                 len = recv(sock, buffer, BUFSIZ-1, 0);
205                 buffer[len] = '\0';
206                 trim(buffer);
207                 
208                 responseCode = atoi(buffer);
209                 switch( responseCode )
210                 {
211                 case 200:
212                         printf("Dispense OK\n");
213                         break;
214                 case 401:
215                         printf("Not authenticated\n");
216                         break;
217                 case 402:
218                         printf("Insufficient balance\n");
219                         break;
220                 case 406:
221                         printf("Bad item name, bug report\n");
222                         break;
223                 case 500:
224                         printf("Item failed to dispense, is the slot empty?\n");
225                         break;
226                 default:
227                         printf("Unknown response code %i\n", responseCode);
228                         break;
229                 }
230         }
231
232         close(sock);
233
234         return 0;
235 }
236
237 void ShowItemAt(int Row, int Col, int Width, int Index)
238 {
239          int    _x, _y, times;
240         
241         move( Row, Col );
242         
243         if( Index < 0 || Index >= giNumItems ) {
244                 printw("%02i OOR", Index);
245                 return ;
246         }
247         printw("%02i %s", Index, gaItems[Index].Desc);
248         
249         getyx(stdscr, _y, _x);
250         // Assumes max 4 digit prices
251         times = Width - 4 - (_x - Col); // TODO: Better handling for large prices
252         while(times--)  addch(' ');
253         printw("%4i", gaItems[Index].Price);
254 }
255
256 /**
257  */
258 int ShowNCursesUI(void)
259 {
260          int    ch;
261          int    i, times;
262          int    xBase, yBase;
263         const int       displayMinWidth = 34;
264         const int       displayMinItems = 8;
265         char    *titleString = "Dispense";
266          int    itemCount = displayMinItems;
267          int    itemBase = 0;
268          
269          int    height = itemCount + 3;
270          int    width = displayMinWidth;
271          
272         // Enter curses mode
273         initscr();
274         raw(); noecho();
275         
276         xBase = COLS/2 - width/2;
277         yBase = LINES/2 - height/2;
278         
279         for( ;; )
280         {
281                 // Header
282                 PrintAlign(yBase, xBase, width, "/", '-', titleString, '-', "\\");
283                 
284                 // Items
285                 for( i = 0; i < itemCount; i ++ )
286                 {
287                         move( yBase + 1 + i, xBase );
288                         addch('|');
289                         addch(' ');
290                         
291                         // Check for ... row
292                         if( i == 0 && itemBase > 0 ) {
293                                 printw("   ...");
294                                 times = width - 1 - 8;
295                                 while(times--)  addch(' ');
296                         }
297                         else if( i == itemCount - 1 && itemBase < giNumItems - itemCount ) {
298                                 printw("   ...");
299                                 times = width - 1 - 8;
300                                 while(times--)  addch(' ');
301                         }
302                         // Show an item
303                         else {
304                                 ShowItemAt( yBase + 1 + i, xBase + 2, width - 4, itemBase + i);
305                                 addch(' ');
306                         }
307                         
308                         // Scrollbar (if needed)
309                         if( giNumItems > itemCount ) {
310                                 if( i == 0 ) {
311                                         addch('A');
312                                 }
313                                 else if( i == itemCount - 1 ) {
314                                         addch('V');
315                                 }
316                                 else {
317                                          int    percentage = itemBase * 100 / (giNumItems-itemCount);
318                                         if( i-1 == percentage*(itemCount-3)/100 ) {
319                                                 addch('#');
320                                         }
321                                         else {
322                                                 addch('|');
323                                         }
324                                 }
325                         }
326                         else {
327                                 addch('|');
328                         }
329                 }
330                 
331                 // Footer
332                 PrintAlign(yBase+height-2, xBase, width, "\\", '-', "", '-', "/");
333                 
334                 // Get input
335                 ch = getch();
336                 
337                 if( ch == '\x1B' ) {
338                         ch = getch();
339                         if( ch == '[' ) {
340                                 ch = getch();
341                                 
342                                 switch(ch)
343                                 {
344                                 case 'B':
345                                         if( itemBase < giNumItems - (itemCount) )
346                                                 itemBase ++;
347                                         break;
348                                 case 'A':
349                                         if( itemBase > 0 )
350                                                 itemBase --;
351                                         break;
352                                 }
353                         }
354                         else {
355                                 
356                         }
357                 }
358                 else {
359                         break;
360                 }
361                 
362         }
363         
364         
365         // Leave
366         endwin();
367         return -1;
368 }
369
370 void PrintAlign(int Row, int Col, int Width, const char *Left, char Pad1, const char *Mid, char Pad2, const char *Right, ...)
371 {
372          int    lLen, mLen, rLen;
373          int    times;
374         
375         va_list args;
376         
377         // Get the length of the strings
378         va_start(args, Right);
379         lLen = vsnprintf(NULL, 0, Left, args);
380         mLen = vsnprintf(NULL, 0, Mid, args);
381         rLen = vsnprintf(NULL, 0, Right, args);
382         va_end(args);
383         
384         // Sanity check
385         if( lLen + mLen/2 > Width/2 || mLen/2 + rLen > Width/2 ) {
386                 return ;        // TODO: What to do?
387         }
388         
389         move(Row, Col);
390         
391         // Render strings
392         va_start(args, Right);
393         // - Left
394         {
395                 char    tmp[lLen+1];
396                 vsnprintf(tmp, lLen+1, Left, args);
397                 addstr(tmp);
398         }
399         // - Left padding
400         times = Width/2 - mLen/2 - lLen;
401         while(times--)  addch(Pad1);
402         // - Middle
403         {
404                 char    tmp[mLen+1];
405                 vsnprintf(tmp, mLen+1, Mid, args);
406                 addstr(tmp);
407         }
408         // - Right Padding
409         times = Width/2 - mLen/2 - rLen;
410         while(times--)  addch(Pad2);
411         // - Right
412         {
413                 char    tmp[rLen+1];
414                 vsnprintf(tmp, rLen+1, Right, args);
415                 addstr(tmp);
416         }
417 }
418
419 // === HELPERS ===
420 int sendf(int Socket, const char *Format, ...)
421 {
422         va_list args;
423          int    len;
424         
425         va_start(args, Format);
426         len = vsnprintf(NULL, 0, Format, args);
427         va_end(args);
428         
429         {
430                 char    buf[len+1];
431                 va_start(args, Format);
432                 vsnprintf(buf, len+1, Format, args);
433                 va_end(args);
434                 
435                 return send(Socket, buf, len, 0);
436         }
437 }
438
439 int OpenConnection(const char *Host, int Port)
440 {
441         struct hostent  *host;
442         struct sockaddr_in      serverAddr;
443          int    sock;
444         
445         host = gethostbyname(Host);
446         if( !host ) {
447                 fprintf(stderr, "Unable to look up '%s'\n", Host);
448                 return -1;
449         }
450         
451         memset(&serverAddr, 0, sizeof(serverAddr));
452         
453         serverAddr.sin_family = AF_INET;        // IPv4
454         // NOTE: I have a suspicion that IPv6 will play sillybuggers with this :)
455         serverAddr.sin_addr.s_addr = *((unsigned long *) host->h_addr_list[0]);
456         serverAddr.sin_port = htons(Port);
457         
458         sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
459         if( sock < 0 ) {
460                 fprintf(stderr, "Failed to create socket\n");
461                 return -1;
462         }
463         
464         #if USE_AUTOAUTH
465         {
466                 struct sockaddr_in      localAddr;
467                 memset(&localAddr, 0, sizeof(localAddr));
468                 localAddr.sin_family = AF_INET; // IPv4
469                 localAddr.sin_port = 1023;      // IPv4
470                 // Attempt to bind to low port for autoauth
471                 bind(sock, &localAddr, sizeof(localAddr));
472         }
473         #endif
474         
475         if( connect(sock, (struct sockaddr *) &serverAddr, sizeof(serverAddr)) < 0 ) {
476                 fprintf(stderr, "Failed to connect to server\n");
477                 return -1;
478         }
479         
480         return sock;
481 }
482
483 void Authenticate(int Socket)
484 {
485         struct passwd   *pwd;
486         char    buf[512];
487          int    responseCode;
488         char    salt[32];
489         regmatch_t      matches[4];
490         
491         // Get user name
492         pwd = getpwuid( getuid() );
493         
494         // Attempt automatic authentication
495         sendf(Socket, "AUTOAUTH %s\n", pwd->pw_name);
496         
497         // Check if it worked
498         recv(Socket, buf, 511, 0);
499         trim(buf);
500         
501         responseCode = atoi(buf);
502         switch( responseCode )
503         {
504         case 200:       // Authenticated, return :)
505                 return ;
506         case 401:       // Untrusted, attempt password authentication
507                 sendf(Socket, "USER %s\n", pwd->pw_name);
508                 printf("Using username %s\n", pwd->pw_name);
509                 
510                 recv(Socket, buf, 511, 0);
511                 trim(buf);
512                 // TODO: Get Salt
513                 // Expected format: 100 SALT <something> ...
514                 // OR             : 100 User Set
515                 printf("string = '%s'\n", buf);
516                 RunRegex(&gSaltRegex, buf, 4, matches, "Malformed server response");
517                 if( atoi(buf) != 100 ) {
518                         exit(-1);       // ERROR
519                 }
520                 if( memcmp( buf+matches[2].rm_so, "SALT", matches[2].rm_eo - matches[2].rm_so) == 0) {
521                         // Set salt
522                         memcpy( salt, buf + matches[3].rm_so, matches[3].rm_eo - matches[3].rm_so );
523                         salt[ matches[3].rm_eo - matches[3].rm_so ] = 0;
524                         printf("Salt: '%s'\n", salt);
525                 }
526                 
527                 fflush(stdout);
528                 {
529                          int    ofs = strlen(pwd->pw_name)+strlen(salt);
530                         char    tmp[ofs+20];
531                         char    *pass = getpass("Password: ");
532                         uint8_t h[20];
533                         
534                         strcpy(tmp, pwd->pw_name);
535                         strcat(tmp, salt);
536                         SHA1( (unsigned char*)pass, strlen(pass), h );
537                         memcpy(tmp+ofs, h, 20);
538                         
539                         // Hash all that
540                         SHA1( (unsigned char*)tmp, ofs+20, h );
541                         sprintf(buf, "%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x",
542                                 h[ 0], h[ 1], h[ 2], h[ 3], h[ 4], h[ 5], h[ 6], h[ 7], h[ 8], h[ 9],
543                                 h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19]
544                                 );
545                         printf("Final hash: '%s'\n", buf);
546                         fflush(stdout); // Debug
547                 }
548                 
549                 sendf(Socket, "PASS %s\n", buf);
550                 recv(Socket, buf, 511, 0);
551                 break;
552         case 404:       // Bad Username
553                 fprintf(stderr, "Bad Username '%s'\n", pwd->pw_name);
554                 exit(-1);
555         default:
556                 fprintf(stderr, "Unkown response code %i from server\n", responseCode);
557                 printf("%s\n", buf);
558                 exit(-1);
559         }
560         
561         printf("%s\n", buf);
562 }
563
564 char *trim(char *string)
565 {
566          int    i;
567         
568         while( isspace(*string) )
569                 string ++;
570         
571         for( i = strlen(string); i--; )
572         {
573                 if( isspace(string[i]) )
574                         string[i] = '\0';
575                 else
576                         break;
577         }
578         
579         return string;
580 }
581
582 int RunRegex(regex_t *regex, const char *string, int nMatches, regmatch_t *matches, const char *errorMessage)
583 {
584          int    ret;
585         
586         ret = regexec(regex, string, nMatches, matches, 0);
587         if( ret ) {
588                 size_t  len = regerror(ret, regex, NULL, 0);
589                 char    errorStr[len];
590                 regerror(ret, regex, errorStr, len);
591                 printf("string = '%s'\n", string);
592                 fprintf(stderr, "%s\n%s", errorMessage, errorStr);
593                 exit(-1);
594         }
595         
596         return ret;
597 }
598
599 void CompileRegex(regex_t *regex, const char *pattern, int flags)
600 {
601          int    ret = regcomp(regex, pattern, flags);
602         if( ret ) {
603                 size_t  len = regerror(ret, regex, NULL, 0);
604                 char    errorStr[len];
605                 regerror(ret, regex, errorStr, len);
606                 fprintf(stderr, "Regex compilation failed - %s\n", errorStr);
607                 exit(-1);
608         }
609 }

UCC git Repository :: git.ucc.asn.au