Removed debug from server, cleaning client
[tpg/opendispense2.git] / src / server / server.c
1 /*
2  * OpenDispense 2 
3  * UCC (University [of WA] Computer Club) Electronic Accounting System
4  *
5  * server.c - Client Server Code
6  *
7  * This file is licenced under the 3-clause BSD Licence. See the file
8  * COPYING for full details.
9  */
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include "common.h"
13 #include <sys/socket.h>
14 #include <netinet/in.h>
15 #include <arpa/inet.h>
16 #include <unistd.h>
17 #include <string.h>
18
19 // HACKS
20 #define HACK_TPG_NOAUTH 1
21 #define HACK_ROOT_NOAUTH        1
22
23 // Statistics
24 #define MAX_CONNECTION_QUEUE    5
25 #define INPUT_BUFFER_SIZE       256
26
27 #define HASH_TYPE       SHA1
28 #define HASH_LENGTH     20
29
30 #define MSG_STR_TOO_LONG        "499 Command too long (limit "EXPSTR(INPUT_BUFFER_SIZE)")\n"
31
32 // === TYPES ===
33 typedef struct sClient
34 {
35          int    ID;     // Client ID
36          
37          int    bIsTrusted;     // Is the connection from a trusted host/port
38         
39         char    *Username;
40         char    Salt[9];
41         
42          int    UID;
43          int    bIsAuthed;
44 }       tClient;
45
46 // === PROTOTYPES ===
47 void    Server_Start(void);
48 void    Server_Cleanup(void);
49 void    Server_HandleClient(int Socket, int bTrusted);
50 char    *Server_ParseClientCommand(tClient *Client, char *CommandString);
51 // --- Commands ---
52 char    *Server_Cmd_USER(tClient *Client, char *Args);
53 char    *Server_Cmd_PASS(tClient *Client, char *Args);
54 char    *Server_Cmd_AUTOAUTH(tClient *Client, char *Args);
55 char    *Server_Cmd_ENUMITEMS(tClient *Client, char *Args);
56 char    *Server_Cmd_ITEMINFO(tClient *Client, char *Args);
57 char    *Server_Cmd_DISPENSE(tClient *Client, char *Args);
58 char    *Server_Cmd_GIVE(tClient *Client, char *Args);
59 char    *Server_Cmd_ADD(tClient *Client, char *Args);
60 char    *Server_Cmd_USERINFO(tClient *Client, char *Args);
61 // --- Helpers ---
62  int    GetUserAuth(const char *Salt, const char *Username, const uint8_t *Hash);
63 void    HexBin(uint8_t *Dest, char *Src, int BufSize);
64
65 // === GLOBALS ===
66  int    giServer_Port = 1020;
67  int    giServer_NextClientID = 1;
68 // - Commands
69 struct sClientCommand {
70         char    *Name;
71         char    *(*Function)(tClient *Client, char *Arguments);
72 }       gaServer_Commands[] = {
73         {"USER", Server_Cmd_USER},
74         {"PASS", Server_Cmd_PASS},
75         {"AUTOAUTH", Server_Cmd_AUTOAUTH},
76         {"ENUM_ITEMS", Server_Cmd_ENUMITEMS},
77         {"ITEM_INFO", Server_Cmd_ITEMINFO},
78         {"DISPENSE", Server_Cmd_DISPENSE},
79         {"GIVE", Server_Cmd_GIVE},
80         {"ADD", Server_Cmd_ADD},
81         {"USER_INFO", Server_Cmd_USERINFO}
82 };
83 #define NUM_COMMANDS    (sizeof(gaServer_Commands)/sizeof(gaServer_Commands[0]))
84  int    giServer_Socket;
85
86 // === CODE ===
87 /**
88  * \brief Open listenting socket and serve connections
89  */
90 void Server_Start(void)
91 {
92          int    client_socket;
93         struct sockaddr_in      server_addr, client_addr;
94
95         atexit(Server_Cleanup);
96
97         // Create Server
98         giServer_Socket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
99         if( giServer_Socket < 0 ) {
100                 fprintf(stderr, "ERROR: Unable to create server socket\n");
101                 return ;
102         }
103         
104         // Make listen address
105         memset(&server_addr, 0, sizeof(server_addr));
106         server_addr.sin_family = AF_INET;       // Internet Socket
107         server_addr.sin_addr.s_addr = htonl(INADDR_ANY);        // Listen on all interfaces
108         server_addr.sin_port = htons(giServer_Port);    // Port
109
110         // Bind
111         if( bind(giServer_Socket, (struct sockaddr *) &server_addr, sizeof(server_addr)) < 0 ) {
112                 fprintf(stderr, "ERROR: Unable to bind to 0.0.0.0:%i\n", giServer_Port);
113                 perror("Binding");
114                 return ;
115         }
116         
117         // Listen
118         if( listen(giServer_Socket, MAX_CONNECTION_QUEUE) < 0 ) {
119                 fprintf(stderr, "ERROR: Unable to listen to socket\n");
120                 perror("Listen");
121                 return ;
122         }
123         
124         printf("Listening on 0.0.0.0:%i\n", giServer_Port);
125         
126         for(;;)
127         {
128                 uint    len = sizeof(client_addr);
129                  int    bTrusted = 0;
130                 
131                 client_socket = accept(giServer_Socket, (struct sockaddr *) &client_addr, &len);
132                 if(client_socket < 0) {
133                         fprintf(stderr, "ERROR: Unable to accept client connection\n");
134                         return ;
135                 }
136                 
137                 if(giDebugLevel >= 2) {
138                         char    ipstr[INET_ADDRSTRLEN];
139                         inet_ntop(AF_INET, &client_addr.sin_addr, ipstr, INET_ADDRSTRLEN);
140                         printf("Client connection from %s:%i\n",
141                                 ipstr, ntohs(client_addr.sin_port));
142                 }
143                 
144                 // Trusted Connections
145                 if( ntohs(client_addr.sin_port) < 1024 )
146                 {
147                         // TODO: Make this runtime configurable
148                         switch( ntohl( client_addr.sin_addr.s_addr ) )
149                         {
150                         case 0x7F000001:        // 127.0.0.1    localhost
151                         //case 0x825E0D00:      // 130.95.13.0
152                         case 0x825E0D12:        // 130.95.13.18 mussel
153                         case 0x825E0D17:        // 130.95.13.23 martello
154                                 bTrusted = 1;
155                                 break;
156                         default:
157                                 break;
158                         }
159                 }
160                 
161                 // TODO: Multithread this?
162                 Server_HandleClient(client_socket, bTrusted);
163                 
164                 close(client_socket);
165         }
166 }
167
168 void Server_Cleanup(void)
169 {
170         printf("Close(%i)\n", giServer_Socket);
171         close(giServer_Socket);
172 }
173
174 /**
175  * \brief Reads from a client socket and parses the command strings
176  * \param Socket        Client socket number/handle
177  * \param bTrusted      Is the client trusted?
178  */
179 void Server_HandleClient(int Socket, int bTrusted)
180 {
181         char    inbuf[INPUT_BUFFER_SIZE];
182         char    *buf = inbuf;
183          int    remspace = INPUT_BUFFER_SIZE-1;
184          int    bytes = -1;
185         tClient clientInfo = {0};
186         
187         // Initialise Client info
188         clientInfo.ID = giServer_NextClientID ++;
189         clientInfo.bIsTrusted = bTrusted;
190         
191         // Read from client
192         /*
193          * Notes:
194          * - The `buf` and `remspace` variables allow a line to span several
195          *   calls to recv(), if a line is not completed in one recv() call
196          *   it is saved to the beginning of `inbuf` and `buf` is updated to
197          *   the end of it.
198          */
199         while( (bytes = recv(Socket, buf, remspace, 0)) > 0 )
200         {
201                 char    *eol, *start;
202                 buf[bytes] = '\0';      // Allow us to use stdlib string functions on it
203                 
204                 // Split by lines
205                 start = inbuf;
206                 while( (eol = strchr(start, '\n')) )
207                 {
208                         char    *ret;
209                         *eol = '\0';
210                         ret = Server_ParseClientCommand(&clientInfo, start);
211                         
212                         #if DEBUG_TRACE_CLIENT
213                         //printf("ret = %s", ret);
214                         #endif
215                         
216                         // `ret` is a string on the heap
217                         send(Socket, ret, strlen(ret), 0);
218                         free(ret);
219                         start = eol + 1;
220                 }
221                 
222                 // Check if there was an incomplete line
223                 if( *start != '\0' ) {
224                          int    tailBytes = bytes - (start-buf);
225                         // Roll back in buffer
226                         memcpy(inbuf, start, tailBytes);
227                         remspace -= tailBytes;
228                         if(remspace == 0) {
229                                 send(Socket, MSG_STR_TOO_LONG, sizeof(MSG_STR_TOO_LONG), 0);
230                                 buf = inbuf;
231                                 remspace = INPUT_BUFFER_SIZE - 1;
232                         }
233                 }
234                 else {
235                         buf = inbuf;
236                         remspace = INPUT_BUFFER_SIZE - 1;
237                 }
238         }
239         
240         // Check for errors
241         if( bytes < 0 ) {
242                 fprintf(stderr, "ERROR: Unable to recieve from client on socket %i\n", Socket);
243                 return ;
244         }
245         
246         if(giDebugLevel >= 2) {
247                 printf("Client %i: Disconnected\n", clientInfo.ID);
248         }
249 }
250
251 /**
252  * \brief Parses a client command and calls the required helper function
253  * \param Client        Pointer to client state structure
254  * \param CommandString Command from client (single line of the command)
255  * \return Heap String to return to the client
256  */
257 char *Server_ParseClientCommand(tClient *Client, char *CommandString)
258 {
259         char    *space, *args;
260          int    i;
261         
262         // Split at first space
263         space = strchr(CommandString, ' ');
264         if(space == NULL) {
265                 args = NULL;
266         }
267         else {
268                 *space = '\0';
269                 args = space + 1;
270         }
271         
272         // Find command
273         for( i = 0; i < NUM_COMMANDS; i++ )
274         {
275                 if(strcmp(CommandString, gaServer_Commands[i].Name) == 0)
276                         return gaServer_Commands[i].Function(Client, args);
277         }
278         
279         return strdup("400 Unknown Command\n");
280 }
281
282 // ---
283 // Commands
284 // ---
285 /**
286  * \brief Set client username
287  * 
288  * Usage: USER <username>
289  */
290 char *Server_Cmd_USER(tClient *Client, char *Args)
291 {
292         char    *ret;
293         
294         // Debug!
295         if( giDebugLevel )
296                 printf("Client %i authenticating as '%s'\n", Client->ID, Args);
297         
298         // Save username
299         if(Client->Username)
300                 free(Client->Username);
301         Client->Username = strdup(Args);
302         
303         #if USE_SALT
304         // Create a salt (that changes if the username is changed)
305         // Yes, I know, I'm a little paranoid, but who isn't?
306         Client->Salt[0] = 0x21 + (rand()&0x3F);
307         Client->Salt[1] = 0x21 + (rand()&0x3F);
308         Client->Salt[2] = 0x21 + (rand()&0x3F);
309         Client->Salt[3] = 0x21 + (rand()&0x3F);
310         Client->Salt[4] = 0x21 + (rand()&0x3F);
311         Client->Salt[5] = 0x21 + (rand()&0x3F);
312         Client->Salt[6] = 0x21 + (rand()&0x3F);
313         Client->Salt[7] = 0x21 + (rand()&0x3F);
314         
315         // TODO: Also send hash type to use, (SHA1 or crypt according to [DAA])
316         ret = mkstr("100 SALT %s\n", Client->Salt);
317         #else
318         ret = strdup("100 User Set\n");
319         #endif
320         return ret;
321 }
322
323 /**
324  * \brief Authenticate as a user
325  * 
326  * Usage: PASS <hash>
327  */
328 char *Server_Cmd_PASS(tClient *Client, char *Args)
329 {
330         uint8_t clienthash[HASH_LENGTH] = {0};
331         
332         // Read user's hash
333         HexBin(clienthash, Args, HASH_LENGTH);
334         
335         // TODO: Decrypt password passed
336         
337         Client->UID = GetUserAuth(Client->Salt, Client->Username, clienthash);
338
339         if( Client->UID != -1 ) {
340                 Client->bIsAuthed = 1;
341                 return strdup("200 Auth OK\n");
342         }
343
344         if( giDebugLevel ) {
345                  int    i;
346                 printf("Client %i: Password hash ", Client->ID);
347                 for(i=0;i<HASH_LENGTH;i++)
348                         printf("%02x", clienthash[i]&0xFF);
349                 printf("\n");
350         }
351         
352         return strdup("401 Auth Failure\n");
353 }
354
355 /**
356  * \brief Authenticate as a user without a password
357  * 
358  * Usage: AUTOAUTH <user>
359  */
360 char *Server_Cmd_AUTOAUTH(tClient *Client, char *Args)
361 {
362         char    *spos = strchr(Args, ' ');
363         if(spos)        *spos = '\0';   // Remove characters after the ' '
364         
365         // Check if trusted
366         if( !Client->bIsTrusted ) {
367                 if(giDebugLevel)
368                         printf("Client %i: Untrusted client attempting to AUTOAUTH\n", Client->ID);
369                 return strdup("401 Untrusted\n");
370         }
371         
372         // Get UID
373         Client->UID = GetUserID( Args );
374         if( Client->UID < 0 ) {
375                 if(giDebugLevel)
376                         printf("Client %i: Unknown user '%s'\n", Client->ID, Args);
377                 return strdup("401 Auth Failure\n");
378         }
379         
380         if(giDebugLevel)
381                 printf("Client %i: Authenticated as '%s' (%i)\n", Client->ID, Args, Client->UID);
382         
383         return strdup("200 Auth OK\n");
384 }
385
386 /**
387  * \brief Enumerate the items that the server knows about
388  */
389 char *Server_Cmd_ENUMITEMS(tClient *Client, char *Args)
390 {
391          int    retLen;
392          int    i;
393         char    *ret;
394
395         retLen = snprintf(NULL, 0, "201 Items %i", giNumItems);
396
397         for( i = 0; i < giNumItems; i ++ )
398         {
399                 retLen += snprintf(NULL, 0, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
400         }
401
402         ret = malloc(retLen+1);
403         retLen = 0;
404         retLen += sprintf(ret+retLen, "201 Items %i", giNumItems);
405
406         for( i = 0; i < giNumItems; i ++ ) {
407                 retLen += sprintf(ret+retLen, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
408         }
409
410         strcat(ret, "\n");
411
412         return ret;
413 }
414
415 tItem *_GetItemFromString(char *String)
416 {
417         tHandler        *handler;
418         char    *type = String;
419         char    *colon = strchr(String, ':');
420          int    num, i;
421         
422         if( !colon ) {
423                 return NULL;
424         }
425
426         num = atoi(colon+1);
427         *colon = '\0';
428
429         // Find handler
430         handler = NULL;
431         for( i = 0; i < giNumHandlers; i ++ )
432         {
433                 if( strcmp(gaHandlers[i]->Name, type) == 0) {
434                         handler = gaHandlers[i];
435                         break;
436                 }
437         }
438         if( !handler ) {
439                 return NULL;
440         }
441
442         // Find item
443         for( i = 0; i < giNumItems; i ++ )
444         {
445                 if( gaItems[i].Handler != handler )     continue;
446                 if( gaItems[i].ID != num )      continue;
447                 return &gaItems[i];
448         }
449         return NULL;
450 }
451
452 /**
453  * \brief Fetch information on a specific item
454  */
455 char *Server_Cmd_ITEMINFO(tClient *Client, char *Args)
456 {
457          int    retLen = 0;
458         char    *ret;
459         tItem   *item = _GetItemFromString(Args);
460         
461         if( !item ) {
462                 return strdup("406 Bad Item ID\n");
463         }
464
465         // Create return
466         retLen = snprintf(NULL, 0, "202 Item %s:%i %i %s\n",
467                 item->Handler->Name, item->ID, item->Price, item->Name);
468         ret = malloc(retLen+1);
469         sprintf(ret, "202 Item %s:%i %i %s\n",
470                 item->Handler->Name, item->ID, item->Price, item->Name);
471
472         return ret;
473 }
474
475 char *Server_Cmd_DISPENSE(tClient *Client, char *Args)
476 {
477         tItem   *item;
478          int    ret;
479         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
480
481         item = _GetItemFromString(Args);
482         if( !item ) {
483                 return strdup("406 Bad Item ID\n");
484         }
485
486         switch( ret = DispenseItem( Client->UID, item ) )
487         {
488         case 0: return strdup("200 Dispense OK\n");
489         case 1: return strdup("501 Unable to dispense\n");
490         case 2: return strdup("402 Poor You\n");
491         default:
492                 return strdup("500 Dispense Error\n");
493         }
494 }
495
496 char *Server_Cmd_GIVE(tClient *Client, char *Args)
497 {
498         char    *recipient, *ammount, *reason;
499          int    uid, iAmmount;
500         
501         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
502
503         recipient = Args;
504
505         ammount = strchr(Args, ' ');
506         if( !ammount )  return strdup("407 Invalid Argument, expected 3 parameters, 1 encountered\n");
507         *ammount = '\0';
508         ammount ++;
509
510         reason = strchr(ammount, ' ');
511         if( !reason )   return strdup("407 Invalid Argument, expected 3 parameters, 2 encountered\n");
512         *reason = '\0';
513         reason ++;
514
515         // Get recipient
516         uid = GetUserID(recipient);
517         if( uid == -1 ) return strdup("404 Invalid target user");
518
519         // Parse ammount
520         iAmmount = atoi(ammount);
521         if( iAmmount <= 0 )     return strdup("407 Invalid Argument, ammount must be > zero\n");
522
523         // Do give
524         switch( DispenseGive(Client->UID, uid, iAmmount, reason) )
525         {
526         case 0:
527                 return strdup("200 Give OK\n");
528         case 2:
529                 return strdup("402 Poor You\n");
530         default:
531                 return strdup("500 Unknown error\n");
532         }
533 }
534
535 char *Server_Cmd_ADD(tClient *Client, char *Args)
536 {
537         char    *user, *ammount, *reason;
538          int    uid, iAmmount;
539         
540         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
541
542         user = Args;
543
544         ammount = strchr(Args, ' ');
545         if( !ammount )  return strdup("407 Invalid Argument, expected 3 parameters, 1 encountered\n");
546         *ammount = '\0';
547         ammount ++;
548
549         reason = strchr(ammount, ' ');
550         if( !reason )   return strdup("407 Invalid Argument, expected 3 parameters, 2 encountered\n");
551         *reason = '\0';
552         reason ++;
553
554         // TODO: Check if the current user is in coke/higher
555
556         // Get recipient
557         uid = GetUserID(user);
558         if( uid == -1 ) return strdup("404 Invalid user");
559
560         // Parse ammount
561         iAmmount = atoi(ammount);
562         if( iAmmount == 0 && ammount[0] != '0' )        return strdup("407 Invalid Argument, ammount must be > zero\n");
563
564         // Do give
565         switch( DispenseAdd(uid, Client->UID, iAmmount, reason) )
566         {
567         case 0:
568                 return strdup("200 Add OK\n");
569         case 2:
570                 return strdup("402 Poor Guy\n");
571         default:
572                 return strdup("500 Unknown error\n");
573         }
574 }
575
576 char *Server_Cmd_USERINFO(tClient *Client, char *Args)
577 {
578          int    uid;
579         char    *user = Args;
580         char    *space;
581         
582         space = strchr(user, ' ');
583         if(space)       *space = '\0';
584         
585         // Get recipient
586         uid = GetUserID(user);
587         if( uid == -1 ) return strdup("404 Invalid user");
588
589         return mkstr("202 User %s %i user\n", user, GetBalance(uid));
590 }
591
592 /**
593  * \brief Authenticate a user
594  * \return User ID, or -1 if authentication failed
595  */
596 int GetUserAuth(const char *Salt, const char *Username, const uint8_t *ProvidedHash)
597 {
598         #if 0
599         uint8_t h[20];
600          int    ofs = strlen(Username) + strlen(Salt);
601         char    input[ ofs + 40 + 1];
602         char    tmp[4 + strlen(Username) + 1];  // uid=%s
603         #endif
604         
605         #if HACK_TPG_NOAUTH
606         if( strcmp(Username, "tpg") == 0 )
607                 return GetUserID("tpg");
608         #endif
609         #if HACK_ROOT_NOAUTH
610         if( strcmp(Username, "root") == 0 )
611                 return GetUserID("root");
612         #endif
613         
614         #if 0
615         //
616         strcpy(input, Username);
617         strcpy(input, Salt);
618         // TODO: Get user's SHA-1 hash
619         sprintf(tmp, "uid=%s", Username);
620         ldap_search_s(ld, "", LDAP_SCOPE_BASE, tmp, "userPassword", 0, res);
621         
622         sprintf(input+ofs, "%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x",
623                 h[ 0], h[ 1], h[ 2], h[ 3], h[ 4], h[ 5], h[ 6], h[ 7], h[ 8], h[ 9],
624                 h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19]
625                 );
626         // Then create the hash from the provided salt
627         // Compare that with the provided hash
628         #endif
629         
630         return -1;
631 }
632
633 // --- INTERNAL HELPERS ---
634 // TODO: Move to another file
635 void HexBin(uint8_t *Dest, char *Src, int BufSize)
636 {
637          int    i;
638         for( i = 0; i < BufSize; i ++ )
639         {
640                 uint8_t val = 0;
641                 
642                 if('0' <= *Src && *Src <= '9')
643                         val |= (*Src-'0') << 4;
644                 else if('A' <= *Src && *Src <= 'F')
645                         val |= (*Src-'A'+10) << 4;
646                 else if('a' <= *Src && *Src <= 'f')
647                         val |= (*Src-'a'+10) << 4;
648                 else
649                         break;
650                 Src ++;
651                 
652                 if('0' <= *Src && *Src <= '9')
653                         val |= (*Src-'0');
654                 else if('A' <= *Src && *Src <= 'F')
655                         val |= (*Src-'A'+10);
656                 else if('a' <= *Src && *Src <= 'f')
657                         val |= (*Src-'a'+10);
658                 else
659                         break;
660                 Src ++;
661                 
662                 Dest[i] = val;
663         }
664         for( ; i < BufSize; i++ )
665                 Dest[i] = 0;
666 }
667
668 /**
669  * \brief Decode a Base64 value
670  */
671 int UnBase64(uint8_t *Dest, char *Src, int BufSize)
672 {
673         uint32_t        val;
674          int    i, j;
675         char    *start_src = Src;
676         
677         for( i = 0; i+2 < BufSize; i += 3 )
678         {
679                 val = 0;
680                 for( j = 0; j < 4; j++, Src ++ ) {
681                         if('A' <= *Src && *Src <= 'Z')
682                                 val |= (*Src - 'A') << ((3-j)*6);
683                         else if('a' <= *Src && *Src <= 'z')
684                                 val |= (*Src - 'a' + 26) << ((3-j)*6);
685                         else if('0' <= *Src && *Src <= '9')
686                                 val |= (*Src - '0' + 52) << ((3-j)*6);
687                         else if(*Src == '+')
688                                 val |= 62 << ((3-j)*6);
689                         else if(*Src == '/')
690                                 val |= 63 << ((3-j)*6);
691                         else if(!*Src)
692                                 break;
693                         else if(*Src != '=')
694                                 j --;   // Ignore invalid characters
695                 }
696                 Dest[i  ] = (val >> 16) & 0xFF;
697                 Dest[i+1] = (val >> 8) & 0xFF;
698                 Dest[i+2] = val & 0xFF;
699                 if(j != 4)      break;
700         }
701         
702         // Finish things off
703         if(i   < BufSize)
704                 Dest[i] = (val >> 16) & 0xFF;
705         if(i+1 < BufSize)
706                 Dest[i+1] = (val >> 8) & 0xFF;
707         
708         return Src - start_src;
709 }

UCC git Repository :: git.ucc.asn.au