AcessNative - Fixed CLIShell
[tpg/acess2.git] / AcessNative / ld-acess_src / request.c
index f9526b3..12a26d1 100644 (file)
@@ -1,8 +1,19 @@
 /*
  */
+#define DEBUG  0
+
+
+#if DEBUG
+# define DEBUG_S       printf
+#else
+# define DEBUG_S(...)
+# define DONT_INCLUDE_SYSCALL_NAMES
+#endif
+
 #include <stdlib.h>
 #include <string.h>
 #include <stdio.h>
+#include <inttypes.h>
 #ifdef __WIN32__
 # include <windows.h>
 # include <winsock.h>
 #include "request.h"
 #include "../syscalls.h"
 
-#define        SERVER_PORT     0xACE
+#define USE_TCP        1
+
+// === PROTOTYPES ===
+void   SendData(void *Data, int Length);
+ int   ReadData(void *Dest, int MaxLen, int Timeout);
 
 // === GLOBALS ===
 #ifdef __WIN32__
@@ -26,14 +41,20 @@ SOCKET      gSocket = INVALID_SOCKET;
 #endif
 // Client ID to pass to server
 // TODO: Implement such that each thread gets a different one
-static int     siSyscall_ClientID = 0;
+ int   giSyscall_ClientID = 0;
+struct sockaddr_in     gSyscall_ServerAddr;
 
 // === CODE ===
-int _InitSyscalls()
+void Request_Preinit(void)
+{
+       // Set server address
+       memset((void *)&gSyscall_ServerAddr, '\0', sizeof(struct sockaddr_in));
+       gSyscall_ServerAddr.sin_family = AF_INET;
+       gSyscall_ServerAddr.sin_port = htons(SERVER_PORT);
+}
+
+int _InitSyscalls(void)
 {
-       struct sockaddr_in      server;
-       struct sockaddr_in      client;
-       
        #ifdef __WIN32__
        /* Open windows connection */
        if (WSAStartup(0x0101, &gWinsock) != 0)
@@ -43,8 +64,13 @@ int _InitSyscalls()
        }
        #endif
        
+       #if USE_TCP
+       // Open TCP Connection
+       gSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+       #else
        // Open UDP Connection
-       gSocket = socket(AF_INET, SOCK_DGRAM, 0);
+       gSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+       #endif
        if (gSocket == INVALID_SOCKET)
        {
                fprintf(stderr, "Could not create socket.\n");
@@ -54,18 +80,30 @@ int _InitSyscalls()
                exit(0);
        }
        
-       // Set server address
-       memset((void *)&server, '\0', sizeof(struct sockaddr_in));
-       server.sin_family = AF_INET;
-       server.sin_port = htons(SERVER_PORT);
-       server.sin_addr.s_addr = htonl(0x7F00001);
-       
+       #if 0
        // Set client address
        memset((void *)&client, '\0', sizeof(struct sockaddr_in));
        client.sin_family = AF_INET;
        client.sin_port = htons(0);
-       client.sin_addr.s_addr = htonl(0x7F00001);
+       client.sin_addr.s_addr = htonl(0x7F000001);
+       #endif
        
+       #if USE_TCP
+       if( connect(gSocket, (struct sockaddr *)&gSyscall_ServerAddr, sizeof(struct sockaddr_in)) < 0 )
+       {
+               fprintf(stderr, "[ERROR -] Cannot connect to server (localhost:%i)\n", SERVER_PORT);
+               perror("_InitSyscalls");
+               #if __WIN32__
+               closesocket(gSocket);
+               WSACleanup();
+               #else
+               close(gSocket);
+               #endif
+               exit(0);
+       }
+       #endif
+       
+       #if 0
        // Bind
        if( bind(gSocket, (struct sockaddr *)&client, sizeof(struct sockaddr_in)) == -1 )
        {
@@ -78,74 +116,180 @@ int _InitSyscalls()
                #endif
                exit(0);
        }
+       #endif
+       
+       #if USE_TCP
+       {
+               tRequestAuthHdr auth;
+               auth.pid = giSyscall_ClientID;
+               SendData(&auth, sizeof(auth));
+               int len = ReadData(&auth, sizeof(auth), 5);
+               if( len == 0 ) { 
+                       fprintf(stderr, "Timeout waiting for auth response\n");
+                       exit(-1);
+               }
+               giSyscall_ClientID = auth.pid;
+       }
+       #else
+       // Ask server for a client ID
+       if( !giSyscall_ClientID )
+       {
+               tRequestHeader  req;
+                int    len;
+               req.ClientID = 0;
+               req.CallID = 0;
+               req.NParams = 0;
+               
+               SendData(&req, sizeof(req));
+               
+               len = ReadData(&req, sizeof(req), 5);
+               if( len == 0 ) {
+                       fprintf(stderr, "Unable to connect to server (localhost:%i)\n", SERVER_PORT);
+                       exit(-1);
+               }
+               
+               giSyscall_ClientID = req.ClientID;
+       }
+       #endif
+       
        return 0;
 }
 
