AcessNative - TCP client implimented, buggy
[tpg/acess2.git] / AcessNative / acesskernel_src / server.c
1 /*
2  * Acess2 Native Kernel
3  * - Acess kernel emulation on another OS using SDL and UDP
4  *
5  * Syscall Server
6  */
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <string.h>
10 #include <SDL/SDL.h>
11 #ifdef __WIN32__
12 # include <windows.h>
13 # include <winsock.h>
14 #else
15 # include <unistd.h>
16 # include <sys/socket.h>
17 # include <netinet/in.h>
18 # include <arpa/inet.h> // inet_ntop
19 #endif
20 #include "../syscalls.h"
21 //#include <debug.h>
22
23 #define USE_TCP 1
24 #define MAX_CLIENTS     16
25
26 // === TYPES ===
27 typedef struct {
28          int    ClientID;
29         SDL_Thread      *WorkerThread;
30         #if USE_TCP
31          int    Socket;
32         #else
33         tRequestHeader  *CurrentRequest;
34         struct sockaddr_in      ClientAddr;
35         SDL_cond        *WaitFlag;
36         SDL_mutex       *Mutex;
37         #endif
38 }       tClient;
39
40 // === IMPORTS ===
41 extern tRequestHeader *SyscallRecieve(tRequestHeader *Request, int *ReturnLength);
42 extern int      Threads_CreateRootProcess(void);
43 extern void     Threads_SetThread(int TID);
44 // HACK: Should have these in a header
45 extern void     Log_Debug(const char *Subsys, const char *Message, ...);
46 extern void     Log_Notice(const char *Subsys, const char *Message, ...);
47 extern void     Log_Warning(const char *Subsys, const char *Message, ...);
48
49 // === PROTOTYPES ===
50 tClient *Server_GetClient(int ClientID);
51  int    Server_WorkerThread(void *ClientPtr);
52  int    SyscallServer(void);
53  int    Server_ListenThread(void *Unused);
54
55 // === GLOBALS ===
56 #ifdef __WIN32__
57 WSADATA gWinsock;
58 SOCKET  gSocket = INVALID_SOCKET;
59 #else
60 # define INVALID_SOCKET -1
61  int    gSocket = INVALID_SOCKET;
62 #endif
63 tClient gaServer_Clients[MAX_CLIENTS];
64 SDL_Thread      *gpServer_ListenThread;
65
66 // === CODE ===
67 int Server_GetClientID(void)
68 {
69          int    i;
70         Uint32  thisId = SDL_ThreadID();
71         
72         for( i = 0; i < MAX_CLIENTS; i ++ )
73         {
74                 if( SDL_GetThreadID(gaServer_Clients[i].WorkerThread) == thisId )
75                         return gaServer_Clients[i].ClientID;
76         }
77         
78         fprintf(stderr, "ERROR: Server_GetClientID - Thread is not allocated\n");
79         
80         return 0;
81 }
82
83 tClient *Server_GetClient(int ClientID)
84 {
85         tClient *ret = NULL;
86          int    i;
87         
88         // Allocate an ID if needed
89         if(ClientID == 0)
90                 ClientID = Threads_CreateRootProcess();
91         
92         for( i = 0; i < MAX_CLIENTS; i ++ )
93         {
94                 if( gaServer_Clients[i].ClientID == ClientID ) {
95                         return &gaServer_Clients[i];
96                 }
97                 if(!ret && gaServer_Clients[i].ClientID == 0)
98                         ret = &gaServer_Clients[i];
99         }
100         
101         // Uh oh, no free slots
102         // TODO: Dynamic allocation
103         if( !ret )
104                 return NULL;
105         
106         // Allocate a thread for the process
107         ret->ClientID = ClientID;
108         #if USE_TCP
109         ret->Socket = 0;
110         #else
111         ret->CurrentRequest = NULL;
112         #endif
113                 
114         if( !ret->WorkerThread ) {
115                 #if USE_TCP
116                 #else
117                 ret->WaitFlag = SDL_CreateCond();
118                 ret->Mutex = SDL_CreateMutex();
119                 SDL_mutexP( ret->Mutex );
120                 #endif
121                 ret->WorkerThread = SDL_CreateThread( Server_WorkerThread, ret );
122         }
123         
124         return ret;
125 }
126
127 int Server_WorkerThread(void *ClientPtr)
128 {
129         tClient *Client = ClientPtr;
130         
131         #if USE_TCP
132         for( ;; )
133         {
134                 fd_set  fds;
135                  int    nfd = Client->Socket+1;
136                 FD_ZERO(&fds);
137                 FD_SET(Client->Socket, &fds);
138                 
139                 int rv = select(nfd, &fds, NULL, NULL, NULL);   // TODO: Timeouts?
140                 if(rv <= 0) {
141                         perror("select");
142                         continue ;
143                 }
144                 
145                 if( FD_ISSET(Client->Socket, &fds) )
146                 {
147                         const int       ciMaxParamCount = 6;
148                         char    lbuf[sizeof(tRequestHeader) + ciMaxParamCount*sizeof(tRequestValue)];
149                         tRequestHeader  *hdr = (void*)lbuf;
150                         size_t  len = recv(Client->Socket, hdr, sizeof(*hdr), 0);
151                         Log_Debug("Server", "%i bytes of header", len);
152                         if( len == 0 )  break;
153                         if( len != sizeof(*hdr) ) {
154                                 // Oops?
155                                 Log_Warning("Server", "FD%i bad sized (%i != exp %i)",
156                                         Client->Socket, len, sizeof(*hdr));
157                                 continue ;
158                         }
159
160                         if( hdr->NParams > ciMaxParamCount ) {
161                                 // Oops.
162                                 Log_Warning("Server", "FD%i too many params (%i > max %i)",
163                                         Client->Socket, hdr->NParams, ciMaxParamCount);
164                                 continue ;
165                         }
166
167                         len = recv(Client->Socket, hdr->Params, hdr->NParams*sizeof(tRequestValue), 0);
168                         Log_Debug("Server", "%i bytes of params", len);
169                         if( len != hdr->NParams*sizeof(tRequestValue) ) {
170                                 // Oops.
171                         }
172
173                         // Get buffer size
174                         size_t  hdrsize = sizeof(tRequestHeader) + hdr->NParams*sizeof(tRequestValue);
175                         size_t  bufsize = hdrsize;
176                          int    i;
177                         for( i = 0; i < hdr->NParams; i ++ )
178                         {
179                                 if( hdr->Params[i].Flags & ARG_FLAG_ZEROED )
180                                         ;
181                                 else {
182                                         bufsize += hdr->Params[i].Length;
183                                 }
184                         }
185
186                         // Allocate full buffer
187                         hdr = malloc(bufsize);
188                         memcpy(hdr, lbuf, hdrsize);
189                         len = recv(Client->Socket, hdr->Params + hdr->NParams, bufsize - hdrsize, 0);
190                         Log_Debug("Server", "%i bytes of data", len);
191                         if( len != bufsize - hdrsize ) {
192                                 // Oops?
193                         }
194
195                          int    retlen;
196                         tRequestHeader  *retHeader;
197                         retHeader = SyscallRecieve(hdr, &retlen);
198                         if( !retHeader ) {
199                                 // Some sort of error
200                                 Log_Warning("Server", "SyscallRecieve failed?");
201                                 continue ;
202                         }
203                         
204                         send(Client->Socket, retHeader, retlen, 0); 
205
206                         // Clean up
207                         free(hdr);
208                 }
209         }
210         #else
211         tRequestHeader  *retHeader;
212         tRequestHeader  errorHeader;
213          int    retSize = 0;
214          int    sentSize;
215          int    cur_client_id = 0;
216         for( ;; )
217         {
218                 // Wait for something to do
219                 while( Client->CurrentRequest == NULL )
220                         SDL_CondWait(Client->WaitFlag, Client->Mutex);
221                 
222 //              Log_Debug("AcessSrv", "Worker got message %p", Client->CurrentRequest);
223                 
224                 if(Client->ClientID != cur_client_id) {
225 //                      Log_Debug("AcessSrv", "Client thread ID changed from %i to %i",
226 //                              cur_client_id, Client->ClientID);
227                         Threads_SetThread( Client->ClientID );
228                         cur_client_id = Client->ClientID;
229                 }
230                 
231                 // Debug
232                 {
233                         int     callid = Client->CurrentRequest->CallID;
234                         Log_Debug("AcessSrv", "Client %i request %i %s",
235                                 Client->ClientID, callid,
236                                 callid < N_SYSCALLS ? casSYSCALL_NAMES[callid] : "UNK"
237                                 );
238                 }
239                 
240                 // Get the response
241                 retHeader = SyscallRecieve(Client->CurrentRequest, &retSize);
242
243                 if( !retHeader ) {
244                         // Return an error to the client
245                         printf("ERROR: SyscallRecieve failed\n");
246                         errorHeader.CallID = Client->CurrentRequest->CallID;
247                         errorHeader.NParams = 0;
248                         retHeader = &errorHeader;
249                         retSize = sizeof(errorHeader);
250                 }
251                 
252                 // Set ID
253                 retHeader->ClientID = Client->ClientID;
254                 
255                 // Mark the thread as ready for another job
256                 free(Client->CurrentRequest);
257                 Client->CurrentRequest = 0;
258                 
259 //              Log_Debug("AcessSrv", "Sending %i to %x:%i (Client %i)",
260 //                      retSize, ntohl(Client->ClientAddr.sin_addr.s_addr),
261 //                      ntohs(Client->ClientAddr.sin_port),
262 //                      Client->ClientID
263 //                      );
264                 
265                 // Return the data
266                 sentSize = sendto(gSocket, retHeader, retSize, 0,
267                         (struct sockaddr*)&Client->ClientAddr, sizeof(Client->ClientAddr)
268                         );
269                 if( sentSize != retSize ) {
270                         perror("Server_WorkerThread - send");
271                 }
272                 
273                 // Free allocated header
274                 if( retHeader != &errorHeader )
275                         free( retHeader );
276         }
277         #endif
278 }
279
280 int SyscallServer(void)
281 {
282         struct sockaddr_in      server;
283         
284         #ifdef __WIN32__
285         /* Open windows connection */
286         if (WSAStartup(0x0101, &gWinsock) != 0)
287         {
288                 fprintf(stderr, "Could not open Windows connection.\n");
289                 exit(0);
290         }
291         #endif
292         
293         #if USE_TCP
294         // Open TCP Connection
295         gSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
296         #else
297         // Open UDP Connection
298         gSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
299         #endif
300         if (gSocket == INVALID_SOCKET)
301         {
302                 fprintf(stderr, "Could not create socket.\n");
303                 #if __WIN32__
304                 WSACleanup();
305                 #endif
306                 exit(0);
307         }
308         
309         // Set server address
310         memset(&server, 0, sizeof(struct sockaddr_in));
311         server.sin_family = AF_INET;
312         server.sin_port = htons(SERVER_PORT);
313         server.sin_addr.s_addr = htonl(INADDR_ANY);
314         
315         // Bind
316         if( bind(gSocket, (struct sockaddr *)&server, sizeof(struct sockaddr_in)) == -1 )
317         {
318                 fprintf(stderr, "Cannot bind address to socket.\n");
319                 perror("SyscallServer - bind");
320                 #if __WIN32__
321                 closesocket(gSocket);
322                 WSACleanup();
323                 #else
324                 close(gSocket);
325                 #endif
326                 exit(0);
327         }
328         
329         #if USE_TCP
330         listen(gSocket, 5);
331         #endif
332         
333         Log_Notice("AcessSrv", "Listening on 0.0.0.0:%i", SERVER_PORT);
334         gpServer_ListenThread = SDL_CreateThread( Server_ListenThread, NULL );
335         return 0;
336 }
337
338 int Server_ListenThread(void *Unused)
339 {       
340         // Wait for something to do :)
341         for( ;; )
342         {
343                 #if USE_TCP
344                 struct sockaddr_in      clientaddr;
345                 socklen_t       clientSize = sizeof(clientaddr);
346                  int    clientSock = accept(gSocket, (struct sockaddr*)&clientaddr, &clientSize);
347                 if( clientSock < 0 ) {
348                         perror("SyscallServer - accept");
349                         break ;
350                 }
351
352                 char    addrstr[4*8+8+1];
353                 inet_ntop(clientaddr.sin_family, &clientaddr.sin_addr, addrstr, sizeof(addrstr));
354                 Log_Debug("Server", "Client connection %s:%i\n", addrstr, ntohs(clientaddr.sin_port));
355                 
356                 // Perform auth
357                 size_t  len;
358                 tRequestAuthHdr authhdr;
359                 len = recv(clientSock, &authhdr, sizeof(authhdr), 0);
360                 if( len != sizeof(authhdr) ) {
361                         // Some form of error?
362                         Log_Warning("Server", "Client auth block bad size (%i != exp %i)",
363                                 len, sizeof(authhdr));
364                         continue ;
365                 }
366                 
367                 Log_Debug("Server", "Client assumed PID %i", authhdr.pid);
368
369                 tClient *client;
370                 if( authhdr.pid == 0 ) {
371                         // Allocate PID and client structure/thread
372                         client = Server_GetClient(0);
373                         client->Socket = clientSock;
374                         authhdr.pid = client->ClientID;
375                 }
376                 else {
377                         // Get client structure and make sure it's unused
378                         // - Auth token / verifcation?
379                         client = Server_GetClient(authhdr.pid);
380                         if( client->Socket != 0 ) {
381                                 Log_Warning("Server", "Client (%i)%p owned by FD%i but %s:%i tried to use it",
382                                         authhdr.pid, client, addrstr, clientaddr.sin_port);
383                                 authhdr.pid = 0;
384                         }
385                         else {
386                                 client->Socket = clientSock;
387                         }
388                 }
389                 Log_Debug("Server", "Client given PID %i", authhdr.pid);
390                 
391                 len = send(clientSock, &authhdr, sizeof(authhdr), 0);
392                 if( len != sizeof(authhdr) ) {
393                         // Ok, this is an error
394                         perror("Sending auth reply");
395                 }
396
397                 // All done, client thread should be watching now               
398
399                 #else
400                 char    data[BUFSIZ];
401                 tRequestHeader  *req = (void*)data;
402                 struct sockaddr_in      addr;
403                 uint    clientSize = sizeof(addr);
404                  int    length;
405                 tClient *client;
406                 
407                 length = recvfrom(gSocket, data, BUFSIZ, 0, (struct sockaddr*)&addr, &clientSize);
408                 
409                 if( length == -1 ) {
410                         perror("SyscallServer - recv");
411                         break;
412                 }
413                 
414                 // Hand off to a worker thread
415                 // - TODO: Actually have worker threads
416 //              Log_Debug("Server", "%i bytes from %x:%i", length,
417 //                      ntohl(addr.sin_addr.s_addr), ntohs(addr.sin_port));
418                 
419                 client = Server_GetClient(req->ClientID);
420                 // NOTE: Hack - Should check if all zero
421                 if( req->ClientID == 0 || client->ClientAddr.sin_port == 0 )
422                 {
423                         memcpy(&client->ClientAddr, &addr, sizeof(addr));
424                 }
425                 else if( memcmp(&client->ClientAddr, &addr, sizeof(addr)) != 0 )
426                 {
427                         printf("ClientID %i used by %x:%i\n",
428                                 client->ClientID, ntohl(addr.sin_addr.s_addr), ntohs(addr.sin_port));
429                         printf(" actually owned by %x:%i\n",
430                                 ntohl(client->ClientAddr.sin_addr.s_addr), ntohs(client->ClientAddr.sin_port));
431                         continue;
432                 }
433                 
434                 if( client->CurrentRequest ) {
435                         printf("Worker thread for %x:%i is busy\n",
436                                 ntohl(client->ClientAddr.sin_addr.s_addr), ntohs(client->ClientAddr.sin_port));
437                         continue;
438                 }
439                 
440 //              Log_Debug("AcessSrv", "Message from Client %i (%p)",
441 //                      client->ClientID, client);
442
443                 // Make a copy of the request data      
444                 req = malloc(length);
445                 memcpy(req, data, length);
446                 client->CurrentRequest = req;
447                 SDL_CondSignal(client->WaitFlag);
448                 #endif
449         }
450         return -1;
451 }

UCC git Repository :: git.ucc.asn.au