AcessNative - Fixed CLIShell
[tpg/acess2.git] / AcessNative / ld-acess_src / request.c
index 75565f1..12a26d1 100644 (file)
@@ -1,8 +1,19 @@
 /*
  */
+#define DEBUG  0
+
+
+#if DEBUG
+# define DEBUG_S       printf
+#else
+# define DEBUG_S(...)
+# define DONT_INCLUDE_SYSCALL_NAMES
+#endif
+
 #include <stdlib.h>
 #include <string.h>
 #include <stdio.h>
+#include <inttypes.h>
 #ifdef __WIN32__
 # include <windows.h>
 # include <winsock.h>
@@ -14,7 +25,7 @@
 #include "request.h"
 #include "../syscalls.h"
 
-#define USE_TCP        0
+#define USE_TCP        1
 
 // === PROTOTYPES ===
 void   SendData(void *Data, int Length);
@@ -34,9 +45,16 @@ SOCKET       gSocket = INVALID_SOCKET;
 struct sockaddr_in     gSyscall_ServerAddr;
 
 // === CODE ===
-int _InitSyscalls()
+void Request_Preinit(void)
+{
+       // Set server address
+       memset((void *)&gSyscall_ServerAddr, '\0', sizeof(struct sockaddr_in));
+       gSyscall_ServerAddr.sin_family = AF_INET;
+       gSyscall_ServerAddr.sin_port = htons(SERVER_PORT);
+}
+
+int _InitSyscalls(void)
 {
-       
        #ifdef __WIN32__
        /* Open windows connection */
        if (WSAStartup(0x0101, &gWinsock) != 0)
@@ -62,12 +80,6 @@ int _InitSyscalls()
                exit(0);
        }
        
-       // Set server address
-       memset((void *)&gSyscall_ServerAddr, '\0', sizeof(struct sockaddr_in));
-       gSyscall_ServerAddr.sin_family = AF_INET;
-       gSyscall_ServerAddr.sin_port = htons(SERVER_PORT);
-       gSyscall_ServerAddr.sin_addr.s_addr = htonl(0x7F000001);
-       
        #if 0
        // Set client address
        memset((void *)&client, '\0', sizeof(struct sockaddr_in));
@@ -79,7 +91,7 @@ int _InitSyscalls()
        #if USE_TCP
        if( connect(gSocket, (struct sockaddr *)&gSyscall_ServerAddr, sizeof(struct sockaddr_in)) < 0 )
        {
-               fprintf(stderr, "Cannot connect to server (localhost:%i)\n", SERVER_PORT);
+               fprintf(stderr, "[ERROR -] Cannot connect to server (localhost:%i)\n", SERVER_PORT);
                perror("_InitSyscalls");
                #if __WIN32__
                closesocket(gSocket);
@@ -89,7 +101,6 @@ int _InitSyscalls()
                #endif
                exit(0);
        }
-       giSyscall_ClientID = gSocket;   // A bit of a hack really :(
        #endif
        
        #if 0
@@ -107,8 +118,21 @@ int _InitSyscalls()
        }
        #endif
        
-       #if !USE_TCP
+       #if USE_TCP
+       {
+               tRequestAuthHdr auth;
+               auth.pid = giSyscall_ClientID;
+               SendData(&auth, sizeof(auth));
+               int len = ReadData(&auth, sizeof(auth), 5);
+               if( len == 0 ) { 
+                       fprintf(stderr, "Timeout waiting for auth response\n");
+                       exit(-1);
+               }
+               giSyscall_ClientID = auth.pid;
+       }
+       #else
        // Ask server for a client ID
+       if( !giSyscall_ClientID )
        {
                tRequestHeader  req;
                 int    len;
@@ -131,6 +155,20 @@ int _InitSyscalls()
        return 0;
 }
 
+/**
+ * \brief Close the syscall socket
+ * \note Used in acess_fork to get a different port number
+ */
+void _CloseSyscalls(void)
+{
+       #if __WIN32__
+       closesocket(gSocket);
+       WSACleanup();
+       #else
+       close(gSocket);
+       #endif
+}
+
 int SendRequest(tRequestHeader *Request, int RequestSize, int ResponseSize)
 {
        if( gSocket == INVALID_SOCKET )
@@ -154,36 +192,38 @@ int SendRequest(tRequestHeader *Request, int RequestSize, int ResponseSize)
        {
                 int    i;
                char    *data = (char*)&Request->Params[Request->NParams];
-               printf("Request #%i (%s) -", Request->CallID, casSYSCALL_NAMES[Request->CallID]);
+               DEBUG_S("Request #%i (%s) -", Request->CallID, casSYSCALL_NAMES[Request->CallID]);
                for( i = 0; i < Request->NParams; i ++ )
                {
                        switch(Request->Params[i].Type)
                        {
                        case ARG_TYPE_INT32:
-                               printf(" 0x%08x", *(uint32_t*)data);
+                               DEBUG_S(" 0x%08x", *(uint32_t*)data);
                                data += sizeof(uint32_t);
                                break;
                        case ARG_TYPE_INT64:
-                               printf(" 0x%016llx", *(uint64_t*)data);
+                               DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
                                data += sizeof(uint64_t);
                                break;
                        case ARG_TYPE_STRING:
-                               printf(" '%s'", (char*)data);
+                               DEBUG_S(" '%s'", (char*)data);
                                data += Request->Params[i].Length;
                                break;
                        case ARG_TYPE_DATA:
-                               printf(" %p:0x%x", (char*)data, Request->Params[i].Length);
+                               DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
                                if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
                                        data += Request->Params[i].Length;
                                break;
                        }
                }
-               printf("\n");
+               DEBUG_S("\n");
        }
        
        // Send it off
        SendData(Request, RequestSize);
-       
+
+       if( Request->CallID == SYS_EXIT )       return 0;
+
        // Wait for a response (no timeout)
        return ReadData(Request, ResponseSize, 0);
 }
@@ -193,13 +233,14 @@ void SendData(void *Data, int Length)
         int    len;
        
        #if USE_TCP
-       len = send(Data, Length, 0);
+       len = send(gSocket, Data, Length, 0);
        #else
        len = sendto(gSocket, Data, Length, 0,
                (struct sockaddr*)&gSyscall_ServerAddr, sizeof(gSyscall_ServerAddr));
        #endif
        
        if( len != Length ) {
+               fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
                perror("SendData");
                exit(-1);
        }
@@ -226,12 +267,13 @@ int ReadData(void *Dest, int MaxLength, int Timeout)
        
        ret = select(gSocket+1, &fds, NULL, NULL, timeoutPtr);
        if( ret == -1 ) {
+               fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
                perror("ReadData - select");
                exit(-1);
        }
        
        if( !ret ) {
-               printf("Timeout reading from socket\n");
+               printf("[ERROR %i] Timeout reading from socket\n", giSyscall_ClientID);
                return 0;       // Timeout
        }
        
@@ -242,11 +284,12 @@ int ReadData(void *Dest, int MaxLength, int Timeout)
        #endif
        
        if( ret < 0 ) {
+               fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
                perror("ReadData");
                exit(-1);
        }
        
-       printf("%i bytes read from socket\n", ret);
+       DEBUG_S("%i bytes read from socket\n", ret);
        
        return ret;
 }

UCC git Repository :: git.ucc.asn.au