-int SendRequest(int RequestID, int NumOutput, tOutValue **Output, int NumInput, tInValue **Input)
+/**
+ * \brief Close the syscall socket
+ * \note Used in acess_fork to get a different port number
+ */
+void _CloseSyscalls(void)
 {
-       tRequestHeader  *request;
-       tRequestValue   *value;
-       char    *data;
-        int    requestLen;
-        int    i;
-       
-       // See ../syscalls.h for details of request format
-       requestLen = sizeof(tRequestHeader) + (NumOutput + NumInput) * sizeof(tRequestValue);
-       
-       // Get total param length
-       for( i = 0; i < NumOutput; i ++ )
-               requestLen += Output[i]->Length;
-       
-       // Allocate request
-       request = malloc( requestLen );
-       value = request->Params;
-       data = (char*)&request->Params[ NumOutput + NumInput ];
+       #if __WIN32__
+       closesocket(gSocket);
+       WSACleanup();
+       #else
+       close(gSocket);
+       #endif
+}
+
+int SendRequest(tRequestHeader *Request, int RequestSize, int ResponseSize)
+{
+       if( gSocket == INVALID_SOCKET )
+       {
+               _InitSyscalls();                
+       }
        
        // Set header
-       request->ClientID = siSyscall_ClientID;
-       request->CallID = RequestID;    // Syscall
-       request->NParams = NumOutput;
-       request->NReturn = NumInput;
+       Request->ClientID = giSyscall_ClientID;
        
-       // Set parameters
-       for( i = 0; i < NumOutput; i ++ )
+       #if 0
        {
-               switch(Output[i]->Type)
+               for(i=0;i<RequestSize;i++)
                {
-               case 'i':       value->Type = ARG_TYPE_INT32;   break;
-               case 'I':       value->Type = ARG_TYPE_INT64;   break;
-               case 'd':       value->Type = ARG_TYPE_DATA;    break;
-               default:
-                       return -1;
+                       printf("%02x ", ((uint8_t*)Request)[i]);
+                       if( i % 16 == 15 )      printf("\n");
                }
-               value->Length = Output[i]->Length;
-               
-               memcpy(data, Output[i]->Data, Output[i]->Length);
-               
-               data += Output[i]->Length;
+               printf("\n");
        }
-       
-       // Set return values
-       for( i = 0; i < NumInput; i ++ )
+       #endif
        {
-               switch(Input[i]->Type)
+                int    i;
+               char    *data = (char*)&Request->Params[Request->NParams];
+               DEBUG_S("Request #%i (%s) -", Request->CallID, casSYSCALL_NAMES[Request->CallID]);
+               for( i = 0; i < Request->NParams; i ++ )
                {
-               case 'i':       value->Type = ARG_TYPE_INT32;   break;
-               case 'I':       value->Type = ARG_TYPE_INT64;   break;
-               case 'd':       value->Type = ARG_TYPE_DATA;    break;
-               default:
-                       return -1;
+                       switch(Request->Params[i].Type)
+                       {
+                       case ARG_TYPE_INT32:
+                               DEBUG_S(" 0x%08x", *(uint32_t*)data);
+                               data += sizeof(uint32_t);
+                               break;
+                       case ARG_TYPE_INT64:
+                               DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
+                               data += sizeof(uint64_t);
+                               break;
+                       case ARG_TYPE_STRING:
+                               DEBUG_S(" '%s'", (char*)data);
+                               data += Request->Params[i].Length;
+                               break;
+                       case ARG_TYPE_DATA:
+                               DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
+                               if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
+                                       data += Request->Params[i].Length;
+                               break;
+                       }
                }
-               value->Length = Input[i]->Length;
+               DEBUG_S("\n");
        }
        
        // Send it off
-       send(gSocket, request, requestLen, 0);
+       SendData(Request, RequestSize);
+
+       if( Request->CallID == SYS_EXIT )       return 0;
+
+       // Wait for a response (no timeout)
+       return ReadData(Request, ResponseSize, 0);
+}
+
+void SendData(void *Data, int Length)
+{
+        int    len;
        
-       // Wait for a response
-       recv(gSocket, request, requestLen, 0);
+       #if USE_TCP
+       len = send(gSocket, Data, Length, 0);
+       #else
+       len = sendto(gSocket, Data, Length, 0,
+               (struct sockaddr*)&gSyscall_ServerAddr, sizeof(gSyscall_ServerAddr));
+       #endif
+       
+       if( len != Length ) {
+               fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
+               perror("SendData");
+               exit(-1);
+       }
+}
+
+int ReadData(void *Dest, int MaxLength, int Timeout)
+{
+        int    ret;
+       fd_set  fds;
+       struct timeval  tv;
+       struct timeval  *timeoutPtr;
        
-       // Parse response out
+       FD_ZERO(&fds);
+       FD_SET(gSocket, &fds);
        
-       return 0;
+       if( Timeout ) {
+               tv.tv_sec = Timeout;
+               tv.tv_usec = 0;
+               timeoutPtr = &tv;
+       }
+       else {
+               timeoutPtr = NULL;
+       }
+       
+       ret = select(gSocket+1, &fds, NULL, NULL, timeoutPtr);
+       if( ret == -1 ) {
+               fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
+               perror("ReadData - select");
+               exit(-1);
+       }
+       
+       if( !ret ) {
+               printf("[ERROR %i] Timeout reading from socket\n", giSyscall_ClientID);
+               return 0;       // Timeout
+       }
+       
+       #if USE_TCP
+       ret = recv(gSocket, Dest, MaxLength, 0);
+       #else
+       ret = recvfrom(gSocket, Dest, MaxLength, 0, NULL, 0);
+       #endif
+       
+       if( ret < 0 ) {
+               fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
+               perror("ReadData");
+               exit(-1);
+       }
+       
+       DEBUG_S("%i bytes read from socket\n", ret);
+       
+       return ret;
 }

UCC git Repository :: git.ucc.asn.au