Working on password-based user auth
[tpg/opendispense2.git] / src / client / main.c
index 5f5f1c8..c82d12d 100644 (file)
@@ -22,6 +22,7 @@
 #include <sys/socket.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
 #include <sys/socket.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
+#include <openssl/sha.h>       // SHA1
 
 // === TYPES ===
 typedef struct sItem {
 
 // === TYPES ===
 typedef struct sItem {
@@ -32,6 +33,7 @@ typedef struct sItem {
 
 // === PROTOTYPES ===
  int   ShowNCursesUI(void);
 
 // === PROTOTYPES ===
  int   ShowNCursesUI(void);
+void   PrintAlign(int Row, int Col, int Width, const char *Left, char Pad1, const char *Mid, char Pad2, const char *Right, ...);
 
  int   sendf(int Socket, const char *Format, ...);
  int   OpenConnection(const char *Host, int Port);
 
  int   sendf(int Socket, const char *Format, ...);
  int   OpenConnection(const char *Host, int Port);
@@ -45,8 +47,7 @@ char  *gsDispenseServer = "localhost";
  int   giDispensePort = 11020;
 tItem  *gaItems;
  int   giNumItems;
  int   giDispensePort = 11020;
 tItem  *gaItems;
  int   giNumItems;
-regex_t        gArrayRegex;
-regex_t        gItemRegex;
+regex_t        gArrayRegex, gItemRegex, gSaltRegex;
 
 // === CODE ===
 int main(int argc, char *argv[])
 
 // === CODE ===
 int main(int argc, char *argv[])
@@ -60,6 +61,8 @@ int main(int argc, char *argv[])
        CompileRegex(&gArrayRegex, "^([0-9]{3})\\s+([A-Za-z]+)\\s+([0-9]+)", REG_EXTENDED);     //
        // > Code Type Ident Price Desc
        CompileRegex(&gItemRegex, "^([0-9]{3})\\s+(.+?)\\s+(.+?)\\s+([0-9]+)\\s+(.+)$", REG_EXTENDED);
        CompileRegex(&gArrayRegex, "^([0-9]{3})\\s+([A-Za-z]+)\\s+([0-9]+)", REG_EXTENDED);     //
        // > Code Type Ident Price Desc
        CompileRegex(&gItemRegex, "^([0-9]{3})\\s+(.+?)\\s+(.+?)\\s+([0-9]+)\\s+(.+)$", REG_EXTENDED);
+       // > Code 'SALT' salt
+       CompileRegex(&gSaltRegex, "^([0-9]{3})\\s+(.+)\\s+(.+)$", REG_EXTENDED);
        
        // Connect to server
        sock = OpenConnection(gsDispenseServer, giDispensePort);
        
        // Connect to server
        sock = OpenConnection(gsDispenseServer, giDispensePort);
@@ -155,8 +158,6 @@ int main(int argc, char *argv[])
                printf("%3i %s\n", gaItems[i].Price, gaItems[i].Desc);
        }
        
                printf("%3i %s\n", gaItems[i].Price, gaItems[i].Desc);
        }
        
-       Authenticate(sock);
-       
        // and choose what to dispense
        // TODO: ncurses interface (with separation between item classes)
        // - Hmm... that would require standardising the item ID to be <class>:<index>
        // and choose what to dispense
        // TODO: ncurses interface (with separation between item classes)
        // - Hmm... that would require standardising the item ID to be <class>:<index>
@@ -193,6 +194,8 @@ int main(int argc, char *argv[])
        }
        #endif
        
        }
        #endif
        
+       Authenticate(sock);
+       
        if( i >= 0 )
        {       
                // Dispense!
        if( i >= 0 )
        {       
                // Dispense!
@@ -231,6 +234,25 @@ int main(int argc, char *argv[])
        return 0;
 }
 
        return 0;
 }
 
+void ShowItemAt(int Row, int Col, int Width, int Index)
+{
+        int    _x, _y, times;
+       
+       move( Row, Col );
+       
+       if( Index < 0 || Index >= giNumItems ) {
+               printw("%02i OOR", Index);
+               return ;
+       }
+       printw("%02i %s", Index, gaItems[Index].Desc);
+       
+       getyx(stdscr, _y, _x);
+       // Assumes max 4 digit prices
+       times = Width - 4 - (_x - Col); // TODO: Better handling for large prices
+       while(times--)  addch(' ');
+       printw("%4i", gaItems[Index].Price);
+}
+
 /**
  */
 int ShowNCursesUI(void)
 /**
  */
 int ShowNCursesUI(void)
