Cleanup and Implementations
[tpg/opendispense2.git] / src / server / server.c
1 /*
2  * OpenDispense 2 
3  * UCC (University [of WA] Computer Club) Electronic Accounting System
4  *
5  * server.c - Client Server Code
6  *
7  * This file is licenced under the 3-clause BSD Licence. See the file
8  * COPYING for full details.
9  */
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include "common.h"
13 #include <sys/socket.h>
14 #include <netinet/in.h>
15 #include <arpa/inet.h>
16 #include <unistd.h>
17 #include <string.h>
18
19 // HACKS
20 #define HACK_TPG_NOAUTH 1
21 #define HACK_ROOT_NOAUTH        1
22
23 // Statistics
24 #define MAX_CONNECTION_QUEUE    5
25 #define INPUT_BUFFER_SIZE       256
26
27 #define HASH_TYPE       SHA1
28 #define HASH_LENGTH     20
29
30 #define MSG_STR_TOO_LONG        "499 Command too long (limit "EXPSTR(INPUT_BUFFER_SIZE)")\n"
31
32 // === TYPES ===
33 typedef struct sClient
34 {
35          int    ID;     // Client ID
36          
37          int    bIsTrusted;     // Is the connection from a trusted host/port
38         
39         char    *Username;
40         char    Salt[9];
41         
42          int    UID;
43          int    bIsAuthed;
44 }       tClient;
45
46 // === PROTOTYPES ===
47 void    Server_Start(void);
48 void    Server_Cleanup(void);
49 void    Server_HandleClient(int Socket, int bTrusted);
50 char    *Server_ParseClientCommand(tClient *Client, char *CommandString);
51 // --- Commands ---
52 char    *Server_Cmd_USER(tClient *Client, char *Args);
53 char    *Server_Cmd_PASS(tClient *Client, char *Args);
54 char    *Server_Cmd_AUTOAUTH(tClient *Client, char *Args);
55 char    *Server_Cmd_ENUMITEMS(tClient *Client, char *Args);
56 char    *Server_Cmd_ITEMINFO(tClient *Client, char *Args);
57 char    *Server_Cmd_DISPENSE(tClient *Client, char *Args);
58 char    *Server_Cmd_GIVE(tClient *Client, char *Args);
59 char    *Server_Cmd_ADD(tClient *Client, char *Args);
60 char    *Server_Cmd_USERINFO(tClient *Client, char *Args);
61 // --- Helpers ---
62  int    GetUserAuth(const char *Salt, const char *Username, const uint8_t *Hash);
63 void    HexBin(uint8_t *Dest, char *Src, int BufSize);
64
65 // === GLOBALS ===
66  int    giServer_Port = 1020;
67  int    giServer_NextClientID = 1;
68 // - Commands
69 struct sClientCommand {
70         char    *Name;
71         char    *(*Function)(tClient *Client, char *Arguments);
72 }       gaServer_Commands[] = {
73         {"USER", Server_Cmd_USER},
74         {"PASS", Server_Cmd_PASS},
75         {"AUTOAUTH", Server_Cmd_AUTOAUTH},
76         {"ENUM_ITEMS", Server_Cmd_ENUMITEMS},
77         {"ITEM_INFO", Server_Cmd_ITEMINFO},
78         {"DISPENSE", Server_Cmd_DISPENSE},
79         {"GIVE", Server_Cmd_GIVE},
80         {"ADD", Server_Cmd_ADD},
81         {"USER_INFO", Server_Cmd_USERINFO}
82 };
83 #define NUM_COMMANDS    (sizeof(gaServer_Commands)/sizeof(gaServer_Commands[0]))
84  int    giServer_Socket;
85
86 // === CODE ===
87 /**
88  * \brief Open listenting socket and serve connections
89  */
90 void Server_Start(void)
91 {
92          int    client_socket;
93         struct sockaddr_in      server_addr, client_addr;
94
95         atexit(Server_Cleanup);
96
97         // Create Server
98         giServer_Socket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
99         if( giServer_Socket < 0 ) {
100                 fprintf(stderr, "ERROR: Unable to create server socket\n");
101                 return ;
102         }
103         
104         // Make listen address
105         memset(&server_addr, 0, sizeof(server_addr));
106         server_addr.sin_family = AF_INET;       // Internet Socket
107         server_addr.sin_addr.s_addr = htonl(INADDR_ANY);        // Listen on all interfaces
108         server_addr.sin_port = htons(giServer_Port);    // Port
109
110         // Bind
111         if( bind(giServer_Socket, (struct sockaddr *) &server_addr, sizeof(server_addr)) < 0 ) {
112                 fprintf(stderr, "ERROR: Unable to bind to 0.0.0.0:%i\n", giServer_Port);
113                 perror("Binding");
114                 return ;
115         }
116         
117         // Listen
118         if( listen(giServer_Socket, MAX_CONNECTION_QUEUE) < 0 ) {
119                 fprintf(stderr, "ERROR: Unable to listen to socket\n");
120                 perror("Listen");
121                 return ;
122         }
123         
124         printf("Listening on 0.0.0.0:%i\n", giServer_Port);
125         
126         for(;;)
127         {
128                 uint    len = sizeof(client_addr);
129                  int    bTrusted = 0;
130                 
131                 client_socket = accept(giServer_Socket, (struct sockaddr *) &client_addr, &len);
132                 if(client_socket < 0) {
133                         fprintf(stderr, "ERROR: Unable to accept client connection\n");
134                         return ;
135                 }
136                 
137                 if(giDebugLevel >= 2) {
138                         char    ipstr[INET_ADDRSTRLEN];
139                         inet_ntop(AF_INET, &client_addr.sin_addr, ipstr, INET_ADDRSTRLEN);
140                         printf("Client connection from %s:%i\n",
141                                 ipstr, ntohs(client_addr.sin_port));
142                 }
143                 
144                 // Trusted Connections
145                 if( ntohs(client_addr.sin_port) < 1024 )
146                 {
147                         // TODO: Make this runtime configurable
148                         switch( ntohl( client_addr.sin_addr.s_addr ) )
149                         {
150                         case 0x7F000001:        // 127.0.0.1    localhost
151                         //case 0x825E0D00:      // 130.95.13.0
152                         case 0x825E0D12:        // 130.95.13.18 mussel
153                         case 0x825E0D17:        // 130.95.13.23 martello
154                                 bTrusted = 1;
155                                 break;
156                         default:
157                                 break;
158                         }
159                 }
160                 
161                 // TODO: Multithread this?
162                 Server_HandleClient(client_socket, bTrusted);
163                 
164                 close(client_socket);
165         }
166 }
167
168 void Server_Cleanup(void)
169 {
170         printf("Close(%i)\n", giServer_Socket);
171         close(giServer_Socket);
172 }
173
174 /**
175  * \brief Reads from a client socket and parses the command strings
176  * \param Socket        Client socket number/handle
177  * \param bTrusted      Is the client trusted?
178  */
179 void Server_HandleClient(int Socket, int bTrusted)
180 {
181         char    inbuf[INPUT_BUFFER_SIZE];
182         char    *buf = inbuf;
183          int    remspace = INPUT_BUFFER_SIZE-1;
184          int    bytes = -1;
185         tClient clientInfo = {0};
186         
187         // Initialise Client info
188         clientInfo.ID = giServer_NextClientID ++;
189         clientInfo.bIsTrusted = bTrusted;
190         
191         // Read from client
192         /*
193          * Notes:
194          * - The `buf` and `remspace` variables allow a line to span several
195          *   calls to recv(), if a line is not completed in one recv() call
196          *   it is saved to the beginning of `inbuf` and `buf` is updated to
197          *   the end of it.
198          */
199         while( (bytes = recv(Socket, buf, remspace, 0)) > 0 )
200         {
201                 char    *eol, *start;
202                 buf[bytes] = '\0';      // Allow us to use stdlib string functions on it
203                 
204                 // Split by lines
205                 start = inbuf;
206                 while( (eol = strchr(start, '\n')) )
207                 {
208                         char    *ret;
209                         *eol = '\0';
210                         ret = Server_ParseClientCommand(&clientInfo, start);
211                         printf("ret = %s", ret);
212                         // `ret` is a string on the heap
213                         send(Socket, ret, strlen(ret), 0);
214                         free(ret);
215                         start = eol + 1;
216                 }
217                 
218                 // Check if there was an incomplete line
219                 if( *start != '\0' ) {
220                          int    tailBytes = bytes - (start-buf);
221                         // Roll back in buffer
222                         memcpy(inbuf, start, tailBytes);
223                         remspace -= tailBytes;
224                         if(remspace == 0) {
225                                 send(Socket, MSG_STR_TOO_LONG, sizeof(MSG_STR_TOO_LONG), 0);
226                                 buf = inbuf;
227                                 remspace = INPUT_BUFFER_SIZE - 1;
228                         }
229                 }
230                 else {
231                         buf = inbuf;
232                         remspace = INPUT_BUFFER_SIZE - 1;
233                 }
234         }
235         
236         // Check for errors
237         if( bytes < 0 ) {
238                 fprintf(stderr, "ERROR: Unable to recieve from client on socket %i\n", Socket);
239                 return ;
240         }
241         
242         if(giDebugLevel >= 2) {
243                 printf("Client %i: Disconnected\n", clientInfo.ID);
244         }
245 }
246
247 /**
248  * \brief Parses a client command and calls the required helper function
249  * \param Client        Pointer to client state structure
250  * \param CommandString Command from client (single line of the command)
251  * \return Heap String to return to the client
252  */
253 char *Server_ParseClientCommand(tClient *Client, char *CommandString)
254 {
255         char    *space, *args;
256          int    i;
257         
258         // Split at first space
259         space = strchr(CommandString, ' ');
260         if(space == NULL) {
261                 args = NULL;
262         }
263         else {
264                 *space = '\0';
265                 args = space + 1;
266         }
267         
268         // Find command
269         for( i = 0; i < NUM_COMMANDS; i++ )
270         {
271                 if(strcmp(CommandString, gaServer_Commands[i].Name) == 0)
272                         return gaServer_Commands[i].Function(Client, args);
273         }
274         
275         return strdup("400 Unknown Command\n");
276 }
277
278 // ---
279 // Commands
280 // ---
281 /**
282  * \brief Set client username
283  * 
284  * Usage: USER <username>
285  */
286 char *Server_Cmd_USER(tClient *Client, char *Args)
287 {
288         char    *ret;
289         
290         // Debug!
291         if( giDebugLevel )
292                 printf("Client %i authenticating as '%s'\n", Client->ID, Args);
293         
294         // Save username
295         if(Client->Username)
296                 free(Client->Username);
297         Client->Username = strdup(Args);
298         
299         #if USE_SALT
300         // Create a salt (that changes if the username is changed)
301         // Yes, I know, I'm a little paranoid, but who isn't?
302         Client->Salt[0] = 0x21 + (rand()&0x3F);
303         Client->Salt[1] = 0x21 + (rand()&0x3F);
304         Client->Salt[2] = 0x21 + (rand()&0x3F);
305         Client->Salt[3] = 0x21 + (rand()&0x3F);
306         Client->Salt[4] = 0x21 + (rand()&0x3F);
307         Client->Salt[5] = 0x21 + (rand()&0x3F);
308         Client->Salt[6] = 0x21 + (rand()&0x3F);
309         Client->Salt[7] = 0x21 + (rand()&0x3F);
310         
311         // TODO: Also send hash type to use, (SHA1 or crypt according to [DAA])
312         ret = mkstr("100 SALT %s\n", Client->Salt);
313         #else
314         ret = strdup("100 User Set\n");
315         #endif
316         return ret;
317 }
318
319 /**
320  * \brief Authenticate as a user
321  * 
322  * Usage: PASS <hash>
323  */
324 char *Server_Cmd_PASS(tClient *Client, char *Args)
325 {
326         uint8_t clienthash[HASH_LENGTH] = {0};
327         
328         // Read user's hash
329         HexBin(clienthash, Args, HASH_LENGTH);
330         
331         // TODO: Decrypt password passed
332         
333         Client->UID = GetUserAuth(Client->Salt, Client->Username, clienthash);
334
335         if( Client->UID != -1 ) {
336                 Client->bIsAuthed = 1;
337                 return strdup("200 Auth OK\n");
338         }
339
340         if( giDebugLevel ) {
341                  int    i;
342                 printf("Client %i: Password hash ", Client->ID);
343                 for(i=0;i<HASH_LENGTH;i++)
344                         printf("%02x", clienthash[i]&0xFF);
345                 printf("\n");
346         }
347         
348         return strdup("401 Auth Failure\n");
349 }
350
351 /**
352  * \brief Authenticate as a user without a password
353  * 
354  * Usage: AUTOAUTH <user>
355  */
356 char *Server_Cmd_AUTOAUTH(tClient *Client, char *Args)
357 {
358         char    *spos = strchr(Args, ' ');
359         if(spos)        *spos = '\0';   // Remove characters after the ' '
360         
361         // Check if trusted
362         if( !Client->bIsTrusted ) {
363                 if(giDebugLevel)
364                         printf("Client %i: Untrusted client attempting to AUTOAUTH\n", Client->ID);
365                 return strdup("401 Untrusted\n");
366         }
367         
368         // Get UID
369         Client->UID = GetUserID( Args );
370         if( Client->UID < 0 ) {
371                 if(giDebugLevel)
372                         printf("Client %i: Unknown user '%s'\n", Client->ID, Args);
373                 return strdup("401 Auth Failure\n");
374         }
375         
376         if(giDebugLevel)
377                 printf("Client %i: Authenticated as '%s' (%i)\n", Client->ID, Args, Client->UID);
378         
379         return strdup("200 Auth OK\n");
380 }
381
382 /**
383  * \brief Enumerate the items that the server knows about
384  */
385 char *Server_Cmd_ENUMITEMS(tClient *Client, char *Args)
386 {
387          int    retLen;
388          int    i;
389         char    *ret;
390
391         retLen = snprintf(NULL, 0, "201 Items %i", giNumItems);
392
393         for( i = 0; i < giNumItems; i ++ )
394         {
395                 retLen += snprintf(NULL, 0, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
396         }
397
398         ret = malloc(retLen+1);
399         retLen = 0;
400         retLen += sprintf(ret+retLen, "201 Items %i", giNumItems);
401
402         for( i = 0; i < giNumItems; i ++ ) {
403                 retLen += sprintf(ret+retLen, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
404         }
405
406         strcat(ret, "\n");
407
408         return ret;
409 }
410
411 tItem *_GetItemFromString(char *String)
412 {
413         tHandler        *handler;
414         char    *type = String;
415         char    *colon = strchr(String, ':');
416          int    num, i;
417         
418         if( !colon ) {
419                 return NULL;
420         }
421
422         num = atoi(colon+1);
423         *colon = '\0';
424
425         // Find handler
426         handler = NULL;
427         for( i = 0; i < giNumHandlers; i ++ )
428         {
429                 if( strcmp(gaHandlers[i]->Name, type) == 0) {
430                         handler = gaHandlers[i];
431                         break;
432                 }
433         }
434         if( !handler ) {
435                 return NULL;
436         }
437
438         // Find item
439         for( i = 0; i < giNumItems; i ++ )
440         {
441                 if( gaItems[i].Handler != handler )     continue;
442                 if( gaItems[i].ID != num )      continue;
443                 return &gaItems[i];
444         }
445         return NULL;
446 }
447
448 /**
449  * \brief Fetch information on a specific item
450  */
451 char *Server_Cmd_ITEMINFO(tClient *Client, char *Args)
452 {
453          int    retLen = 0;
454         char    *ret;
455         tItem   *item = _GetItemFromString(Args);
456         
457         if( !item ) {
458                 return strdup("406 Bad Item ID\n");
459         }
460
461         // Create return
462         retLen = snprintf(NULL, 0, "202 Item %s:%i %i %s\n",
463                 item->Handler->Name, item->ID, item->Price, item->Name);
464         ret = malloc(retLen+1);
465         sprintf(ret, "202 Item %s:%i %i %s\n",
466                 item->Handler->Name, item->ID, item->Price, item->Name);
467
468         return ret;
469 }
470
471 char *Server_Cmd_DISPENSE(tClient *Client, char *Args)
472 {
473         tItem   *item;
474          int    ret;
475         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
476
477         item = _GetItemFromString(Args);
478         if( !item ) {
479                 return strdup("406 Bad Item ID\n");
480         }
481
482         switch( ret = DispenseItem( Client->UID, item ) )
483         {
484         case 0: return strdup("200 Dispense OK\n");
485         case 1: return strdup("501 Unable to dispense\n");
486         case 2: return strdup("402 Poor You\n");
487         default:
488                 return strdup("500 Dispense Error\n");
489         }
490 }
491
492 char *Server_Cmd_GIVE(tClient *Client, char *Args)
493 {
494         char    *recipient, *ammount, *reason;
495          int    uid, iAmmount;
496         
497         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
498
499         recipient = Args;
500
501         ammount = strchr(Args, ' ');
502         if( !ammount )  return strdup("407 Invalid Argument, expected 3 parameters, 1 encountered\n");
503         *ammount = '\0';
504         ammount ++;
505
506         reason = strchr(ammount, ' ');
507         if( !reason )   return strdup("407 Invalid Argument, expected 3 parameters, 2 encountered\n");
508         *reason = '\0';
509         reason ++;
510
511         // Get recipient
512         uid = GetUserID(recipient);
513         if( uid == -1 ) return strdup("404 Invalid target user");
514
515         // Parse ammount
516         iAmmount = atoi(ammount);
517         if( iAmmount <= 0 )     return strdup("407 Invalid Argument, ammount must be > zero\n");
518
519         // Do give
520         switch( DispenseGive(Client->UID, uid, iAmmount, reason) )
521         {
522         case 0:
523                 return strdup("200 Give OK\n");
524         case 2:
525                 return strdup("402 Poor You\n");
526         default:
527                 return strdup("500 Unknown error\n");
528         }
529 }
530
531 char *Server_Cmd_ADD(tClient *Client, char *Args)
532 {
533         char    *user, *ammount, *reason;
534          int    uid, iAmmount;
535         
536         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
537
538         user = Args;
539
540         ammount = strchr(Args, ' ');
541         if( !ammount )  return strdup("407 Invalid Argument, expected 3 parameters, 1 encountered\n");
542         *ammount = '\0';
543         ammount ++;
544
545         reason = strchr(ammount, ' ');
546         if( !reason )   return strdup("407 Invalid Argument, expected 3 parameters, 2 encountered\n");
547         *reason = '\0';
548         reason ++;
549
550         // TODO: Check if the current user is in coke/higher
551
552         // Get recipient
553         uid = GetUserID(user);
554         if( uid == -1 ) return strdup("404 Invalid user");
555
556         // Parse ammount
557         iAmmount = atoi(ammount);
558         if( iAmmount == 0 && ammount[0] != '0' )        return strdup("407 Invalid Argument, ammount must be > zero\n");
559
560         // Do give
561         switch( DispenseAdd(uid, Client->UID, iAmmount, reason) )
562         {
563         case 0:
564                 return strdup("200 Add OK\n");
565         case 2:
566                 return strdup("402 Poor Guy\n");
567         default:
568                 return strdup("500 Unknown error\n");
569         }
570 }
571
572 char *Server_Cmd_USERINFO(tClient *Client, char *Args)
573 {
574          int    uid;
575         char    *user = Args;
576         char    *space;
577         
578         space = strchr(user, ' ');
579         if(space)       *space = '\0';
580         
581         // Get recipient
582         uid = GetUserID(user);
583         if( uid == -1 ) return strdup("404 Invalid user");
584
585         return mkstr("202 User %s %i user\n", user, GetBalance(uid));
586 }
587
588 /**
589  * \brief Authenticate a user
590  * \return User ID, or -1 if authentication failed
591  */
592 int GetUserAuth(const char *Salt, const char *Username, const uint8_t *ProvidedHash)
593 {
594         #if 0
595         uint8_t h[20];
596          int    ofs = strlen(Username) + strlen(Salt);
597         char    input[ ofs + 40 + 1];
598         char    tmp[4 + strlen(Username) + 1];  // uid=%s
599         #endif
600         
601         #if HACK_TPG_NOAUTH
602         if( strcmp(Username, "tpg") == 0 )
603                 return GetUserID("tpg");
604         #endif
605         #if HACK_ROOT_NOAUTH
606         if( strcmp(Username, "root") == 0 )
607                 return GetUserID("root");
608         #endif
609         
610         #if 0
611         //
612         strcpy(input, Username);
613         strcpy(input, Salt);
614         // TODO: Get user's SHA-1 hash
615         sprintf(tmp, "uid=%s", Username);
616         ldap_search_s(ld, "", LDAP_SCOPE_BASE, tmp, "userPassword", 0, res);
617         
618         sprintf(input+ofs, "%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x",
619                 h[ 0], h[ 1], h[ 2], h[ 3], h[ 4], h[ 5], h[ 6], h[ 7], h[ 8], h[ 9],
620                 h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19]
621                 );
622         // Then create the hash from the provided salt
623         // Compare that with the provided hash
624         #endif
625         
626         return -1;
627 }
628
629 // --- INTERNAL HELPERS ---
630 // TODO: Move to another file
631 void HexBin(uint8_t *Dest, char *Src, int BufSize)
632 {
633          int    i;
634         for( i = 0; i < BufSize; i ++ )
635         {
636                 uint8_t val = 0;
637                 
638                 if('0' <= *Src && *Src <= '9')
639                         val |= (*Src-'0') << 4;
640                 else if('A' <= *Src && *Src <= 'F')
641                         val |= (*Src-'A'+10) << 4;
642                 else if('a' <= *Src && *Src <= 'f')
643                         val |= (*Src-'a'+10) << 4;
644                 else
645                         break;
646                 Src ++;
647                 
648                 if('0' <= *Src && *Src <= '9')
649                         val |= (*Src-'0');
650                 else if('A' <= *Src && *Src <= 'F')
651                         val |= (*Src-'A'+10);
652                 else if('a' <= *Src && *Src <= 'f')
653                         val |= (*Src-'a'+10);
654                 else
655                         break;
656                 Src ++;
657                 
658                 Dest[i] = val;
659         }
660         for( ; i < BufSize; i++ )
661                 Dest[i] = 0;
662 }
663
664 /**
665  * \brief Decode a Base64 value
666  */
667 int UnBase64(uint8_t *Dest, char *Src, int BufSize)
668 {
669         uint32_t        val;
670          int    i, j;
671         char    *start_src = Src;
672         
673         for( i = 0; i+2 < BufSize; i += 3 )
674         {
675                 val = 0;
676                 for( j = 0; j < 4; j++, Src ++ ) {
677                         if('A' <= *Src && *Src <= 'Z')
678                                 val |= (*Src - 'A') << ((3-j)*6);
679                         else if('a' <= *Src && *Src <= 'z')
680                                 val |= (*Src - 'a' + 26) << ((3-j)*6);
681                         else if('0' <= *Src && *Src <= '9')
682                                 val |= (*Src - '0' + 52) << ((3-j)*6);
683                         else if(*Src == '+')
684                                 val |= 62 << ((3-j)*6);
685                         else if(*Src == '/')
686                                 val |= 63 << ((3-j)*6);
687                         else if(!*Src)
688                                 break;
689                         else if(*Src != '=')
690                                 j --;   // Ignore invalid characters
691                 }
692                 Dest[i  ] = (val >> 16) & 0xFF;
693                 Dest[i+1] = (val >> 8) & 0xFF;
694                 Dest[i+2] = val & 0xFF;
695                 if(j != 4)      break;
696         }
697         
698         // Finish things off
699         if(i   < BufSize)
700                 Dest[i] = (val >> 16) & 0xFF;
701         if(i+1 < BufSize)
702                 Dest[i+1] = (val >> 8) & 0xFF;
703         
704         return Src - start_src;
705 }

UCC git Repository :: git.ucc.asn.au