Merge branch 'master' of git://git.ucc.asn.au/tpg/acess2
[tpg/acess2.git] / AcessNative / acesskernel_src / server.c
1 /*
2  * Acess2 Native Kernel
3  * - Acess kernel emulation on another OS using SDL and UDP
4  *
5  * Syscall Server
6  */
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <string.h>
10 #include <SDL/SDL.h>
11 #ifdef __WIN32__
12 # include <windows.h>
13 # include <winsock.h>
14 #else
15 # include <unistd.h>
16 # include <sys/socket.h>
17 # include <netinet/in.h>
18 # include <arpa/inet.h> // inet_ntop
19 #endif
20 #include "../syscalls.h"
21 //#include <debug.h>
22
23 #define USE_TCP 1
24 #define MAX_CLIENTS     16
25
26 // === TYPES ===
27 typedef struct {
28          int    ClientID;
29         SDL_Thread      *WorkerThread;
30         #if USE_TCP
31          int    Socket;
32         #else
33         tRequestHeader  *CurrentRequest;
34         struct sockaddr_in      ClientAddr;
35         SDL_cond        *WaitFlag;
36         SDL_mutex       *Mutex;
37         #endif
38 }       tClient;
39
40 // === IMPORTS ===
41 extern tRequestHeader *SyscallRecieve(tRequestHeader *Request, int *ReturnLength);
42 extern int      Threads_CreateRootProcess(void);
43 extern void     Threads_SetThread(int TID);
44 // HACK: Should have these in a header
45 extern void     Log_Debug(const char *Subsys, const char *Message, ...);
46 extern void     Log_Notice(const char *Subsys, const char *Message, ...);
47 extern void     Log_Warning(const char *Subsys, const char *Message, ...);
48
49 // === PROTOTYPES ===
50 tClient *Server_GetClient(int ClientID);
51  int    Server_WorkerThread(void *ClientPtr);
52  int    SyscallServer(void);
53  int    Server_ListenThread(void *Unused);
54
55 // === GLOBALS ===
56 #ifdef __WIN32__
57 WSADATA gWinsock;
58 SOCKET  gSocket = INVALID_SOCKET;
59 #else
60 # define INVALID_SOCKET -1
61  int    gSocket = INVALID_SOCKET;
62 #endif
63 tClient gaServer_Clients[MAX_CLIENTS];
64 SDL_Thread      *gpServer_ListenThread;
65
66 // === CODE ===
67 int Server_GetClientID(void)
68 {
69          int    i;
70         Uint32  thisId = SDL_ThreadID();
71         
72         for( i = 0; i < MAX_CLIENTS; i ++ )
73         {
74                 if( SDL_GetThreadID(gaServer_Clients[i].WorkerThread) == thisId )
75                         return gaServer_Clients[i].ClientID;
76         }
77         
78         fprintf(stderr, "ERROR: Server_GetClientID - Thread is not allocated\n");
79         
80         return 0;
81 }
82
83 tClient *Server_GetClient(int ClientID)
84 {
85         tClient *ret = NULL;
86          int    i;
87         
88         // Allocate an ID if needed
89         if(ClientID == 0)
90                 ClientID = Threads_CreateRootProcess();
91         
92         for( i = 0; i < MAX_CLIENTS; i ++ )
93         {
94                 if( gaServer_Clients[i].ClientID == ClientID ) {
95                         return &gaServer_Clients[i];
96                 }
97                 if(!ret && gaServer_Clients[i].ClientID == 0)
98                         ret = &gaServer_Clients[i];
99         }
100         
101         // Uh oh, no free slots
102         // TODO: Dynamic allocation
103         if( !ret )
104                 return NULL;
105         
106         // Allocate a thread for the process
107         ret->ClientID = ClientID;
108         #if USE_TCP
109         ret->Socket = 0;
110         #else
111         ret->CurrentRequest = NULL;
112         #endif
113                 
114         if( !ret->WorkerThread ) {
115                 #if USE_TCP
116                 #else
117                 ret->WaitFlag = SDL_CreateCond();
118                 ret->Mutex = SDL_CreateMutex();
119                 SDL_mutexP( ret->Mutex );
120                 #endif
121                 ret->WorkerThread = SDL_CreateThread( Server_WorkerThread, ret );
122         }
123         
124         return ret;
125 }
126
127 int Server_WorkerThread(void *ClientPtr)
128 {
129         tClient *Client = ClientPtr;
130         
131         #if USE_TCP
132         for( ;; )
133         {
134                 fd_set  fds;
135                  int    nfd = Client->Socket;
136                 FD_ZERO(&fds);
137                 FD_SET(Client->Socket, &fds);
138                 
139                 select(nfd, &fds, NULL, NULL, NULL);    // TODO: Timeouts?
140                 
141                 if( FD_ISSET(Client->Socket, &fds) )
142                 {
143                         const int       ciMaxParamCount = 6;
144                         char    lbuf[sizeof(tRequestHeader) + ciMaxParamCount*sizeof(tRequestValue)];
145                         tRequestHeader  *hdr = (void*)lbuf;
146                         size_t  len = recv(Client->Socket, hdr, sizeof(*hdr), 0);
147                         if( len != sizeof(hdr) ) {
148                                 // Oops?
149                         }
150
151                         if( hdr->NParams > ciMaxParamCount ) {
152                                 // Oops.
153                         }
154
155                         len = recv(Client->Socket, hdr->Params, hdr->NParams*sizeof(tRequestValue), 0);
156                         if( len != hdr->NParams*sizeof(tRequestValue) ) {
157                                 // Oops.
158                         }
159
160                         // Get buffer size
161                         size_t  hdrsize = sizeof(tRequestHeader) + hdr->NParams*sizeof(tRequestValue);
162                         size_t  bufsize = hdrsize;
163                          int    i;
164                         for( i = 0; i < hdr->NParams; i ++ )
165                         {
166                                 if( hdr->Params[i].Flags & ARG_FLAG_ZEROED )
167                                         ;
168                                 else {
169                                         bufsize += hdr->Params[i].Length;
170                                 }
171                         }
172
173                         // Allocate full buffer
174                         hdr = malloc(bufsize);
175                         memcpy(hdr, lbuf, hdrsize);
176                         len = recv(Client->Socket, hdr->Params + hdr->NParams, bufsize - hdrsize, 0);
177                         if( len != bufsize - hdrsize ) {
178                                 // Oops?
179                         }
180
181                          int    retlen;
182                         tRequestHeader  *retHeader;
183                         retHeader = SyscallRecieve(hdr, &retlen);
184                         if( !retHeader ) {
185                                 // Some sort of error
186                         }
187                         
188                         send(Client->Socket, retHeader, retlen, 0); 
189
190                         // Clean up
191                         free(hdr);
192                 }
193         }
194         #else
195         tRequestHeader  *retHeader;
196         tRequestHeader  errorHeader;
197          int    retSize = 0;
198          int    sentSize;
199          int    cur_client_id = 0;
200         for( ;; )
201         {
202                 // Wait for something to do
203                 while( Client->CurrentRequest == NULL )
204                         SDL_CondWait(Client->WaitFlag, Client->Mutex);
205                 
206 //              Log_Debug("AcessSrv", "Worker got message %p", Client->CurrentRequest);
207                 
208                 if(Client->ClientID != cur_client_id) {
209 //                      Log_Debug("AcessSrv", "Client thread ID changed from %i to %i",
210 //                              cur_client_id, Client->ClientID);
211                         Threads_SetThread( Client->ClientID );
212                         cur_client_id = Client->ClientID;
213                 }
214                 
215                 // Debug
216                 {
217                         int     callid = Client->CurrentRequest->CallID;
218                         Log_Debug("AcessSrv", "Client %i request %i %s",
219                                 Client->ClientID, callid,
220                                 callid < N_SYSCALLS ? casSYSCALL_NAMES[callid] : "UNK"
221                                 );
222                 }
223                 
224                 // Get the response
225                 retHeader = SyscallRecieve(Client->CurrentRequest, &retSize);
226
227                 if( !retHeader ) {
228                         // Return an error to the client
229                         printf("ERROR: SyscallRecieve failed\n");
230                         errorHeader.CallID = Client->CurrentRequest->CallID;
231                         errorHeader.NParams = 0;
232                         retHeader = &errorHeader;
233                         retSize = sizeof(errorHeader);
234                 }
235                 
236                 // Set ID
237                 retHeader->ClientID = Client->ClientID;
238                 
239                 // Mark the thread as ready for another job
240                 free(Client->CurrentRequest);
241                 Client->CurrentRequest = 0;
242                 
243 //              Log_Debug("AcessSrv", "Sending %i to %x:%i (Client %i)",
244 //                      retSize, ntohl(Client->ClientAddr.sin_addr.s_addr),
245 //                      ntohs(Client->ClientAddr.sin_port),
246 //                      Client->ClientID
247 //                      );
248                 
249                 // Return the data
250                 sentSize = sendto(gSocket, retHeader, retSize, 0,
251                         (struct sockaddr*)&Client->ClientAddr, sizeof(Client->ClientAddr)
252                         );
253                 if( sentSize != retSize ) {
254                         perror("Server_WorkerThread - send");
255                 }
256                 
257                 // Free allocated header
258                 if( retHeader != &errorHeader )
259                         free( retHeader );
260         }
261         #endif
262 }
263
264 int SyscallServer(void)
265 {
266         struct sockaddr_in      server;
267         
268         #ifdef __WIN32__
269         /* Open windows connection */
270         if (WSAStartup(0x0101, &gWinsock) != 0)
271         {
272                 fprintf(stderr, "Could not open Windows connection.\n");
273                 exit(0);
274         }
275         #endif
276         
277         #if USE_TCP
278         // Open TCP Connection
279         gSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
280         #else
281         // Open UDP Connection
282         gSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
283         #endif
284         if (gSocket == INVALID_SOCKET)
285         {
286                 fprintf(stderr, "Could not create socket.\n");
287                 #if __WIN32__
288                 WSACleanup();
289                 #endif
290                 exit(0);
291         }
292         
293         // Set server address
294         memset(&server, 0, sizeof(struct sockaddr_in));
295         server.sin_family = AF_INET;
296         server.sin_port = htons(SERVER_PORT);
297         server.sin_addr.s_addr = htonl(INADDR_ANY);
298         
299         // Bind
300         if( bind(gSocket, (struct sockaddr *)&server, sizeof(struct sockaddr_in)) == -1 )
301         {
302                 fprintf(stderr, "Cannot bind address to socket.\n");
303                 perror("SyscallServer - bind");
304                 #if __WIN32__
305                 closesocket(gSocket);
306                 WSACleanup();
307                 #else
308                 close(gSocket);
309                 #endif
310                 exit(0);
311         }
312         
313         #if USE_TCP
314         listen(gSocket, 5);
315         #endif
316         
317         Log_Notice("AcessSrv", "Listening on 0.0.0.0:%i", SERVER_PORT);
318         gpServer_ListenThread = SDL_CreateThread( Server_ListenThread, NULL );
319         return 0;
320 }
321
322 int Server_ListenThread(void *Unused)
323 {       
324         // Wait for something to do :)
325         for( ;; )
326         {
327                 #if USE_TCP
328                 struct sockaddr_in      clientaddr;
329                 socklen_t       clientSize = sizeof(clientaddr);
330                  int    clientSock = accept(gSocket, (struct sockaddr*)&clientaddr, &clientSize);
331                 if( clientSock < 0 ) {
332                         perror("SyscallServer - accept");
333                         break ;
334                 }
335
336                 char    addrstr[4*8+8+1];
337                 inet_ntop(clientaddr.sin_family, &clientaddr.sin_addr, addrstr, sizeof(addrstr));
338                 Log_Debug("Server", "Client connection %s:%i\n", addrstr, ntohs(clientaddr.sin_port));
339                 
340                 // Perform auth
341                 size_t  len;
342                 tRequestAuthHdr authhdr;
343                 len = recv(clientSock, &authhdr, sizeof(authhdr), 0);
344                 if( len != sizeof(authhdr) ) {
345                         // Some form of error?
346                 }
347                 
348                 tClient *client;
349                 if( authhdr.pid == 0 ) {
350                         // Allocate PID and client structure/thread
351                         client = Server_GetClient(0);
352                         client->Socket = clientSock;
353                         authhdr.pid = client->ClientID;
354                 }
355                 else {
356                         // Get client structure and make sure it's unused
357                         // - Auth token / verifcation?
358                         client = Server_GetClient(authhdr.pid);
359                         if( client->Socket != 0 ) {
360                                 Log_Warning("Server", "Client (%i)%p owned by FD%i but %s:%i tried to use it",
361                                         authhdr.pid, client, addrstr, clientaddr.sin_port);
362                                 authhdr.pid = 0;
363                         }
364                         else {
365                                 client->Socket = clientSock;
366                         }
367                 }
368                 
369                 len = send(clientSock, &authhdr, sizeof(authhdr), 0);
370                 if( len != sizeof(authhdr) ) {
371                         // Ok, this is an error
372                         perror("Sending auth reply");
373                 }
374
375                 // All done, client thread should be watching now               
376
377                 #else
378                 char    data[BUFSIZ];
379                 tRequestHeader  *req = (void*)data;
380                 struct sockaddr_in      addr;
381                 uint    clientSize = sizeof(addr);
382                  int    length;
383                 tClient *client;
384                 
385                 length = recvfrom(gSocket, data, BUFSIZ, 0, (struct sockaddr*)&addr, &clientSize);
386                 
387                 if( length == -1 ) {
388                         perror("SyscallServer - recv");
389                         break;
390                 }
391                 
392                 // Hand off to a worker thread
393                 // - TODO: Actually have worker threads
394 //              Log_Debug("Server", "%i bytes from %x:%i", length,
395 //                      ntohl(addr.sin_addr.s_addr), ntohs(addr.sin_port));
396                 
397                 client = Server_GetClient(req->ClientID);
398                 // NOTE: Hack - Should check if all zero
399                 if( req->ClientID == 0 || client->ClientAddr.sin_port == 0 )
400                 {
401                         memcpy(&client->ClientAddr, &addr, sizeof(addr));
402                 }
403                 else if( memcmp(&client->ClientAddr, &addr, sizeof(addr)) != 0 )
404                 {
405                         printf("ClientID %i used by %x:%i\n",
406                                 client->ClientID, ntohl(addr.sin_addr.s_addr), ntohs(addr.sin_port));
407                         printf(" actually owned by %x:%i\n",
408                                 ntohl(client->ClientAddr.sin_addr.s_addr), ntohs(client->ClientAddr.sin_port));
409                         continue;
410                 }
411                 
412                 if( client->CurrentRequest ) {
413                         printf("Worker thread for %x:%i is busy\n",
414                                 ntohl(client->ClientAddr.sin_addr.s_addr), ntohs(client->ClientAddr.sin_port));
415                         continue;
416                 }
417                 
418 //              Log_Debug("AcessSrv", "Message from Client %i (%p)",
419 //                      client->ClientID, client);
420
421                 // Make a copy of the request data      
422                 req = malloc(length);
423                 memcpy(req, data, length);
424                 client->CurrentRequest = req;
425                 SDL_CondSignal(client->WaitFlag);
426                 #endif
427         }
428         return -1;
429 }

UCC git Repository :: git.ucc.asn.au