AcessNative - Set SO_REUSEADDR
[tpg/acess2.git] / AcessNative / acesskernel_src / server.c
index 7919979..4dd964c 100644 (file)
@@ -6,6 +6,7 @@
  */
 #include <stdio.h>
 #include <stdlib.h>
+#include <stdbool.h>
 #include <string.h>
 #include <SDL/SDL.h>
 #ifdef __WIN32__
@@ -19,7 +20,7 @@ typedef int   socklen_t;
 # include <unistd.h>
 # include <sys/socket.h>
 # include <netinet/in.h>
-# include <arpa/inet.h>        // inet_ntop
+# include <netdb.h>    // getaddrinfo
 #endif
 #define DONT_INCLUDE_SYSCALL_NAMES
 #include "../syscalls.h"
@@ -33,22 +34,24 @@ typedef int socklen_t;
 typedef struct {
         int    ClientID;
        SDL_Thread      *WorkerThread;
+       tRequestHeader  *CurrentRequest;
+       SDL_cond        *WaitFlag;
+       SDL_mutex       *Mutex;
        #if USE_TCP
         int    Socket;
        #else
-       tRequestHeader  *CurrentRequest;
        struct sockaddr_in      ClientAddr;
-       SDL_cond        *WaitFlag;
-       SDL_mutex       *Mutex;
        #endif
 }      tClient;
 
 // === IMPORTS ===
-extern tRequestHeader *SyscallRecieve(tRequestHeader *Request, int *ReturnLength);
+// TODO: Move these to headers
+extern tRequestHeader *SyscallRecieve(tRequestHeader *Request, size_t *ReturnLength);
 extern int     Threads_CreateRootProcess(void);
-extern void    Threads_SetThread(int TID);
+extern void    Threads_SetThread(int TID, void *ClientPtr);
 extern void    *Threads_GetThread(int TID);
 extern void    Threads_PostEvent(void *Thread, uint32_t Event);
+extern void    Threads_int_Terminate(void *Thread);
 
 // === PROTOTYPES ===
 tClient        *Server_GetClient(int ClientID);