@@ -241,7 +263,6 @@ int ShowNCursesUI(void)
        const int       displayMinWidth = 34;
        const int       displayMinItems = 8;
        char    *titleString = "Dispense";
        const int       displayMinWidth = 34;
        const int       displayMinItems = 8;
        char    *titleString = "Dispense";
-        int    titleStringLen = strlen(titleString);
         int    itemCount = displayMinItems;
         int    itemBase = 0;
         
         int    itemCount = displayMinItems;
         int    itemBase = 0;
         
@@ -258,21 +279,11 @@ int ShowNCursesUI(void)
        for( ;; )
        {
                // Header
        for( ;; )
        {
                // Header
-               move( yBase, xBase );
-               addch('/');
-               times = width/2 - titleStringLen/2 - 2;
-               while(times --) addch('-');
-               addch(' ');
-               addstr(titleString);
-               addch(' ');
-               times = width/2 - titleStringLen/2 - 2;
-               while(times --) addch('-');
-               addch('\\');
+               PrintAlign(yBase, xBase, width, "/", '-', titleString, '-', "\\");
                
                // Items
                for( i = 0; i < itemCount; i ++ )
                {
                
                // Items
                for( i = 0; i < itemCount; i ++ )
                {
-                        int    _x, _y;
                        move( yBase + 1 + i, xBase );
                        addch('|');
                        addch(' ');
                        move( yBase + 1 + i, xBase );
                        addch('|');
                        addch(' ');
@@ -290,16 +301,8 @@ int ShowNCursesUI(void)
                        }
                        // Show an item
                        else {
                        }
                        // Show an item
                        else {
-                               if( itemBase + i < 0 || itemBase + i >= giNumItems ) {
-                                       printw("%02i %i OOR", itemBase + i, i);
-                                       continue ;
-                               }
-                               printw("%02i %s", itemBase + i, gaItems[itemBase + i].Desc);
-                               
-                               getyx(stdscr, _y, _x);
-                               times = width - 6 - (_x - xBase);       // TODO: Better handling for large prices
-                               while(times--)  addch(' ');
-                               printw("%4i ", gaItems[itemBase + i].Price);
+                               ShowItemAt( yBase + 1 + i, xBase + 2, width - 4, itemBase + i);
+                               addch(' ');
                        }
                        
                        // Scrollbar (if needed)
                        }
                        
                        // Scrollbar (if needed)
@@ -326,30 +329,9 @@ int ShowNCursesUI(void)
                }
                
                // Footer
                }
                
                // Footer
-               move( yBase + 1 + itemCount, xBase );
-               addch('\\');
-               times = width/2 - titleStringLen/2 - 2;
-               while(times --) addch('-');
-               addch(' ');
-               addstr(titleString);
-               addch(' ');
-               times = width/2 - titleStringLen/2 - 2;
-               while(times --) addch('-');
-               addch('/');
-               
-               move( yBase + 1 + itemCount + 1, xBase );
-               {
-                        int    count = itemCount-2;
-                        int    ofs = itemBase;
-                       if( itemBase == 0 )     count ++;
-                       else    ofs ++;
-                       if( itemBase == giNumItems-itemCount) {
-                               count ++;
-                               ofs ++;
-                       }
-                       printw("%i - %i / %i items", itemBase, itemBase+count, giNumItems);
-               }
+               PrintAlign(yBase+height-2, xBase, width, "\\", '-', "", '-', "/");
                
                
+               // Get input
                ch = getch();
                
                if( ch == '\x1B' ) {
                ch = getch();
                
                if( ch == '\x1B' ) {
@@ -385,6 +367,55 @@ int ShowNCursesUI(void)
        return -1;
 }
 
        return -1;
 }
 
