6bf0434dd83fedd0c2a8ef2b4d472c3ad2407566
[tpg/opendispense2.git] / server / src / server.c
1 /*
2  * OpenDispense 2 
3  * UCC (University [of WA] Computer Club) Electronic Accounting System
4  *
5  * server.c - Client Server Code
6  *
7  * This file is licenced under the 3-clause BSD Licence. See the file
8  * COPYING for full details.
9  */
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include "common.h"
13 #include <sys/socket.h>
14 #include <netinet/in.h>
15 #include <arpa/inet.h>
16 #include <unistd.h>
17 #include <string.h>
18
19 #define MAX_CONNECTION_QUEUE    5
20 #define INPUT_BUFFER_SIZE       256
21
22 #define HASH_TYPE       SHA512
23 #define HASH_LENGTH     64
24
25 #define MSG_STR_TOO_LONG        "499 Command too long (limit "EXPSTR(INPUT_BUFFER_SIZE)")\n"
26
27 // === TYPES ===
28 typedef struct sClient
29 {
30          int    ID;     // Client ID
31          
32          int    bIsTrusted;     // Is the connection from a trusted host/port
33         
34         char    *Username;
35         char    Salt[9];
36         
37          int    UID;
38          int    bIsAuthed;
39 }       tClient;
40
41 // === PROTOTYPES ===
42 void    Server_Start(void);
43 void    Server_HandleClient(int Socket, int bTrusted);
44 char    *Server_ParseClientCommand(tClient *Client, char *CommandString);
45 // --- Commands ---
46 char    *Server_Cmd_USER(tClient *Client, char *Args);
47 char    *Server_Cmd_PASS(tClient *Client, char *Args);
48 char    *Server_Cmd_AUTOAUTH(tClient *Client, char *Args);
49 // --- Helpers ---
50 void    HexBin(uint8_t *Dest, char *Src, int BufSize);
51
52 // === GLOBALS ===
53  int    giServer_Port = 1020;
54  int    giServer_NextClientID = 1;
55 // - Commands
56 struct sClientCommand {
57         char    *Name;
58         char    *(*Function)(tClient *Client, char *Arguments);
59 }       gaServer_Commands[] = {
60         {"USER", Server_Cmd_USER},
61         {"PASS", Server_Cmd_PASS},
62         {"AUTOAUTH", Server_Cmd_AUTOAUTH}
63 };
64 #define NUM_COMMANDS    (sizeof(gaServer_Commands)/sizeof(gaServer_Commands[0]))
65
66 // === CODE ===
67 /**
68  * \brief Open listenting socket and serve connections
69  */
70 void Server_Start(void)
71 {
72          int    server_socket, client_socket;
73         struct sockaddr_in      server_addr, client_addr;
74
75         // Create Server
76         server_socket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
77         if( server_socket < 0 ) {
78                 fprintf(stderr, "ERROR: Unable to create server socket\n");
79                 return ;
80         }
81         
82         // Make listen address
83         memset(&server_addr, 0, sizeof(server_addr));
84         server_addr.sin_family = AF_INET;       // Internet Socket
85         server_addr.sin_addr.s_addr = htonl(INADDR_ANY);        // Listen on all interfaces
86         server_addr.sin_port = htons(giServer_Port);    // Port
87
88         // Bind
89         if( bind(server_socket, (struct sockaddr *) &server_addr, sizeof(server_addr)) < 0 ) {
90                 fprintf(stderr, "ERROR: Unable to bind to 0.0.0.0:%i\n", giServer_Port);
91                 return ;
92         }
93         
94         // Listen
95         if( listen(server_socket, MAX_CONNECTION_QUEUE) < 0 ) {
96                 fprintf(stderr, "ERROR: Unable to listen to socket\n");
97                 return ;
98         }
99         
100         printf("Listening on 0.0.0.0:%i\n", giServer_Port);
101         
102         for(;;)
103         {
104                 uint    len = sizeof(client_addr);
105                  int    bTrusted = 0;
106                 
107                 client_socket = accept(server_socket, (struct sockaddr *) &client_addr, &len);
108                 if(client_socket < 0) {
109                         fprintf(stderr, "ERROR: Unable to accept client connection\n");
110                         return ;
111                 }
112                 
113                 if(giDebugLevel >= 2) {
114                         char    ipstr[INET_ADDRSTRLEN];
115                         inet_ntop(AF_INET, &client_addr.sin_addr, ipstr, INET_ADDRSTRLEN);
116                         printf("Client connection from %s:%i\n",
117                                 ipstr, ntohs(client_addr.sin_port));
118                 }
119                 
120                 // Trusted Connections
121                 if( ntohs(client_addr.sin_port) < 1024 )
122                 {
123                         // TODO: Make this runtime configurable
124                         switch( ntohl( client_addr.sin_addr.s_addr ) )
125                         {
126                         case 0x7F000001:        // 127.0.0.1    localhost
127                         //case 0x825E0D00:      // 130.95.13.0
128                         case 0x825E0D12:        // 130.95.13.18 mussel
129                         case 0x825E0D17:        // 130.95.13.23 martello
130                                 bTrusted = 1;
131                                 break;
132                         default:
133                                 break;
134                         }
135                 }
136                 
137                 // TODO: Multithread this?
138                 Server_HandleClient(client_socket, bTrusted);
139                 
140                 close(client_socket);
141         }
142 }
143
144 /**
145  * \brief Reads from a client socket and parses the command strings
146  * \param Socket        Client socket number/handle
147  * \param bTrusted      Is the client trusted?
148  */
149 void Server_HandleClient(int Socket, int bTrusted)
150 {
151         char    inbuf[INPUT_BUFFER_SIZE];
152         char    *buf = inbuf;
153          int    remspace = INPUT_BUFFER_SIZE-1;
154          int    bytes = -1;
155         tClient clientInfo = {0};
156         
157         // Initialise Client info
158         clientInfo.ID = giServer_NextClientID ++;
159         clientInfo.bIsTrusted = bTrusted;
160         
161         // Read from client
162         /*
163          * Notes:
164          * - The `buf` and `remspace` variables allow a line to span several
165          *   calls to recv(), if a line is not completed in one recv() call
166          *   it is saved to the beginning of `inbuf` and `buf` is updated to
167          *   the end of it.
168          */
169         while( (bytes = recv(Socket, buf, remspace, 0)) > 0 )
170         {
171                 char    *eol, *start;
172                 buf[bytes] = '\0';      // Allow us to use stdlib string functions on it
173                 
174                 // Split by lines
175                 start = inbuf;
176                 while( (eol = strchr(start, '\n')) )
177                 {
178                         char    *ret;
179                         *eol = '\0';
180                         ret = Server_ParseClientCommand(&clientInfo, start);
181                         // `ret` is a string on the heap
182                         send(Socket, ret, strlen(ret), 0);
183                         free(ret);
184                         start = eol + 1;
185                 }
186                 
187                 // Check if there was an incomplete line
188                 if( *start != '\0' ) {
189                          int    tailBytes = bytes - (start-buf);
190                         // Roll back in buffer
191                         memcpy(inbuf, start, tailBytes);
192                         remspace -= tailBytes;
193                         if(remspace == 0) {
194                                 send(Socket, MSG_STR_TOO_LONG, sizeof(MSG_STR_TOO_LONG), 0);
195                                 buf = inbuf;
196                                 remspace = INPUT_BUFFER_SIZE - 1;
197                         }
198                 }
199                 else {
200                         buf = inbuf;
201                         remspace = INPUT_BUFFER_SIZE - 1;
202                 }
203         }
204         
205         // Check for errors
206         if( bytes < 0 ) {
207                 fprintf(stderr, "ERROR: Unable to recieve from client on socket %i\n", Socket);
208                 return ;
209         }
210         
211         if(giDebugLevel >= 2) {
212                 printf("Client %i: Disconnected\n", clientInfo.ID);
213         }
214 }
215
216 /**
217  * \brief Parses a client command and calls the required helper function
218  * \param Client        Pointer to client state structure
219  * \param CommandString Command from client (single line of the command)
220  * \return Heap String to return to the client
221  */
222 char *Server_ParseClientCommand(tClient *Client, char *CommandString)
223 {
224         char    *space, *args;
225          int    i;
226         
227         // Split at first space
228         space = strchr(CommandString, ' ');
229         if(space == NULL) {
230                 args = NULL;
231         }
232         else {
233                 *space = '\0';
234                 args = space + 1;
235         }
236         
237         // Find command
238         for( i = 0; i < NUM_COMMANDS; i++ )
239         {
240                 if(strcmp(CommandString, gaServer_Commands[i].Name) == 0)
241                         return gaServer_Commands[i].Function(Client, args);
242         }
243         
244         return strdup("400 Unknown Command\n");
245 }
246
247 // ---
248 // Commands
249 // ---
250 /**
251  * \brief Set client username
252  * 
253  * Usage: USER <username>
254  */
255 char *Server_Cmd_USER(tClient *Client, char *Args)
256 {
257         char    *ret;
258         
259         // Debug!
260         if( giDebugLevel )
261                 printf("Client %i authenticating as '%s'\n", Client->ID, Args);
262         
263         // Save username
264         if(Client->Username)
265                 free(Client->Username);
266         Client->Username = strdup(Args);
267         
268         #if USE_SALT
269         // Create a salt (that changes if the username is changed)
270         // Yes, I know, I'm a little paranoid, but who isn't?
271         Client->Salt[0] = 0x21 + (rand()&0x3F);
272         Client->Salt[1] = 0x21 + (rand()&0x3F);
273         Client->Salt[2] = 0x21 + (rand()&0x3F);
274         Client->Salt[3] = 0x21 + (rand()&0x3F);
275         Client->Salt[4] = 0x21 + (rand()&0x3F);
276         Client->Salt[5] = 0x21 + (rand()&0x3F);
277         Client->Salt[6] = 0x21 + (rand()&0x3F);
278         Client->Salt[7] = 0x21 + (rand()&0x3F);
279         
280         // "100 Salt xxxxXXXX\n"
281         ret = strdup("100 SALT xxxxXXXX\n");
282         sprintf(ret, "100 SALT %s\n", Client->Salt);
283         #else
284         ret = strdup("100 User Set\n");
285         #endif
286         return ret;
287 }
288
289 /**
290  * \brief Authenticate as a user
291  * 
292  * Usage: PASS <hash>
293  */
294 char *Server_Cmd_PASS(tClient *Client, char *Args)
295 {
296         uint8_t clienthash[HASH_LENGTH] = {0};
297         
298         // Read user's hash
299         HexBin(clienthash, Args, HASH_LENGTH);
300         
301         if( giDebugLevel ) {
302                  int    i;
303                 printf("Client %i: Password hash ", Client->ID);
304                 for(i=0;i<HASH_LENGTH;i++)
305                         printf("%02x", clienthash[i]&0xFF);
306                 printf("\n");
307         }
308         
309         return strdup("401 Auth Failure\n");
310 }
311
312 /**
313  * \brief Authenticate as a user without a password
314  * 
315  * Usage: AUTOAUTH <user>
316  */
317 char *Server_Cmd_AUTOAUTH(tClient *Client, char *Args)
318 {
319         char    *spos = strchr(Args, ' ');
320         if(spos)        *spos = '\0';   // Remove characters after the ' '
321         
322         // Check if trusted
323         if( !Client->bIsTrusted ) {
324                 if(giDebugLevel)
325                         printf("Client %i: Untrusted client attempting to AUTOAUTH\n", Client->ID);
326                 return strdup("401 Untrusted\n");
327         }
328         
329         // Get UID
330         Client->UID = GetUserID( Args );
331         if( Client->UID < 0 ) {
332                 if(giDebugLevel)
333                         printf("Client %i: Unknown user '%s'\n", Client->ID, Args);
334                 return strdup("401 Auth Failure\n");
335         }
336         
337         if(giDebugLevel)
338                 printf("Client %i: Authenticated as '%s' (%i)\n", Client->ID, Args, Client->UID);
339         
340         return strdup("200 Auth OK\n");
341 }
342
343 // --- INTERNAL HELPERS ---
344 // TODO: Move to another file
345 void HexBin(uint8_t *Dest, char *Src, int BufSize)
346 {
347          int    i;
348         for( i = 0; i < BufSize; i ++ )
349         {
350                 uint8_t val = 0;
351                 
352                 if('0' <= *Src && *Src <= '9')
353                         val |= (*Src-'0') << 4;
354                 else if('A' <= *Src && *Src <= 'F')
355                         val |= (*Src-'A'+10) << 4;
356                 else if('a' <= *Src && *Src <= 'f')
357                         val |= (*Src-'a'+10) << 4;
358                 else
359                         break;
360                 Src ++;
361                 
362                 if('0' <= *Src && *Src <= '9')
363                         val |= (*Src-'0');
364                 else if('A' <= *Src && *Src <= 'F')
365                         val |= (*Src-'A'+10);
366                 else if('a' <= *Src && *Src <= 'f')
367                         val |= (*Src-'a'+10);
368                 else
369                         break;
370                 Src ++;
371                 
372                 Dest[i] = val;
373         }
374         for( ; i < BufSize; i++ )
375                 Dest[i] = 0;
376 }
377
378 /**
379  * \brief Decode a Base64 value
380  */
381 int UnBase64(uint8_t *Dest, char *Src, int BufSize)
382 {
383         uint32_t        val;
384          int    i, j;
385         char    *start_src = Src;
386         
387         for( i = 0; i+2 < BufSize; i += 3 )
388         {
389                 val = 0;
390                 for( j = 0; j < 4; j++, Src ++ ) {
391                         if('A' <= *Src && *Src <= 'Z')
392                                 val |= (*Src - 'A') << ((3-j)*6);
393                         else if('a' <= *Src && *Src <= 'z')
394                                 val |= (*Src - 'a' + 26) << ((3-j)*6);
395                         else if('0' <= *Src && *Src <= '9')
396                                 val |= (*Src - '0' + 52) << ((3-j)*6);
397                         else if(*Src == '+')
398                                 val |= 62 << ((3-j)*6);
399                         else if(*Src == '/')
400                                 val |= 63 << ((3-j)*6);
401                         else if(!*Src)
402                                 break;
403                         else if(*Src != '=')
404                                 j --;   // Ignore invalid characters
405                 }
406                 Dest[i  ] = (val >> 16) & 0xFF;
407                 Dest[i+1] = (val >> 8) & 0xFF;
408                 Dest[i+2] = val & 0xFF;
409                 if(j != 4)      break;
410         }
411         
412         // Finish things off
413         if(i   < BufSize)
414                 Dest[i] = (val >> 16) & 0xFF;
415         if(i+1 < BufSize)
416                 Dest[i+1] = (val >> 8) & 0xFF;
417         
418         return Src - start_src;
419 }

UCC git Repository :: git.ucc.asn.au