Cleaning up client code and server responses
[tpg/opendispense2.git] / src / server / server.c
1 /*
2  * OpenDispense 2 
3  * UCC (University [of WA] Computer Club) Electronic Accounting System
4  *
5  * server.c - Client Server Code
6  *
7  * This file is licenced under the 3-clause BSD Licence. See the file
8  * COPYING for full details.
9  */
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include "common.h"
13 #include <sys/socket.h>
14 #include <netinet/in.h>
15 #include <arpa/inet.h>
16 #include <unistd.h>
17 #include <string.h>
18
19 // HACKS
20 #define HACK_TPG_NOAUTH 1
21
22 // Statistics
23 #define MAX_CONNECTION_QUEUE    5
24 #define INPUT_BUFFER_SIZE       256
25
26 #define HASH_TYPE       SHA1
27 #define HASH_LENGTH     20
28
29 #define MSG_STR_TOO_LONG        "499 Command too long (limit "EXPSTR(INPUT_BUFFER_SIZE)")\n"
30
31 // === TYPES ===
32 typedef struct sClient
33 {
34          int    ID;     // Client ID
35          
36          int    bIsTrusted;     // Is the connection from a trusted host/port
37         
38         char    *Username;
39         char    Salt[9];
40         
41          int    UID;
42          int    bIsAuthed;
43 }       tClient;
44
45 // === PROTOTYPES ===
46 void    Server_Start(void);
47 void    Server_Cleanup(void);
48 void    Server_HandleClient(int Socket, int bTrusted);
49 char    *Server_ParseClientCommand(tClient *Client, char *CommandString);
50 // --- Commands ---
51 char    *Server_Cmd_USER(tClient *Client, char *Args);
52 char    *Server_Cmd_PASS(tClient *Client, char *Args);
53 char    *Server_Cmd_AUTOAUTH(tClient *Client, char *Args);
54 char    *Server_Cmd_ENUMITEMS(tClient *Client, char *Args);
55 char    *Server_Cmd_ITEMINFO(tClient *Client, char *Args);
56 char    *Server_Cmd_DISPENSE(tClient *Client, char *Args);
57 // --- Helpers ---
58  int    GetUserAuth(const char *Salt, const char *Username, const uint8_t *Hash);
59 void    HexBin(uint8_t *Dest, char *Src, int BufSize);
60
61 // === GLOBALS ===
62  int    giServer_Port = 1020;
63  int    giServer_NextClientID = 1;
64 // - Commands
65 struct sClientCommand {
66         char    *Name;
67         char    *(*Function)(tClient *Client, char *Arguments);
68 }       gaServer_Commands[] = {
69         {"USER", Server_Cmd_USER},
70         {"PASS", Server_Cmd_PASS},
71         {"AUTOAUTH", Server_Cmd_AUTOAUTH},
72         {"ENUM_ITEMS", Server_Cmd_ENUMITEMS},
73         {"ITEM_INFO", Server_Cmd_ITEMINFO},
74         {"DISPENSE", Server_Cmd_DISPENSE}
75 };
76 #define NUM_COMMANDS    (sizeof(gaServer_Commands)/sizeof(gaServer_Commands[0]))
77  int    giServer_Socket;
78
79 // === CODE ===
80 /**
81  * \brief Open listenting socket and serve connections
82  */
83 void Server_Start(void)
84 {
85          int    client_socket;
86         struct sockaddr_in      server_addr, client_addr;
87
88         atexit(Server_Cleanup);
89
90         // Create Server
91         giServer_Socket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
92         if( giServer_Socket < 0 ) {
93                 fprintf(stderr, "ERROR: Unable to create server socket\n");
94                 return ;
95         }
96         
97         // Make listen address
98         memset(&server_addr, 0, sizeof(server_addr));
99         server_addr.sin_family = AF_INET;       // Internet Socket
100         server_addr.sin_addr.s_addr = htonl(INADDR_ANY);        // Listen on all interfaces
101         server_addr.sin_port = htons(giServer_Port);    // Port
102
103         // Bind
104         if( bind(giServer_Socket, (struct sockaddr *) &server_addr, sizeof(server_addr)) < 0 ) {
105                 fprintf(stderr, "ERROR: Unable to bind to 0.0.0.0:%i\n", giServer_Port);
106                 perror("Binding");
107                 return ;
108         }
109         
110         // Listen
111         if( listen(giServer_Socket, MAX_CONNECTION_QUEUE) < 0 ) {
112                 fprintf(stderr, "ERROR: Unable to listen to socket\n");
113                 perror("Listen");
114                 return ;
115         }
116         
117         printf("Listening on 0.0.0.0:%i\n", giServer_Port);
118         
119         for(;;)
120         {
121                 uint    len = sizeof(client_addr);
122                  int    bTrusted = 0;
123                 
124                 client_socket = accept(giServer_Socket, (struct sockaddr *) &client_addr, &len);
125                 if(client_socket < 0) {
126                         fprintf(stderr, "ERROR: Unable to accept client connection\n");
127                         return ;
128                 }
129                 
130                 if(giDebugLevel >= 2) {
131                         char    ipstr[INET_ADDRSTRLEN];
132                         inet_ntop(AF_INET, &client_addr.sin_addr, ipstr, INET_ADDRSTRLEN);
133                         printf("Client connection from %s:%i\n",
134                                 ipstr, ntohs(client_addr.sin_port));
135                 }
136                 
137                 // Trusted Connections
138                 if( ntohs(client_addr.sin_port) < 1024 )
139                 {
140                         // TODO: Make this runtime configurable
141                         switch( ntohl( client_addr.sin_addr.s_addr ) )
142                         {
143                         case 0x7F000001:        // 127.0.0.1    localhost
144                         //case 0x825E0D00:      // 130.95.13.0
145                         case 0x825E0D12:        // 130.95.13.18 mussel
146                         case 0x825E0D17:        // 130.95.13.23 martello
147                                 bTrusted = 1;
148                                 break;
149                         default:
150                                 break;
151                         }
152                 }
153                 
154                 // TODO: Multithread this?
155                 Server_HandleClient(client_socket, bTrusted);
156                 
157                 close(client_socket);
158         }
159 }
160
161 void Server_Cleanup(void)
162 {
163         printf("Close(%i)\n", giServer_Socket);
164         close(giServer_Socket);
165 }
166
167 /**
168  * \brief Reads from a client socket and parses the command strings
169  * \param Socket        Client socket number/handle
170  * \param bTrusted      Is the client trusted?
171  */
172 void Server_HandleClient(int Socket, int bTrusted)
173 {
174         char    inbuf[INPUT_BUFFER_SIZE];
175         char    *buf = inbuf;
176          int    remspace = INPUT_BUFFER_SIZE-1;
177          int    bytes = -1;
178         tClient clientInfo = {0};
179         
180         // Initialise Client info
181         clientInfo.ID = giServer_NextClientID ++;
182         clientInfo.bIsTrusted = bTrusted;
183         
184         // Read from client
185         /*
186          * Notes:
187          * - The `buf` and `remspace` variables allow a line to span several
188          *   calls to recv(), if a line is not completed in one recv() call
189          *   it is saved to the beginning of `inbuf` and `buf` is updated to
190          *   the end of it.
191          */
192         while( (bytes = recv(Socket, buf, remspace, 0)) > 0 )
193         {
194                 char    *eol, *start;
195                 buf[bytes] = '\0';      // Allow us to use stdlib string functions on it
196                 
197                 // Split by lines
198                 start = inbuf;
199                 while( (eol = strchr(start, '\n')) )
200                 {
201                         char    *ret;
202                         *eol = '\0';
203                         ret = Server_ParseClientCommand(&clientInfo, start);
204                         // `ret` is a string on the heap
205                         send(Socket, ret, strlen(ret), 0);
206                         free(ret);
207                         start = eol + 1;
208                 }
209                 
210                 // Check if there was an incomplete line
211                 if( *start != '\0' ) {
212                          int    tailBytes = bytes - (start-buf);
213                         // Roll back in buffer
214                         memcpy(inbuf, start, tailBytes);
215                         remspace -= tailBytes;
216                         if(remspace == 0) {
217                                 send(Socket, MSG_STR_TOO_LONG, sizeof(MSG_STR_TOO_LONG), 0);
218                                 buf = inbuf;
219                                 remspace = INPUT_BUFFER_SIZE - 1;
220                         }
221                 }
222                 else {
223                         buf = inbuf;
224                         remspace = INPUT_BUFFER_SIZE - 1;
225                 }
226         }
227         
228         // Check for errors
229         if( bytes < 0 ) {
230                 fprintf(stderr, "ERROR: Unable to recieve from client on socket %i\n", Socket);
231                 return ;
232         }
233         
234         if(giDebugLevel >= 2) {
235                 printf("Client %i: Disconnected\n", clientInfo.ID);
236         }
237 }
238
239 /**
240  * \brief Parses a client command and calls the required helper function
241  * \param Client        Pointer to client state structure
242  * \param CommandString Command from client (single line of the command)
243  * \return Heap String to return to the client
244  */
245 char *Server_ParseClientCommand(tClient *Client, char *CommandString)
246 {
247         char    *space, *args;
248          int    i;
249         
250         // Split at first space
251         space = strchr(CommandString, ' ');
252         if(space == NULL) {
253                 args = NULL;
254         }
255         else {
256                 *space = '\0';
257                 args = space + 1;
258         }
259         
260         // Find command
261         for( i = 0; i < NUM_COMMANDS; i++ )
262         {
263                 if(strcmp(CommandString, gaServer_Commands[i].Name) == 0)
264                         return gaServer_Commands[i].Function(Client, args);
265         }
266         
267         return strdup("400 Unknown Command\n");
268 }
269
270 // ---
271 // Commands
272 // ---
273 /**
274  * \brief Set client username
275  * 
276  * Usage: USER <username>
277  */
278 char *Server_Cmd_USER(tClient *Client, char *Args)
279 {
280         char    *ret;
281         
282         // Debug!
283         if( giDebugLevel )
284                 printf("Client %i authenticating as '%s'\n", Client->ID, Args);
285         
286         // Save username
287         if(Client->Username)
288                 free(Client->Username);
289         Client->Username = strdup(Args);
290         
291         #if USE_SALT
292         // Create a salt (that changes if the username is changed)
293         // Yes, I know, I'm a little paranoid, but who isn't?
294         Client->Salt[0] = 0x21 + (rand()&0x3F);
295         Client->Salt[1] = 0x21 + (rand()&0x3F);
296         Client->Salt[2] = 0x21 + (rand()&0x3F);
297         Client->Salt[3] = 0x21 + (rand()&0x3F);
298         Client->Salt[4] = 0x21 + (rand()&0x3F);
299         Client->Salt[5] = 0x21 + (rand()&0x3F);
300         Client->Salt[6] = 0x21 + (rand()&0x3F);
301         Client->Salt[7] = 0x21 + (rand()&0x3F);
302         
303         // TODO: Also send hash type to use, (SHA1 or crypt according to [DAA])
304         // "100 Salt xxxxXXXX\n"
305         ret = strdup("100 SALT xxxxXXXX\n");
306         sprintf(ret, "100 SALT %s\n", Client->Salt);
307         #else
308         ret = strdup("100 User Set\n");
309         #endif
310         return ret;
311 }
312
313 /**
314  * \brief Authenticate as a user
315  * 
316  * Usage: PASS <hash>
317  */
318 char *Server_Cmd_PASS(tClient *Client, char *Args)
319 {
320         uint8_t clienthash[HASH_LENGTH] = {0};
321         
322         // Read user's hash
323         HexBin(clienthash, Args, HASH_LENGTH);
324         
325         // TODO: Decrypt password passed
326         
327         Client->UID = GetUserAuth(Client->Salt, Client->Username, clienthash);
328
329         if( Client->UID != -1 ) {
330                 Client->bIsAuthed = 1;
331                 return strdup("200 Auth OK\n");
332         }
333
334         if( giDebugLevel ) {
335                  int    i;
336                 printf("Client %i: Password hash ", Client->ID);
337                 for(i=0;i<HASH_LENGTH;i++)
338                         printf("%02x", clienthash[i]&0xFF);
339                 printf("\n");
340         }
341         
342         return strdup("401 Auth Failure\n");
343 }
344
345 /**
346  * \brief Authenticate as a user without a password
347  * 
348  * Usage: AUTOAUTH <user>
349  */
350 char *Server_Cmd_AUTOAUTH(tClient *Client, char *Args)
351 {
352         char    *spos = strchr(Args, ' ');
353         if(spos)        *spos = '\0';   // Remove characters after the ' '
354         
355         // Check if trusted
356         if( !Client->bIsTrusted ) {
357                 if(giDebugLevel)
358                         printf("Client %i: Untrusted client attempting to AUTOAUTH\n", Client->ID);
359                 return strdup("401 Untrusted\n");
360         }
361         
362         // Get UID
363         Client->UID = GetUserID( Args );
364         if( Client->UID < 0 ) {
365                 if(giDebugLevel)
366                         printf("Client %i: Unknown user '%s'\n", Client->ID, Args);
367                 return strdup("401 Auth Failure\n");
368         }
369         
370         if(giDebugLevel)
371                 printf("Client %i: Authenticated as '%s' (%i)\n", Client->ID, Args, Client->UID);
372         
373         return strdup("200 Auth OK\n");
374 }
375
376 /**
377  * \brief Enumerate the items that the server knows about
378  */
379 char *Server_Cmd_ENUMITEMS(tClient *Client, char *Args)
380 {
381 //       int    nItems = giNumItems;
382          int    retLen;
383          int    i;
384         char    *ret;
385
386         retLen = snprintf(NULL, 0, "201 Items %i", giNumItems);
387
388         for( i = 0; i < giNumItems; i ++ )
389         {
390                 retLen += snprintf(NULL, 0, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
391         }
392
393         ret = malloc(retLen+1);
394         retLen = 0;
395         retLen += sprintf(ret+retLen, "201 Items %i", giNumItems);
396
397         for( i = 0; i < giNumItems; i ++ ) {
398                 retLen += sprintf(ret+retLen, " %s:%i", gaItems[i].Handler->Name, gaItems[i].ID);
399         }
400
401         strcat(ret, "\n");
402
403         return ret;
404 }
405
406 tItem *_GetItemFromString(char *String)
407 {
408         tHandler        *handler;
409         char    *type = String;
410         char    *colon = strchr(String, ':');
411          int    num, i;
412         
413         if( !colon ) {
414                 return NULL;
415         }
416
417         num = atoi(colon+1);
418         *colon = '\0';
419
420         // Find handler
421         handler = NULL;
422         for( i = 0; i < giNumHandlers; i ++ )
423         {
424                 if( strcmp(gaHandlers[i]->Name, type) == 0) {
425                         handler = gaHandlers[i];
426                         break;
427                 }
428         }
429         if( !handler ) {
430                 return NULL;
431         }
432
433         // Find item
434         for( i = 0; i < giNumItems; i ++ )
435         {
436                 if( gaItems[i].Handler != handler )     continue;
437                 if( gaItems[i].ID != num )      continue;
438                 return &gaItems[i];
439         }
440         return NULL;
441 }
442
443 /**
444  * \brief Fetch information on a specific item
445  */
446 char *Server_Cmd_ITEMINFO(tClient *Client, char *Args)
447 {
448          int    retLen = 0;
449         char    *ret;
450         tItem   *item = _GetItemFromString(Args);
451         
452         if( !item ) {
453                 return strdup("406 Bad Item ID\n");
454         }
455
456         // Create return
457         retLen = snprintf(NULL, 0, "202 Item %s:%i %i %s\n",
458                 item->Handler->Name, item->ID, item->Price, item->Name);
459         ret = malloc(retLen+1);
460         sprintf(ret, "202 Item %s:%i %i %s\n",
461                 item->Handler->Name, item->ID, item->Price, item->Name);
462
463         return ret;
464 }
465
466 char *Server_Cmd_DISPENSE(tClient *Client, char *Args)
467 {
468         tItem   *item;
469          int    ret;
470         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
471
472         item = _GetItemFromString(Args);
473         if( !item ) {
474                 return strdup("406 Bad Item ID\n");
475         }
476
477         switch( ret = DispenseItem( Client->UID, item ) )
478         {
479         case 0: return strdup("200 Dispense OK\n");
480         case 1: return strdup("501 Unable to dispense\n");
481         case 2: return strdup("402 Poor You\n");
482         default:
483                 return strdup("500 Dispense Error\n");
484         }
485 }
486
487 char *Server_Cmd_GIVE(tClient *Client, char *Args)
488 {
489         char    *recipient, *ammount, *reason;
490          int    uid, iAmmount;
491         
492         if( !Client->bIsAuthed )        return strdup("401 Not Authenticated\n");
493
494         recipient = Args;
495
496         ammount = strchr(Args, ' ');
497         if( !ammount )  return strdup("407 Invalid Argument, expected 3 parameters, 1 encountered\n");
498         *ammount = '\0';
499         ammount ++;
500
501         reason = strchr(ammount, ' ');
502         if( !reason )   return strdup("407 Invalid Argument, expected 3 parameters, 2 encountered\n");
503         *reason = '\0';
504         reason ++;
505
506         // Get recipient
507         uid = GetUserID(recipient);
508         if( uid == -1 ) return strdup("404 Invalid target user");
509
510         // Parse ammount
511         iAmmount = atoi(ammount);
512         if( iAmmount <= 0 )     return strdup("407 Invalid Argument, ammount must be > zero\n");
513
514         // Do give
515         switch( Transfer(Client->UID, uid, iAmmount, reason) )
516         {
517         case 0:
518                 return strdup("200 Give OK\n");
519         default:
520                 return strdup("402 Poor You\n");
521         }
522 }
523
524 /**
525  * \brief Authenticate a user
526  * \return User ID, or -1 if authentication failed
527  */
528 int GetUserAuth(const char *Salt, const char *Username, const uint8_t *ProvidedHash)
529 {
530         #if 0
531         uint8_t h[20];
532          int    ofs = strlen(Username) + strlen(Salt);
533         char    input[ ofs + 40 + 1];
534         char    tmp[4 + strlen(Username) + 1];  // uid=%s
535         #endif
536         
537         #if HACK_TPG_NOAUTH
538         if( strcmp(Username, "tpg") == 0 )
539                 return GetUserID("tpg");
540         #endif
541         
542         #if 0
543         //
544         strcpy(input, Username);
545         strcpy(input, Salt);
546         // TODO: Get user's SHA-1 hash
547         sprintf(tmp, "uid=%s", Username);
548         ldap_search_s(ld, "", LDAP_SCOPE_BASE, tmp, "userPassword", 0, res);
549         
550         sprintf(input+ofs, "%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x",
551                 h[ 0], h[ 1], h[ 2], h[ 3], h[ 4], h[ 5], h[ 6], h[ 7], h[ 8], h[ 9],
552                 h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19]
553                 );
554         // Then create the hash from the provided salt
555         // Compare that with the provided hash
556         #endif
557         
558         return -1;
559 }
560
561 // --- INTERNAL HELPERS ---
562 // TODO: Move to another file
563 void HexBin(uint8_t *Dest, char *Src, int BufSize)
564 {
565          int    i;
566         for( i = 0; i < BufSize; i ++ )
567         {
568                 uint8_t val = 0;
569                 
570                 if('0' <= *Src && *Src <= '9')
571                         val |= (*Src-'0') << 4;
572                 else if('A' <= *Src && *Src <= 'F')
573                         val |= (*Src-'A'+10) << 4;
574                 else if('a' <= *Src && *Src <= 'f')
575                         val |= (*Src-'a'+10) << 4;
576                 else
577                         break;
578                 Src ++;
579                 
580                 if('0' <= *Src && *Src <= '9')
581                         val |= (*Src-'0');
582                 else if('A' <= *Src && *Src <= 'F')
583                         val |= (*Src-'A'+10);
584                 else if('a' <= *Src && *Src <= 'f')
585                         val |= (*Src-'a'+10);
586                 else
587                         break;
588                 Src ++;
589                 
590                 Dest[i] = val;
591         }
592         for( ; i < BufSize; i++ )
593                 Dest[i] = 0;
594 }
595
596 /**
597  * \brief Decode a Base64 value
598  */
599 int UnBase64(uint8_t *Dest, char *Src, int BufSize)
600 {
601         uint32_t        val;
602          int    i, j;
603         char    *start_src = Src;
604         
605         for( i = 0; i+2 < BufSize; i += 3 )
606         {
607                 val = 0;
608                 for( j = 0; j < 4; j++, Src ++ ) {
609                         if('A' <= *Src && *Src <= 'Z')
610                                 val |= (*Src - 'A') << ((3-j)*6);
611                         else if('a' <= *Src && *Src <= 'z')
612                                 val |= (*Src - 'a' + 26) << ((3-j)*6);
613                         else if('0' <= *Src && *Src <= '9')
614                                 val |= (*Src - '0' + 52) << ((3-j)*6);
615                         else if(*Src == '+')
616                                 val |= 62 << ((3-j)*6);
617                         else if(*Src == '/')
618                                 val |= 63 << ((3-j)*6);
619                         else if(!*Src)
620                                 break;
621                         else if(*Src != '=')
622                                 j --;   // Ignore invalid characters
623                 }
624                 Dest[i  ] = (val >> 16) & 0xFF;
625                 Dest[i+1] = (val >> 8) & 0xFF;
626                 Dest[i+2] = val & 0xFF;
627                 if(j != 4)      break;
628         }
629         
630         // Finish things off
631         if(i   < BufSize)
632                 Dest[i] = (val >> 16) & 0xFF;
633         if(i+1 < BufSize)
634                 Dest[i+1] = (val >> 8) & 0xFF;
635         
636         return Src - start_src;
637 }

UCC git Repository :: git.ucc.asn.au