+void PrintAlign(int Row, int Col, int Width, const char *Left, char Pad1, const char *Mid, char Pad2, const char *Right, ...)
+{
+        int    lLen, mLen, rLen;
+        int    times;
+       
+       va_list args;
+       
+       // Get the length of the strings
+       va_start(args, Right);
+       lLen = vsnprintf(NULL, 0, Left, args);
+       mLen = vsnprintf(NULL, 0, Mid, args);
+       rLen = vsnprintf(NULL, 0, Right, args);
+       va_end(args);
+       
+       // Sanity check
+       if( lLen + mLen/2 > Width/2 || mLen/2 + rLen > Width/2 ) {
+               return ;        // TODO: What to do?
+       }
+       
+       move(Row, Col);
+       
+       // Render strings
+       va_start(args, Right);
+       // - Left
+       {
+               char    tmp[lLen+1];
+               vsnprintf(tmp, lLen+1, Left, args);
+               addstr(tmp);
+       }
+       // - Left padding
+       times = Width/2 - mLen/2 - lLen;
+       while(times--)  addch(Pad1);
+       // - Middle
+       {
+               char    tmp[mLen+1];
+               vsnprintf(tmp, mLen+1, Mid, args);
+               addstr(tmp);
+       }
+       // - Right Padding
+       times = Width/2 - mLen/2 - rLen;
+       while(times--)  addch(Pad2);
+       // - Right
+       {
+               char    tmp[rLen+1];
+               vsnprintf(tmp, rLen+1, Right, args);
+               addstr(tmp);
+       }
+}
+
 // === HELPERS ===
 int sendf(int Socket, const char *Format, ...)
 {
 // === HELPERS ===
 int sendf(int Socket, const char *Format, ...)
 {
@@ -454,6 +485,8 @@ void Authenticate(int Socket)
        struct passwd   *pwd;
        char    buf[512];
         int    responseCode;
        struct passwd   *pwd;
        char    buf[512];
         int    responseCode;
+       char    salt[32];
+       regmatch_t      matches[4];
        
        // Get user name
        pwd = getpwuid( getuid() );
        
        // Get user name
        pwd = getpwuid( getuid() );
@@ -471,6 +504,50 @@ void Authenticate(int Socket)
        case 200:       // Authenticated, return :)
                return ;
        case 401:       // Untrusted, attempt password authentication
        case 200:       // Authenticated, return :)
                return ;
        case 401:       // Untrusted, attempt password authentication
+               sendf(Socket, "USER %s\n", pwd->pw_name);
+               printf("Using username %s\n", pwd->pw_name);
+               
+               recv(Socket, buf, 511, 0);
+               trim(buf);
+               // TODO: Get Salt
+               // Expected format: 100 SALT <something> ...
+               // OR             : 100 User Set
+               printf("string = '%s'\n", buf);
+               RunRegex(&gSaltRegex, buf, 4, matches, "Malformed server response");
+               if( atoi(buf) != 100 ) {
+                       exit(-1);       // ERROR
+               }
+               if( memcmp( buf+matches[2].rm_so, "SALT", matches[2].rm_eo - matches[2].rm_so) == 0) {
+                       // Set salt
+                       memcpy( salt, buf + matches[3].rm_so, matches[3].rm_eo - matches[3].rm_so );
+                       salt[ matches[3].rm_eo - matches[3].rm_so ] = 0;
+                       printf("Salt: '%s'\n", salt);
+               }
+               
+               fflush(stdout);
+               {
+                        int    ofs = strlen(pwd->pw_name)+strlen(salt);
+                       char    tmp[ofs+20];
+                       char    *pass = getpass("Password: ");
+                       uint8_t h[20];
+                       
+                       strcpy(tmp, pwd->pw_name);
+                       strcat(tmp, salt);
+                       SHA1( (unsigned char*)pass, strlen(pass), h );
+                       memcpy(tmp+ofs, h, 20);
+                       
+                       // Hash all that
+                       SHA1( (unsigned char*)tmp, ofs+20, h );
+                       sprintf(buf, "%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x",
+                               h[ 0], h[ 1], h[ 2], h[ 3], h[ 4], h[ 5], h[ 6], h[ 7], h[ 8], h[ 9],
+                               h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19]
+                               );
+                       printf("Final hash: '%s'\n", buf);
+                       fflush(stdout); // Debug
+               }
+               
+               sendf(Socket, "PASS %s\n", buf);
+               recv(Socket, buf, 511, 0);
                break;
        case 404:       // Bad Username
                fprintf(stderr, "Bad Username '%s'\n", pwd->pw_name);
                break;
        case 404:       // Bad Username
                fprintf(stderr, "Bad Username '%s'\n", pwd->pw_name);

UCC git Repository :: git.ucc.asn.au