690c7f7ff35221542b3a1b8e5c4064d8adc9b6ce
[tpg/acess2.git] / AcessNative / ld-acess_src / request.c
1 /*
2  * AcessNative ld-acess dynamic linker
3  * - By John Hodge (thePowersGang)
4  *
5  * request.c
6  * - IPC interface
7  */
8 #define DEBUG   0
9
10 #if DEBUG
11 # define DEBUG_S        printf
12 #else
13 # define DEBUG_S(...)
14 # define DONT_INCLUDE_SYSCALL_NAMES
15 #endif
16
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #include <inttypes.h>
21 #ifdef __WIN32__
22 # include <windows.h>
23 # include <winsock.h>
24 #else
25 # include <unistd.h>
26 # include <sys/socket.h>
27 # include <netinet/in.h>
28 # include <sys/select.h>
29 #endif
30 #include "request.h"
31 #include "../syscalls.h"
32
33 #define USE_TCP 1
34
35 // === PROTOTYPES ===
36 void    SendData(void *Data, int Length);
37  int    ReadData(void *Dest, int MaxLen, int Timeout);
38
39 // === GLOBALS ===
40 #ifdef __WIN32__
41 WSADATA gWinsock;
42 SOCKET  gSocket = INVALID_SOCKET;
43 #else
44 # define INVALID_SOCKET -1
45  int    gSocket = INVALID_SOCKET;
46 #endif
47 // Client ID to pass to server
48 // TODO: Implement such that each thread gets a different one
49  int    giSyscall_ClientID = 0;
50 struct sockaddr_in      gSyscall_ServerAddr;
51
52 // === CODE ===
53 void Request_Preinit(void)
54 {
55         // Set server address
56         memset((void *)&gSyscall_ServerAddr, '\0', sizeof(struct sockaddr_in));
57         gSyscall_ServerAddr.sin_family = AF_INET;
58         gSyscall_ServerAddr.sin_port = htons(SERVER_PORT);
59         gSyscall_ServerAddr.sin_addr.s_addr = htonl(0x7F000001);
60 }
61
62 int _InitSyscalls(void)
63 {
64         #ifdef __WIN32__
65         /* Open windows connection */
66         if (WSAStartup(0x0101, &gWinsock) != 0)
67         {
68                 fprintf(stderr, "Could not open Windows connection.\n");
69                 exit(0);
70         }
71         #endif
72         
73         #if USE_TCP
74         // Open TCP Connection
75         gSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
76         #else
77         // Open UDP Connection
78         gSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
79         #endif
80         if (gSocket == INVALID_SOCKET)
81         {
82                 fprintf(stderr, "Could not create socket.\n");
83                 #if __WIN32__
84                 WSACleanup();
85                 #endif
86                 exit(0);
87         }
88         
89         #if USE_TCP
90         if( connect(gSocket, (struct sockaddr *)&gSyscall_ServerAddr, sizeof(struct sockaddr_in)) < 0 )
91         {
92                 fprintf(stderr, "[ERROR -] Cannot connect to server (localhost:%i)\n", SERVER_PORT);
93                 perror("_InitSyscalls");
94                 #if __WIN32__
95                 fprintf(stderr, "[ERROR -] - WSAGetLastError said %i", WSAGetLastError());
96                 closesocket(gSocket);
97                 WSACleanup();
98                 #else
99                 close(gSocket);
100                 #endif
101                 exit(0);
102         }
103         #endif
104         
105         #if 0
106         // Set client address
107         memset((void *)&client, '\0', sizeof(struct sockaddr_in));
108         client.sin_family = AF_INET;
109         client.sin_port = htons(0);
110         client.sin_addr.s_addr = htonl(0x7F000001);
111         // Bind
112         if( bind(gSocket, (struct sockaddr *)&client, sizeof(struct sockaddr_in)) == -1 )
113         {
114                 fprintf(stderr, "Cannot bind address to socket.\n");
115                 #if __WIN32__
116                 closesocket(gSocket);
117                 WSACleanup();
118                 #else
119                 close(gSocket);
120                 #endif
121                 exit(0);
122         }
123         #endif
124         
125         #if USE_TCP
126         {
127                 tRequestAuthHdr auth;
128                 auth.pid = giSyscall_ClientID;
129                 auth.key = 0;
130                 SendData(&auth, sizeof(auth));
131                 int len = ReadData(&auth, sizeof(auth), 5);
132                 if( len == 0 ) { 
133                         fprintf(stderr, "Timeout waiting for auth response\n");
134                         exit(-1);
135                 }
136                 giSyscall_ClientID = auth.pid;
137         }
138         #else
139         // Ask server for a client ID
140         if( !giSyscall_ClientID )
141         {
142                 tRequestHeader  req;
143                  int    len;
144                 req.ClientID = 0;
145                 req.CallID = 0;
146                 req.NParams = 0;
147                 
148                 SendData(&req, sizeof(req));
149                 
150                 len = ReadData(&req, sizeof(req), 5);
151                 if( len == 0 ) {
152                         fprintf(stderr, "Unable to connect to server (localhost:%i)\n", SERVER_PORT);
153                         exit(-1);
154                 }
155                 
156                 giSyscall_ClientID = req.ClientID;
157         }
158         #endif
159         
160         return 0;
161 }
162
163 /**
164  * \brief Close the syscall socket
165  * \note Used in acess_fork to get a different port number
166  */
167 void _CloseSyscalls(void)
168 {
169         #if __WIN32__
170         closesocket(gSocket);
171         WSACleanup();
172         #else
173         close(gSocket);
174         #endif
175 }
176
177 int SendRequest(tRequestHeader *Request, int RequestSize, int ResponseSize)
178 {
179         if( gSocket == INVALID_SOCKET )
180         {
181                 _InitSyscalls();                
182         }
183         
184         // Set header
185         Request->ClientID = giSyscall_ClientID;
186         
187         #if 0
188         {
189                 for(i=0;i<RequestSize;i++)
190                 {
191                         printf("%02x ", ((uint8_t*)Request)[i]);
192                         if( i % 16 == 15 )      printf("\n");
193                 }
194                 printf("\n");
195         }
196         #endif
197         #if DEBUG
198         {
199                  int    i;
200                 char    *data = (char*)&Request->Params[Request->NParams];
201                 DEBUG_S("Request #%i (%s) -", Request->CallID, casSYSCALL_NAMES[Request->CallID]);
202                 for( i = 0; i < Request->NParams; i ++ )
203                 {
204                         switch(Request->Params[i].Type)
205                         {
206                         case ARG_TYPE_INT32:
207                                 DEBUG_S(" 0x%08x", *(uint32_t*)data);
208                                 data += sizeof(uint32_t);
209                                 break;
210                         case ARG_TYPE_INT64:
211                                 DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
212                                 data += sizeof(uint64_t);
213                                 break;
214                         case ARG_TYPE_STRING:
215                                 DEBUG_S(" '%s'", (char*)data);
216                                 data += Request->Params[i].Length;
217                                 break;
218                         case ARG_TYPE_DATA:
219                                 DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
220                                 if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
221                                         data += Request->Params[i].Length;
222                                 break;
223                         }
224                 }
225                 DEBUG_S("\n");
226         }
227         #endif
228         
229         // Send it off
230         SendData(Request, RequestSize);
231
232         // Wait for a response (no timeout)
233         ReadData(Request, sizeof(*Request), 0);
234         
235         size_t  recvbytes = sizeof(*Request);
236         // TODO: Sanity
237         size_t  expbytes = Request->MessageLength;
238         char    *ptr = (void*)Request->Params;
239         while( recvbytes < expbytes )
240         {
241                 size_t  len = ReadData(ptr, expbytes - recvbytes, 1000);
242                 if( len == -1 ) {
243                         return -1;
244                 }
245                 recvbytes += len;
246                 ptr += len;
247         }
248         if( recvbytes > expbytes ) {
249                 // TODO: Warning
250         }
251         
252         #if DEBUG
253         {
254                  int    i;
255                 char    *data = (char*)&Request->Params[Request->NParams];
256                 DEBUG_S(" Reply:");
257                 for( i = 0; i < Request->NParams; i ++ )
258                 {
259                         switch(Request->Params[i].Type)
260                         {
261                         case ARG_TYPE_INT32:
262                                 DEBUG_S(" 0x%08x", *(uint32_t*)data);
263                                 data += sizeof(uint32_t);
264                                 break;
265                         case ARG_TYPE_INT64:
266                                 DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
267                                 data += sizeof(uint64_t);
268                                 break;
269                         case ARG_TYPE_STRING:
270                                 DEBUG_S(" '%s'", (char*)data);
271                                 data += Request->Params[i].Length;
272                                 break;
273                         case ARG_TYPE_DATA:
274                                 DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
275                                 if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
276                                         data += Request->Params[i].Length;
277                                 break;
278                         }
279                 }
280                 DEBUG_S("\n");
281         }
282         #endif
283         return recvbytes;
284 }
285
286 void SendData(void *Data, int Length)
287 {
288          int    len;
289         
290         #if USE_TCP
291         len = send(gSocket, Data, Length, 0);
292         #else
293         len = sendto(gSocket, Data, Length, 0,
294                 (struct sockaddr*)&gSyscall_ServerAddr, sizeof(gSyscall_ServerAddr));
295         #endif
296         
297         if( len != Length ) {
298                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
299                 perror("SendData");
300                 exit(-1);
301         }
302 }
303
304 int ReadData(void *Dest, int MaxLength, int Timeout)
305 {
306          int    ret;
307         fd_set  fds;
308         struct timeval  tv;
309         struct timeval  *timeoutPtr;
310         
311         FD_ZERO(&fds);
312         FD_SET(gSocket, &fds);
313         
314         if( Timeout ) {
315                 tv.tv_sec = Timeout;
316                 tv.tv_usec = 0;
317                 timeoutPtr = &tv;
318         }
319         else {
320                 timeoutPtr = NULL;
321         }
322         
323         ret = select(gSocket+1, &fds, NULL, NULL, timeoutPtr);
324         if( ret == -1 ) {
325                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
326                 perror("ReadData - select");
327                 exit(-1);
328         }
329         
330         if( !ret ) {
331                 printf("[ERROR %i] Timeout reading from socket\n", giSyscall_ClientID);
332                 return -2;      // Timeout
333         }
334         
335         #if USE_TCP
336         ret = recv(gSocket, Dest, MaxLength, 0);
337         #else
338         ret = recvfrom(gSocket, Dest, MaxLength, 0, NULL, 0);
339         #endif
340         
341         if( ret < 0 ) {
342                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
343                 perror("ReadData");
344                 exit(-1);
345         }
346         if( ret == 0 ) {
347                 fprintf(stderr, "[ERROR %i] Connection closed.\n", giSyscall_ClientID);
348                 #if __WIN32__
349                 closesocket(gSocket);
350                 #else
351                 close(gSocket);
352                 #endif
353                 exit(0);
354         }
355         
356         DEBUG_S("%i bytes read from socket\n", ret);
357         
358         return ret;
359 }

UCC git Repository :: git.ucc.asn.au