More cleanup, implementing `dispense give`
[tpg/opendispense2.git] / src / server / server.c
1 /*
2  * OpenDispense 2 
3  * UCC (University [of WA] Computer Club) Electronic Accounting System
4  *
5  * server.c - Client Server Code
6  *
7  * This file is licenced under the 3-clause BSD Licence. See the file
8  * COPYING for full details.
9  */
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include "common.h"
13 #include <sys/socket.h>
14 #include <netinet/in.h>
15 #include <arpa/inet.h>
16 #include <unistd.h>
17 #include <string.h>
18
19 // HACKS
20 #define HACK_TPG_NOAUTH 1
21 #define HACK_ROOT_NOAUTH        1
22
23 // Statistics
24 #define MAX_CONNECTION_QUEUE    5
25 #define INPUT_BUFFER_SIZE       256
26
27 #define HASH_TYPE       SHA1
28 #define HASH_LENGTH     20
29
30 #define MSG_STR_TOO_LONG        "499 Command too long (limit "EXPSTR(INPUT_BUFFER_SIZE)")\n"
31
32 // === TYPES ===
33 typedef struct sClient
34 {
35          int    ID;     // Client ID
36          
37          int    bIsTrusted;     // Is the connection from a trusted host/port
38         
39         char    *Username;
40         char    Salt[9];
41         
42          int    UID;
43          int    bIsAuthed;
44 }       tClient;
45
46 // === PROTOTYPES ===
47 void    Server_Start(void);
48 void    Server_Cleanup(void);
49 void    Server_HandleClient(int Socket, int bTrusted);
50 char    *Server_ParseClientCommand(tClient *Client, char *CommandString);
51 // --- Commands ---
52 char    *Server_Cmd_USER(tClient *Client, char *Args);
53 char    *Server_Cmd_PASS(tClient *Client, char *Args);
54 char    *Server_Cmd_AUTOAUTH(tClient *Client, char *Args);
55 char    *Server_Cmd_ENUMITEMS(tClient *Client, char *Args);
56 char    *Server_Cmd_ITEMINFO(tClient *Client, char *Args);
57 char    *Server_Cmd_DISPENSE(tClient *Client, char *Args);
58 // --- Helpers ---
59  int    GetUserAuth(const char *Salt, const char *Username, const uint8_t *Hash);
60 void    HexBin(uint8_t *Dest, char *Src, int BufSize);
61
62 // === GLOBALS ===
63  int    giServer_Port = 1020;
64  int    giServer_NextClientID = 1;
65 // - Commands
66 struct sClientCommand {
67         char    *Name;
68         char    *(*Function)(tClient *Client, char *Arguments);
69 }       gaServer_Commands[] = {
70         {"USER", Server_Cmd_USER},
71         {"PASS", Server_Cmd_PASS},
72         {"AUTOAUTH", Server_Cmd_AUTOAUTH},
73         {"ENUM_ITEMS", Server_Cmd_ENUMITEMS},
74         {"ITEM_INFO", Server_Cmd_ITEMINFO},
75         {"DISPENSE", Server_Cmd_DISPENSE}
76 };
77 #define NUM_COMMANDS    (sizeof(gaServer_Commands)/sizeof(gaServer_Commands[0]))
78  int    giServer_Socket;
79
80 // === CODE ===
81 /**
82  * \brief Open listenting socket and serve connections
83  */
84 void Server_Start(void)
85 {
86          int    client_socket;
87         struct sockaddr_in      server_addr, client_addr;
88
89         atexit(Server_Cleanup);
90
91         // Create Server
92         giServer_Socket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
93         if( giServer_Socket < 0 ) {
94                 fprintf(stderr, "ERROR: Unable to create server socket\n");
95                 return ;
96         }
97         
98         // Make listen address
99         memset(&server_addr, 0, sizeof(server_addr));
100         server_addr.sin_family = AF_INET;       // Internet Socket
101         server_addr.sin_addr.s_addr = htonl(INADDR_ANY);        // Listen on all interfaces
102         server_addr.sin_port = htons(giServer_Port);    // Port
103
104         // Bind
105         if( bind(giServer_Socket, (struct sockaddr *) &server_addr, sizeof(server_addr)) < 0 ) {
106                 fprintf(stderr, "ERROR: Unable to bind to 0.0.0.0:%i\n", giServer_Port);
107                 perror("Binding");
108                 return ;
109         }
110         
111         // Listen
112         if( listen(giServer_Socket, MAX_CONNECTION_QUEUE) < 0 ) {
113                 fprintf(stderr, "ERROR: Unable to listen to socket\n");
114                 perror("Listen");
115                 return ;
116         }
117         
118         printf("Listening on 0.0.0.0:%i\n", giServer_Port);
119         
120         for(;;)
121         {
122                 uint    len = sizeof(client_addr);
123                  int    bTrusted = 0;
124                 
125                 client_socket = accept(giServer_Socket, (struct sockaddr *) &client_addr, &len);
126                 if(client_socket < 0) {
127                         fprintf(stderr, "ERROR: Unable to accept client connection\n");
128                         return ;
129                 }
130                 
131                 if(giDebugLevel >= 2) {
132                         char    ipstr[INET_ADDRSTRLEN];
133                         inet_ntop(AF_INET, &client_addr.sin_addr, ipstr, INET_ADDRSTRLEN);
134                         printf("Client connection from %s:%i\n",
135                                 ipstr, ntohs(client_addr.sin_port));
136                 }
137                 
138                 // Trusted Connections
139                 if( ntohs(client_addr.sin_port) < 1024 )
140                 {
141                         // TODO: Make this runtime configurable
142                         switch( ntohl( client_addr.sin_addr.s_addr ) )
143                         {
144                         case 0x7F000001:        // 127.0.0.1    localhost
145                         //case 0x825E0D00:      // 130.95.13.0
146                         case 0x825E0D12:        // 130.95.13.18 mussel
147                         case 0x825E0D17:        // 130.95.13.23 martello
148                                 bTrusted = 1;
149                                 break;
150                         default:
151                                 break;
152                         }
153                 }
154                 
155                 // TODO: Multithread this?
156                 Server_HandleClient(client_socket, bTrusted);
157                 
158                 close(client_socket);
159         }
160 }
161
162 void Server_Cleanup(void)
163 {
164         printf("Close(%i)\n", giServer_Socket);
165         close(giServer_Socket);
166 }
167
168 /**
169  * \brief Reads from a client socket and parses the command strings
170  * \param Socket        Client socket number/handle
171  * \param bTrusted      Is the client trusted?
172  */
173 void Server_HandleClient(int Socket, int bTrusted)
174 {
175         char    inbuf[INPUT_BUFFER_SIZE];
176         char    *buf = inbuf;
177          int    remspace = INPUT_BUFFER_SIZE-1;
178          int    bytes = -1;
179         tClient clientInfo = {0};
180         
181         // Initialise Client info
182         clientInfo.ID = giServer_NextClientID ++;
183         clientInfo.bIsTrusted = bTrusted;
184         
185         // Read from client
186         /*
187          * Notes:
188          * - The `buf` and `remspace` variables allow a line to span several
189          *   calls to recv(), if a line is not completed in one recv() call
190          *   it is saved to the beginning of `inbuf` and `buf` is updated to
191          *   the end of it.
192          */
193         while( (bytes = recv(Socket, buf, remspace, 0)) > 0 )
194         {
195                 char    *eol, *start;
196                 buf[bytes] = '\0';      // Allow us to use stdlib string functions on it
197                 
198                 // Split by lines
199                 start = inbuf;
200                 while( (eol = strchr(start, '\n')) )
201                 {
202                         char    *ret;
203                         *eol = '\0';
204                         ret = Server_ParseClientCommand(&clientInfo, start);
205                         // `ret` is a string on the heap
206                         send(Socket, ret, strlen(ret), 0);
207                         free(ret);
208                         start = eol + 1;
209                 }
210                 
211                 // Check if there was an incomplete line
212                 if( *start != '\0' ) {
213                          int    tailBytes = bytes - (start-buf);
214                         // Roll back in buffer
215                         memcpy(inbuf, start, tailBytes);
216                         remspace -= tailBytes;
217                         if(remspace == 0) {
218                                 send(Socket, MSG_STR_TOO_LONG, sizeof(MSG_STR_TOO_LONG), 0);
219                                 buf = inbuf;
220                                 remspace = INPUT_BUFFER_SIZE - 1;
221                         }
222                 }
223                 else {
224                         buf = inbuf;
225                         remspace = INPUT_BUFFER_SIZE - 1;
226                 }
227         }
228         
229         // Check for errors
230         if( bytes < 0 ) {
231                 fprintf(stderr, "ERROR: Unable to recieve from client on socket %i\n", Socket);
232                 return ;
233         }
234         
235         if(giDebugLevel >= 2) {
236                 printf("Client %i: Disconnected\n", clientInfo.ID);
237         }
238 }
239
240 /**
241  * \brief Parses a client command and calls the required helper function
242  * \param Client        Pointer to client state structure
243  * \param CommandString Command from client (single line of the command)
244  * \return Heap String to return to the client
245  */
246 char *Server_ParseClientCommand(tClient *Client, char *CommandString)
247 {
248         char    *space, *args;
249          int    i;
250         
251         // Split at first space
252         space = strchr(CommandString, ' ');
253         if(space == NULL) {
254                 args = NULL;
255         }
256         else {
257                 *space = '\0';
258                 args = space + 1;
259         }
260         
261         // Find command
262         for( i = 0; i < NUM_COMMANDS; i++ )
263         {
264                 if(strcmp(CommandString, gaServer_Commands[i].Name) == 0)
265                         return gaServer_Commands[i].Function(Client, args);
266         }
267         
268         return strdup("400 Unknown Command\n");
269 }
270
271 // ---
272 // Commands
273 // ---
274 /**
275  * \brief Set client username
276  * 
277  * Usage: USER <username>
278  */
279 char *Server_Cmd_USER(tClient *Client, char *Args)
280 {
281         char    *ret;
282         
283         // Debug!
284         if( giDebugLevel )
285                 printf("Client %i authenticating as '%s'\n", Client->ID, Args);
286         
287         // Save username
288         if(Client->Username)
289                 free(Client->Username);
290         Client->Username = strdup(Args);
291         
292         #if USE_SALT
293         // Create a salt (that changes if the username is changed)
294         // Yes, I know, I'm a little paranoid, but who isn't?
295         Client->Salt[0] = 0x21 + (rand()&0x3F);
296         Client->Salt[1] = 0x21 + (rand()&0x3F);
297         Client->Salt[2] = 0x21 + (rand()&0x3F);
298         Client->Salt[3] = 0x21 + (rand()&0x3F);
299         Client->Salt[4] = 0x21 + (rand()&0x3F);
300         Client->Salt[5] = 0x21 + (rand()&0x3F);
301         Client->Salt[6] = 0x21 + (rand()&0x3F);
302         Client->Salt[7] = 0x21 + (rand()&0x3F);
303         
304         // TODO: Also send hash type to use, (SHA1 or crypt according to [DAA])
305         // "100 Salt xxxxXXXX\n"
306         ret = strdup("100 SALT xxxxXXXX\n");
307         sprintf(ret, "100 SALT %s\n", Client->Salt);
308         #else
309         ret = strdup("100 User Set\n");
310         #endif
311         return ret;
312 }
313
314 /**
315  * \brief Authenticate as a user
316  * 
317  * Usage: PASS <hash>
318  */
319 char *Server_Cmd_PASS(tClient *Client, char *Args)
320 {
321         uint8_t clienthash[HASH_LENGTH] = {0};
322         
323         // Read user's hash
324         HexBin(clienthash, Args, HASH_LENGTH);
325         
326         // TODO: Decrypt password passed
327         
328         Client->UID = GetUserAuth(Client->Salt, Client->Username, clienthash);
329
330         if( Client->UID != -1 ) {
331                 Client->bIsAuthed = 1;
332                 return strdup("200 Auth OK\n");
333         }
334
335         if( giDebugLevel ) {
336                  int    i;
337                 printf("Client %i: Password hash ", Client->ID);
338                 for(i=0;i<HASH_LENGTH;i++)
339                         printf("%02x", clienthash[i]&0xFF);
340                 printf("\n");
341         }
342         
343         return strdup("401 Auth Failure\n");
344 }
345
346 /**
347  * \brief Authenticate as a user without a password
348  * 
349  * Usage: AUTOAUTH <user>
350  */
351 char *Server_Cmd_AUTOAUTH(tClient *Client, char *Args)
352 {
353         char    *spos = strchr(Args, ' ');
354         if(spos)        *spos = '\0';   // Remove characters after the ' '
355         
356         // Check if trusted
357         if( !Client->bIsTrusted ) {
358                 if(giDebugLevel)
359                         printf("Client %i: Untrusted client attempting to AUTOAUTH\n", Client->ID);
360                 return strdup("401 Untrusted\n");
361         }
362         
363         // Get UID
364         Client->UID = GetUserID( Args );
365         if( Client->UID < 0 ) {
366                 if(giDebugLevel)
367                         printf("Client %i: Unknown user '%s'\n", Client->ID, Args);
368                 return strdup("401 Auth Failure\n");
369         }
370         
371         if(giDebugLevel)
372                 printf("Client %i: Authenticated as '%s' (%i)\n", Client->ID, Args, Client->UID);
373         
374         return strdup("200 Auth OK\n");
375 }
376
377 /**
378  * \brief Enumerate the items that the server knows about
379  */
380 char *Server_Cmd_ENUMITEMS(tClient *Client, char *Args)
381 {
382          int    retLen;
383          int    i;
384         char    *ret;
385
386         retLen = snprintf(NULL, 0, "201 Items %i", giNumItems);
387
388         for( i = 0; i < giNumItems; i ++ )
389         {
390                 retLen += snprintf(NULL, 0, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
391         }
392
393         ret = malloc(retLen+1);
394         retLen = 0;
395         retLen += sprintf(ret+retLen, "201 Items %i", giNumItems);
396
397         for( i = 0; i < giNumItems; i ++ ) {
398                 retLen += sprintf(ret+retLen, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
399         }
400
401         strcat(ret, "\n");
402
403         return ret;
404 }
405
406 tItem *_GetItemFromString(char *String)
407 {
408         tHandler        *handler;
409         char    *type = String;
410         char    *colon = strchr(String, ':');
411          int    num, i;
412         
413         if( !colon ) {
414                 return NULL;
415         }
416
417         num = atoi(colon+1);
418         *colon = '\0';
419
420         // Find handler
421         handler = NULL;
422         for( i = 0; i < giNumHandlers; i ++ )
423         {
424                 if( strcmp(gaHandlers[i]->Name, type) == 0) {
425                         handler = gaHandlers[i];
426                         break;
427                 }
428         }
429         if( !handler ) {
430                 return NULL;
431         }
432
433         // Find item
434         for( i = 0; i < giNumItems; i ++ )
435         {
436                 if( gaItems[i].Handler != handler )     continue;
437                 if( gaItems[i].ID != num )      continue;
438                 return &gaItems[i];
439         }
440         return NULL;
441 }
442
443 /**
444  * \brief Fetch information on a specific item
445  */
446 char *Server_Cmd_ITEMINFO(tClient *Client, char *Args)
447 {
448          int    retLen = 0;
449         char    *ret;
450         tItem   *item = _GetItemFromString(Args);
451         
452         if( !item ) {
453                 return strdup("406 Bad Item ID\n");
454         }
455
456         // Create return
457         retLen = snprintf(NULL, 0, "202 Item %s:%i %i %s\n",
458                 item->Handler->Name, item->ID, item->Price, item->Name);
459         ret = malloc(retLen+1);
460         sprintf(ret, "202 Item %s:%i %i %s\n",
461                 item->Handler->Name, item->ID, item->Price, item->Name);
462
463         return ret;
464 }
465
466 char *Server_Cmd_DISPENSE(tClient *Client, char *Args)
467 {
468         tItem   *item;
469          int    ret;
470         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
471
472         item = _GetItemFromString(Args);
473         if( !item ) {
474                 return strdup("406 Bad Item ID\n");
475         }
476
477         switch( ret = DispenseItem( Client->UID, item ) )
478         {
479         case 0: return strdup("200 Dispense OK\n");
480         case 1: return strdup("501 Unable to dispense\n");
481         case 2: return strdup("402 Poor You\n");
482         default:
483                 return strdup("500 Dispense Error\n");
484         }
485 }
486
487 char *Server_Cmd_GIVE(tClient *Client, char *Args)
488 {
489         char    *recipient, *ammount, *reason;
490          int    uid, iAmmount;
491         
492         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
493
494         recipient = Args;
495
496         ammount = strchr(Args, ' ');
497         if( !ammount )  return strdup("407 Invalid Argument, expected 3 parameters, 1 encountered\n");
498         *ammount = '\0';
499         ammount ++;
500
501         reason = strchr(ammount, ' ');
502         if( !reason )   return strdup("407 Invalid Argument, expected 3 parameters, 2 encountered\n");
503         *reason = '\0';
504         reason ++;
505
506         // Get recipient
507         uid = GetUserID(recipient);
508         if( uid == -1 ) return strdup("404 Invalid target user");
509
510         // Parse ammount
511         iAmmount = atoi(ammount);
512         if( iAmmount <= 0 )     return strdup("407 Invalid Argument, ammount must be > zero\n");
513
514         // Do give
515         switch( DispenseGive(Client->UID, uid, iAmmount, reason) )
516         {
517         case 0:
518                 return strdup("200 Give OK\n");
519         case 2:
520                 return strdup("402 Poor You\n");
521         default:
522                 return strdup("500 Unknown error\n");
523         }
524 }
525
526 /**
527  * \brief Authenticate a user
528  * \return User ID, or -1 if authentication failed
529  */
530 int GetUserAuth(const char *Salt, const char *Username, const uint8_t *ProvidedHash)
531 {
532         #if 0
533         uint8_t h[20];
534          int    ofs = strlen(Username) + strlen(Salt);
535         char    input[ ofs + 40 + 1];
536         char    tmp[4 + strlen(Username) + 1];  // uid=%s
537         #endif
538         
539         #if HACK_TPG_NOAUTH
540         if( strcmp(Username, "tpg") == 0 )
541                 return GetUserID("tpg");
542         #endif
543         #if HACK_ROOT_NOAUTH
544         if( strcmp(Username, "root") == 0 )
545                 return GetUserID("root");
546         #endif
547         
548         #if 0
549         //
550         strcpy(input, Username);
551         strcpy(input, Salt);
552         // TODO: Get user's SHA-1 hash
553         sprintf(tmp, "uid=%s", Username);
554         ldap_search_s(ld, "", LDAP_SCOPE_BASE, tmp, "userPassword", 0, res);
555         
556         sprintf(input+ofs, "%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x",
557                 h[ 0], h[ 1], h[ 2], h[ 3], h[ 4], h[ 5], h[ 6], h[ 7], h[ 8], h[ 9],
558                 h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19]
559                 );
560         // Then create the hash from the provided salt
561         // Compare that with the provided hash
562         #endif
563         
564         return -1;
565 }
566
567 // --- INTERNAL HELPERS ---
568 // TODO: Move to another file
569 void HexBin(uint8_t *Dest, char *Src, int BufSize)
570 {
571          int    i;
572         for( i = 0; i < BufSize; i ++ )
573         {
574                 uint8_t val = 0;
575                 
576                 if('0' <= *Src && *Src <= '9')
577                         val |= (*Src-'0') << 4;
578                 else if('A' <= *Src && *Src <= 'F')
579                         val |= (*Src-'A'+10) << 4;
580                 else if('a' <= *Src && *Src <= 'f')
581                         val |= (*Src-'a'+10) << 4;
582                 else
583                         break;
584                 Src ++;
585                 
586                 if('0' <= *Src && *Src <= '9')
587                         val |= (*Src-'0');
588                 else if('A' <= *Src && *Src <= 'F')
589                         val |= (*Src-'A'+10);
590                 else if('a' <= *Src && *Src <= 'f')
591                         val |= (*Src-'a'+10);
592                 else
593                         break;
594                 Src ++;
595                 
596                 Dest[i] = val;
597         }
598         for( ; i < BufSize; i++ )
599                 Dest[i] = 0;
600 }
601
602 /**
603  * \brief Decode a Base64 value
604  */
605 int UnBase64(uint8_t *Dest, char *Src, int BufSize)
606 {
607         uint32_t        val;
608          int    i, j;
609         char    *start_src = Src;
610         
611         for( i = 0; i+2 < BufSize; i += 3 )
612         {
613                 val = 0;
614                 for( j = 0; j < 4; j++, Src ++ ) {
615                         if('A' <= *Src && *Src <= 'Z')
616                                 val |= (*Src - 'A') << ((3-j)*6);
617                         else if('a' <= *Src && *Src <= 'z')
618                                 val |= (*Src - 'a' + 26) << ((3-j)*6);
619                         else if('0' <= *Src && *Src <= '9')
620                                 val |= (*Src - '0' + 52) << ((3-j)*6);
621                         else if(*Src == '+')
622                                 val |= 62 << ((3-j)*6);
623                         else if(*Src == '/')
624                                 val |= 63 << ((3-j)*6);
625                         else if(!*Src)
626                                 break;
627                         else if(*Src != '=')
628                                 j --;   // Ignore invalid characters
629                 }
630                 Dest[i  ] = (val >> 16) & 0xFF;
631                 Dest[i+1] = (val >> 8) & 0xFF;
632                 Dest[i+2] = val & 0xFF;
633                 if(j != 4)      break;
634         }
635         
636         // Finish things off
637         if(i   < BufSize)
638                 Dest[i] = (val >> 16) & 0xFF;
639         if(i+1 < BufSize)
640                 Dest[i+1] = (val >> 8) & 0xFF;
641         
642         return Src - start_src;
643 }

UCC git Repository :: git.ucc.asn.au