@@ -70,10 +73,9 @@ SDL_Thread   *gpServer_ListenThread;
 // === CODE ===
 int Server_GetClientID(void)
 {
-        int    i;
        Uint32  thisId = SDL_ThreadID();
        
-       for( i = 0; i < MAX_CLIENTS; i ++ )
+       for( int i = 0; i < MAX_CLIENTS; i ++ )
        {
                if( SDL_GetThreadID(gaServer_Clients[i].WorkerThread) == thisId )
                        return gaServer_Clients[i].ClientID;
@@ -87,13 +89,12 @@ int Server_GetClientID(void)
 tClient *Server_GetClient(int ClientID)
 {
        tClient *ret = NULL;
-        int    i;
        
        // Allocate an ID if needed
        if(ClientID == 0)
                ClientID = Threads_CreateRootProcess();
        
-       for( i = 0; i < MAX_CLIENTS; i ++ )
+       for( int i = 0; i < MAX_CLIENTS; i ++ )
        {
                if( gaServer_Clients[i].ClientID == ClientID ) {
                        return &gaServer_Clients[i];
@@ -104,25 +105,23 @@ tClient *Server_GetClient(int ClientID)
        
        // Uh oh, no free slots
        // TODO: Dynamic allocation
-       if( !ret )
+       if( !ret ) {
+               Log_Error("Server", "Ran out of static client slots (%i)", MAX_CLIENTS);
                return NULL;
+       }
        
        // Allocate a thread for the process
        ret->ClientID = ClientID;
+       ret->CurrentRequest = NULL;
        #if USE_TCP
        ret->Socket = 0;
-       #else
-       ret->CurrentRequest = NULL;
        #endif
                
        if( !ret->WorkerThread ) {
-               #if USE_TCP
-               #else
+               Log_Debug("Server", "Creating worker for %p", ret);
                ret->WaitFlag = SDL_CreateCond();
                ret->Mutex = SDL_CreateMutex();
                SDL_mutexP( ret->Mutex );
-               #endif
-               Log_Debug("Server", "Creating worker for %p", ret);
                ret->WorkerThread = SDL_CreateThread( Server_WorkerThread, ret );
        }
        
@@ -133,133 +132,12 @@ int Server_WorkerThread(void *ClientPtr)
 {
        tClient *Client = ClientPtr;
 
-       Log_Debug("Server", "Worker %p", ClientPtr);    
-
-       #if USE_TCP
-
-       while( *((volatile typeof(Client->Socket)*)&Client->Socket) == 0 )
-               ;
-       Threads_SetThread( Client->ClientID );
-       
-       while( Client->ClientID != -1 )
-       {
-               fd_set  fds;
-                int    nfd = Client->Socket+1;
-               FD_ZERO(&fds);
-               FD_SET(Client->Socket, &fds);
-               
-               int rv = select(nfd, &fds, NULL, NULL, NULL);   // TODO: Timeouts?
-               if(rv < 0) {
-                       perror("select");
-                       continue ;
-               }
-               Log_Debug("Server", "%p: rv=%i", Client, rv);           
-
-               if( FD_ISSET(Client->Socket, &fds) )
-               {
-                       const int       ciMaxParamCount = 6;
-                       char    lbuf[sizeof(tRequestHeader) + ciMaxParamCount*sizeof(tRequestValue)];
-                       tRequestHeader  *hdr = (void*)lbuf;
-                       size_t  len = recv(Client->Socket, (void*)hdr, sizeof(*hdr), 0);
-                       Log_Debug("Server", "%i bytes of header", len);
-                       if( len == 0 )  break;
-                       if( len == -1 ) {
-                               perror("recv header");
-//                             Log_Warning("Server", "recv() error - %s", strerror(errno));
-                               break;
-                       }
-                       if( len != sizeof(*hdr) ) {
-                               // Oops?
-                               Log_Warning("Server", "FD%i bad sized (%i != exp %i)",
-                                       Client->Socket, len, sizeof(*hdr));
-                               continue ;
-                       }
-
-                       if( hdr->NParams > ciMaxParamCount ) {
-                               // Oops.
-                               Log_Warning("Server", "FD%i too many params (%i > max %i)",
-                                       Client->Socket, hdr->NParams, ciMaxParamCount);
-                               break ;
-                       }
-
-                       if( hdr->NParams > 0 )
-                       {
-                               len = recv(Client->Socket, (void*)hdr->Params, hdr->NParams*sizeof(tRequestValue), 0);
-                               Log_Debug("Server", "%i bytes of params", len);
-                               if( len != hdr->NParams*sizeof(tRequestValue) ) {
-                                       // Oops.
-                                       perror("recv params");
-                                       Log_Warning("Sever", "Recieving params failed");
-                                       break ;
-                               }
-                       }
-                       else
-                       {
-                               Log_Debug("Server", "No params?");
-                       }
-
-                       // Get buffer size
-                       size_t  hdrsize = sizeof(tRequestHeader) + hdr->NParams*sizeof(tRequestValue);
-                       size_t  bufsize = hdrsize;
-                        int    i;
-                       for( i = 0; i < hdr->NParams; i ++ )
-                       {
-                               if( hdr->Params[i].Flags & ARG_FLAG_ZEROED )
-                                       ;
-                               else {
-                                       bufsize += hdr->Params[i].Length;
-                               }
-                       }
-
-                       // Allocate full buffer
-                       hdr = malloc(bufsize);
-                       memcpy(hdr, lbuf, hdrsize);
-                       if( bufsize > hdrsize )
-                       {
-                               size_t  rem = bufsize - hdrsize;
-                               char    *ptr = (void*)( hdr->Params + hdr->NParams );
-                               while( rem )
-                               {
-                                       len = recv(Client->Socket, ptr, rem, 0);
-                                       Log_Debug("Server", "%i bytes of data", len);
-                                       if( len == -1 ) {
-                                               // Oops?
-                                               perror("recv data");
-                                               Log_Warning("Sever", "Recieving data failed");
-                                               break ;
-                                       }
-                                       rem -= len;
-                                       ptr += len;
-                               }
-                               if( rem ) {
-                                       break;
-                               }
-                       }
-                       else
-                               Log_Debug("Server", "no data");
-
-                        int    retlen;
-                       tRequestHeader  *retHeader;
-                       retHeader = SyscallRecieve(hdr, &retlen);
-                       if( !retHeader ) {
-                               // Some sort of error
-                               Log_Warning("Server", "SyscallRecieve failed?");
-                               continue ;
-                       }
-                       
-                       send(Client->Socket, (void*)retHeader, retlen, 0); 
+       Log_Debug("Server", "Worker %p active", ClientPtr);     
 
-                       // Clean up
-                       free(hdr);
-               }
-       }
-       #else
-       tRequestHeader  *retHeader;
        tRequestHeader  errorHeader;
-        int    retSize = 0;
-        int    sentSize;
+       size_t  retSize = 0;
         int    cur_client_id = 0;
-       while( Client->ClientID != -1 )
+       while( Client->ClientID != 0 )
        {
                // Wait for something to do
                if( Client->CurrentRequest == NULL )
@@ -267,26 +145,13 @@ int Server_WorkerThread(void *ClientPtr)
                if( Client->CurrentRequest == NULL )
                        continue ;
                
-//             Log_Debug("AcessSrv", "Worker got message %p", Client->CurrentRequest);
-               
                if(Client->ClientID != cur_client_id) {
-//                     Log_Debug("AcessSrv", "Client thread ID changed from %i to %i",
-//                             cur_client_id, Client->ClientID);
-                       Threads_SetThread( Client->ClientID );
+                       Threads_SetThread( Client->ClientID, Client );
                        cur_client_id = Client->ClientID;
                }
                
-               // Debug
-               {
-                       int     callid = Client->CurrentRequest->CallID;
-                       Log_Debug("AcessSrv", "Client %i request %i %s",
-                               Client->ClientID, callid,
-                               callid < N_SYSCALLS ? casSYSCALL_NAMES[callid] : "UNK"
-                               );
-               }
-               
                // Get the response
-               retHeader = SyscallRecieve(Client->CurrentRequest, &retSize);
+               tRequestHeader  *retHeader = SyscallRecieve(Client->CurrentRequest, &retSize);
 
                if( !retHeader ) {
                        // Return an error to the client
@@ -303,26 +168,27 @@ int Server_WorkerThread(void *ClientPtr)
                // Mark the thread as ready for another job
                free(Client->CurrentRequest);
                Client->CurrentRequest = 0;
-               
-//             Log_Debug("AcessSrv", "Sending %i to %x:%i (Client %i)",
-//                     retSize, ntohl(Client->ClientAddr.sin_addr.s_addr),
-//                     ntohs(Client->ClientAddr.sin_port),
-//                     Client->ClientID
-//                     );
-               
-               // Return the data
-               sentSize = sendto(gSocket, retHeader, retSize, 0,
-                       (struct sockaddr*)&Client->ClientAddr, sizeof(Client->ClientAddr)
-                       );
-               if( sentSize != retSize ) {
-                       perror("Server_WorkerThread - send");
+
+               // If the thread is being terminated, don't send reply
+               if( Client->ClientID > 0 )
+               {
+                       // Return the data
+                       #if USE_TCP
+                       size_t sentSize = send(Client->Socket, retHeader, retSize, 0); 
+                       #else
+                       size_t sentSize = sendto(gSocket, retHeader, retSize, 0,
+                               (struct sockaddr*)&Client->ClientAddr, sizeof(Client->ClientAddr)
+                               );
+                       #endif
+                       if( sentSize != retSize ) {
+                               perror("Server_WorkerThread - send");
+                       }
                }
                
                // Free allocated header
                if( retHeader != &errorHeader )
                        free( retHeader );
        }
-       #endif
        Log_Notice("Server", "Terminated Worker %p", ClientPtr);        
        return 0;
 }
@@ -362,6 +228,13 @@ int SyscallServer(void)
        server.sin_port = htons(SERVER_PORT);
        server.sin_addr.s_addr = htonl(INADDR_ANY);
        
+       #if USE_TCP
+       {
+               int val = 1;
+               setsockopt(gSocket, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val);
+       }
+       #endif
+       
        // Bind
        if( bind(gSocket, (struct sockaddr *)&server, sizeof(struct sockaddr_in)) == -1 )
        {
@@ -403,83 +276,243 @@ int Server_Shutdown(void)
        return 0;
 }
 
+#if USE_TCP
+int Server_int_HandleRx(tClient *Client)
+{
+       const int       ciMaxParamCount = 6;
+       char    lbuf[sizeof(tRequestHeader) + ciMaxParamCount*sizeof(tRequestValue)];
+       tRequestHeader  *hdr = (void*)lbuf;
+       size_t  len = recv(Client->Socket, (void*)hdr, sizeof(*hdr), 0);
+       if( len == 0 ) {
+               Log_Notice("Server", "Zero RX on %i (worker %p)", Client->Socket, Client);
+               return 1;
+       }
+       if( len == -1 ) {
+               perror("recv header");
+               return 2;
+       }
+       if( len != sizeof(*hdr) ) {
+               // Oops?
+               Log_Warning("Server", "FD%i bad sized (%i != exp %i)",
+                       Client->Socket, len, sizeof(*hdr));
+               return 0;
+       }
+
+       if( hdr->NParams > ciMaxParamCount ) {
+               // Oops.
+               Log_Warning("Server", "FD%i too many params (%i > max %i)",
+                       Client->Socket, hdr->NParams, ciMaxParamCount);
+               return 0;
+       }
+
+       if( hdr->NParams > 0 )
+       {
+               len = recv(Client->Socket, (void*)hdr->Params, hdr->NParams*sizeof(tRequestValue), 0);
+               if( len != hdr->NParams*sizeof(tRequestValue) ) {
+                       // Oops.
+                       perror("recv params");
+                       Log_Warning("Sever", "Recieving params failed");
+                       return 0;
+               }
+       }
+       else
+       {
+               //Log_Debug("Server", "No params?");
+       }
+
+       // Get buffer size
+       size_t  hdrsize = sizeof(tRequestHeader) + hdr->NParams*sizeof(tRequestValue);
+       size_t  bufsize = hdrsize;
+       for( int i = 0; i < hdr->NParams; i ++ )
+       {
+               if( hdr->Params[i].Flags & ARG_FLAG_ZEROED )
+                       ;
+               else {
+                       bufsize += hdr->Params[i].Length;
+               }
+       }
+
+       // Allocate full buffer
+       hdr = malloc(bufsize);
+       memcpy(hdr, lbuf, hdrsize);
+       if( bufsize > hdrsize )
+       {
+               size_t  rem = bufsize - hdrsize;
+               char    *ptr = (void*)( hdr->Params + hdr->NParams );
+               while( rem )
+               {
+                       len = recv(Client->Socket, ptr, rem, 0);
+                       if( len == -1 ) {
+                               // Oops?
+                               perror("recv data");
+                               Log_Warning("Sever", "Recieving data failed");
+                               return 2;
+                       }
+                       rem -= len;
+                       ptr += len;
+               }
+               if( rem ) {
+                       // Extra data?
+                       return 0;
+               }
+       }
+       else {
+               //Log_Debug("Server", "no data");
+       }
+       
+       // Dispatch to worker
+       if( Client->CurrentRequest ) {
+               printf("Worker thread for client ID %i is busy\n", Client->ClientID);
+               return 1;
+       }
+
+       // Give to worker
+       Log_Debug("Server", "Message from Client %i (%p)", Client->ClientID, Client);
+       Client->CurrentRequest = hdr;
+       SDL_CondSignal(Client->WaitFlag);
+
+       return 0;
+}
+
+int Server_int_HandshakeClient(int Socket, struct sockaddr_in *addr, socklen_t addr_size)
+{
+       ENTER("iSocket paddr iaddr_size",
+               Socket, addr, addr_size);
+       unsigned short  port = ntohs(addr->sin_port);
+       char    addrstr[4*8+8+1];
+       getnameinfo((struct sockaddr*)addr, addr_size, addrstr, sizeof(addrstr), NULL, 0, NI_NUMERICHOST);
+       Log_Debug("Server", "Client connection %s:%i", addrstr, port);
+       
+       // Perform handshake
+       tRequestAuthHdr authhdr;
+       size_t  len  = recv(Socket, &authhdr, sizeof(authhdr), 0);
+       if( len != sizeof(authhdr) ) {
+               // Some form of error?
+               Log_Warning("Server", "Client auth block bad size (%i != exp %i)",
+                       len, sizeof(authhdr));
+               LEAVE('i', 1);
+               return 1;
+       }
+       
+       LOG("authhdr.pid = %i", authhdr.pid);
+       tClient *client = Server_GetClient(authhdr.pid);
+       if( authhdr.pid == 0 ) {
+               // Allocate PID and client structure/thread
+               client->Socket = Socket;
+               authhdr.pid = client->ClientID;
+       }
+       else {
+               Log_Debug("Server", "Client assumed PID %i", authhdr.pid);
+               
+               // Get client structure and make sure it's unused
+               // - Auth token / verifcation?
+               if( !client ) {
+                       Log_Warning("Server", "Can't allocate a client struct for %s:%i",
+                               addrstr, port);
+                       LEAVE('i', 1);
+                       return 1;
+               }
+               if( client->Socket != 0 ) {
+                       Log_Warning("Server", "Client (%i)%p owned by FD%i but %s:%i tried to use it",
+                               authhdr.pid, client, addrstr, port);
+                       LEAVE('i', 1);
+                       return 1;
+               }
+               
+               client->Socket = Socket;
+       }
+
+       LOG("Sending auth reply");      
+       len = send(Socket, (void*)&authhdr, sizeof(authhdr), 0);
+       if( len != sizeof(authhdr) ) {
+               // Ok, this is an error
+               perror("Sending auth reply");
+               LEAVE('i', 1);
+               return 1;
+       }
+
+       // All done, client thread should be watching now               
+       
+       LEAVE('i', 0);
+       return 0;
+}
+
+void Server_int_RemoveClient(tClient *Client)
+{
+       // Trigger the thread to kill itself
+       Threads_int_Terminate( Threads_GetThread(Client->ClientID) );
+       Client->ClientID = 0;
+       close(Client->Socket);
+}
+
+#endif
+
 int Server_ListenThread(void *Unused)
 {      
        // Wait for something to do :)
        for( ;; )
        {
                #if USE_TCP
-               struct sockaddr_in      clientaddr;
-               socklen_t       clientSize = sizeof(clientaddr);
-                int    clientSock = accept(gSocket, (struct sockaddr*)&clientaddr, &clientSize);
-               if( clientSock < 0 ) {
-                       perror("SyscallServer - accept");
-                       break ;
-               }
+               fd_set  fds;
+                int    maxfd = gSocket;
+               FD_ZERO(&fds);
+               FD_SET(gSocket, &fds);
 
-               char    addrstr[4*8+8+1];
-               getnameinfo((struct sockaddr*)&clientaddr, sizeof(clientaddr),
-                       addrstr, sizeof(addrstr), NULL, 0, NI_NUMERICHOST);
-               Log_Debug("Server", "Client connection %s:%i", addrstr, ntohs(clientaddr.sin_port));
-               
-               // Perform auth
-               size_t  len;
-               tRequestAuthHdr authhdr;
-               len = recv(clientSock, (void*)&authhdr, sizeof(authhdr), 0);
-               if( len != sizeof(authhdr) ) {
-                       // Some form of error?
-                       Log_Warning("Server", "Client auth block bad size (%i != exp %i)",
-                               len, sizeof(authhdr));
-                       close(clientSock);
-                       continue ;
+               for( int i = 0; i < MAX_CLIENTS; i ++ ) {
+                       tClient *client = &gaServer_Clients[i];
+                       if( client->ClientID == 0 )
+                               continue ;
+                       FD_SET(client->Socket, &fds);
+                       if(client->Socket > maxfd)
+                               maxfd = client->Socket;
                }
                
-               Log_Debug("Server", "Client assumed PID %i", authhdr.pid);
-
-               tClient *client;
-               if( authhdr.pid == 0 ) {
-                       // Allocate PID and client structure/thread
-                       client = Server_GetClient(0);
-                       client->Socket = clientSock;
-                       authhdr.pid = client->ClientID;
+               int rv = select(maxfd+1, &fds, NULL, NULL, NULL);
+               Log_Debug("Server", "Select rv = %i", rv);
+               if( rv <= 0 ) {
+                       perror("select");
+                       return 1;
                }
-               else {
-                       // Get client structure and make sure it's unused
-                       // - Auth token / verifcation?
-                       client = Server_GetClient(authhdr.pid);
-                       if( !client ) {
-                               Log_Warning("Server", "Can't allocate a client struct for %s:%i",
-                                       addrstr, clientaddr.sin_port);
-                               close(clientSock);
-                               continue ;
+               
+               // Incoming connection
+               if( FD_ISSET(gSocket, &fds) )
+               {
+                       struct sockaddr_in      clientaddr;
+                       socklen_t       clientSize = sizeof(clientaddr);
+                        int    clientSock = accept(gSocket, (struct sockaddr*)&clientaddr, &clientSize);
+                       if( clientSock < 0 ) {
+                               perror("SyscallServer - accept");
+                               break ;
                        }
-                       if( client->Socket != 0 ) {
-                               Log_Warning("Server", "Client (%i)%p owned by FD%i but %s:%i tried to use it",
-                                       authhdr.pid, client, addrstr, clientaddr.sin_port);
+                       if( Server_int_HandshakeClient(clientSock, &clientaddr, clientSize) ) {
+                               Log_Warning("Server", "Client handshake failed :(");
                                close(clientSock);
-                               continue;
-                       }
-                       else {
-                               client->Socket = clientSock;
                        }
                }
-               Log_Debug("Server", "Client given PID %i - info %p", authhdr.pid, client);
                
-               len = send(clientSock, (void*)&authhdr, sizeof(authhdr), 0);
-               if( len != sizeof(authhdr) ) {
-                       // Ok, this is an error
-                       perror("Sending auth reply");
+               for( int i = 0; i < MAX_CLIENTS; i ++ )
+               {
+                       tClient *client = &gaServer_Clients[i];
+                       if( client->ClientID == 0 )
+                               continue ;
+                       //Debug("Server_ListenThread: Idx %i ID %i FD %i",
+                       //      i, client->ClientID, client->Socket);
+                       if( !FD_ISSET(client->Socket, &fds) )
+                               continue ;
+                       
+                       if( Server_int_HandleRx( client ) )
+                       {
+                               Log_Warning("Server", "Client %p dropped, TODO: clean up", client);
+                               Server_int_RemoveClient(client);
+                       }
                }
-
-               // All done, client thread should be watching now               
-
+       
                #else
                char    data[BUFSIZ];
                tRequestHeader  *req = (void*)data;
                struct sockaddr_in      addr;
                uint    clientSize = sizeof(addr);
                 int    length;
-               tClient *client;
                
                length = recvfrom(gSocket, data, BUFSIZ, 0, (struct sockaddr*)&addr, &clientSize);
                
@@ -488,13 +521,13 @@ int Server_ListenThread(void *Unused)
                        break;
                }
                
-               // Hand off to a worker thread
-               // - TODO: Actually have worker threads
+               // Recive data
 //             Log_Debug("Server", "%i bytes from %x:%i", length,
 //                     ntohl(addr.sin_addr.s_addr), ntohs(addr.sin_port));
                
-               client = Server_GetClient(req->ClientID);
-               // NOTE: Hack - Should check if all zero
+               tClient *client = Server_GetClient(req->ClientID);
+               // NOTE: I should really check if the sin_addr is zero, but meh
+               // Shouldn't matter much
                if( req->ClientID == 0 || client->ClientAddr.sin_port == 0 )
                {
                        memcpy(&client->ClientAddr, &addr, sizeof(addr));
@@ -508,16 +541,15 @@ int Server_ListenThread(void *Unused)
                        continue;
                }
                
+//             Log_Debug("AcessSrv", "Message from Client %i (%p)",
+//                     client->ClientID, client);
                if( client->CurrentRequest ) {
                        printf("Worker thread for %x:%i is busy\n",
                                ntohl(client->ClientAddr.sin_addr.s_addr), ntohs(client->ClientAddr.sin_port));
                        continue;
                }
                
-//             Log_Debug("AcessSrv", "Message from Client %i (%p)",
-//                     client->ClientID, client);
-
-               // Make a copy of the request data      
+               // Duplicate the data currently on the stack, and dispatch to worker
                req = malloc(length);
                memcpy(req, data, length);
                client->CurrentRequest = req;

UCC git Repository :: git.ucc.asn.au