fdf57485e860e61e02aa3fe0a1d87a459245bc56
[tpg/acess2.git] / AcessNative / ld-acess_src / request.c
1 /*
2  * AcessNative ld-acess dynamic linker
3  * - By John Hodge (thePowersGang)
4  *
5  * request.c
6  * - IPC interface
7  */
8 #define DEBUG   0
9
10 #if DEBUG
11 # define DEBUG_S        printf
12 #else
13 # define DEBUG_S(...)
14 # define DONT_INCLUDE_SYSCALL_NAMES
15 #endif
16
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #include <inttypes.h>
21 #ifdef __WIN32__
22 # include <windows.h>
23 # include <winsock.h>
24 #else
25 # include <unistd.h>
26 # include <sys/socket.h>
27 # include <netinet/in.h>
28 #endif
29 #include "request.h"
30 #include "../syscalls.h"
31
32 #define USE_TCP 1
33
34 // === PROTOTYPES ===
35 void    SendData(void *Data, int Length);
36  int    ReadData(void *Dest, int MaxLen, int Timeout);
37
38 // === GLOBALS ===
39 #ifdef __WIN32__
40 WSADATA gWinsock;
41 SOCKET  gSocket = INVALID_SOCKET;
42 #else
43 # define INVALID_SOCKET -1
44  int    gSocket = INVALID_SOCKET;
45 #endif
46 // Client ID to pass to server
47 // TODO: Implement such that each thread gets a different one
48  int    giSyscall_ClientID = 0;
49 struct sockaddr_in      gSyscall_ServerAddr;
50
51 // === CODE ===
52 void Request_Preinit(void)
53 {
54         // Set server address
55         memset((void *)&gSyscall_ServerAddr, '\0', sizeof(struct sockaddr_in));
56         gSyscall_ServerAddr.sin_family = AF_INET;
57         gSyscall_ServerAddr.sin_port = htons(SERVER_PORT);
58         gSyscall_ServerAddr.sin_addr.s_addr = htonl(0x7F000001);
59 }
60
61 int _InitSyscalls(void)
62 {
63         #ifdef __WIN32__
64         /* Open windows connection */
65         if (WSAStartup(0x0101, &gWinsock) != 0)
66         {
67                 fprintf(stderr, "Could not open Windows connection.\n");
68                 exit(0);
69         }
70         #endif
71         
72         #if USE_TCP
73         // Open TCP Connection
74         gSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
75         #else
76         // Open UDP Connection
77         gSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
78         #endif
79         if (gSocket == INVALID_SOCKET)
80         {
81                 fprintf(stderr, "Could not create socket.\n");
82                 #if __WIN32__
83                 WSACleanup();
84                 #endif
85                 exit(0);
86         }
87         
88         #if USE_TCP
89         if( connect(gSocket, (struct sockaddr *)&gSyscall_ServerAddr, sizeof(struct sockaddr_in)) < 0 )
90         {
91                 fprintf(stderr, "[ERROR -] Cannot connect to server (localhost:%i)\n", SERVER_PORT);
92                 perror("_InitSyscalls");
93                 #if __WIN32__
94                 fprintf(stderr, "[ERROR -] - WSAGetLastError said %i", WSAGetLastError());
95                 closesocket(gSocket);
96                 WSACleanup();
97                 #else
98                 close(gSocket);
99                 #endif
100                 exit(0);
101         }
102         #endif
103         
104         #if 0
105         // Set client address
106         memset((void *)&client, '\0', sizeof(struct sockaddr_in));
107         client.sin_family = AF_INET;
108         client.sin_port = htons(0);
109         client.sin_addr.s_addr = htonl(0x7F000001);
110         // Bind
111         if( bind(gSocket, (struct sockaddr *)&client, sizeof(struct sockaddr_in)) == -1 )
112         {
113                 fprintf(stderr, "Cannot bind address to socket.\n");
114                 #if __WIN32__
115                 closesocket(gSocket);
116                 WSACleanup();
117                 #else
118                 close(gSocket);
119                 #endif
120                 exit(0);
121         }
122         #endif
123         
124         #if USE_TCP
125         {
126                 tRequestAuthHdr auth;
127                 auth.pid = giSyscall_ClientID;
128                 auth.key = 0;
129                 SendData(&auth, sizeof(auth));
130                 int len = ReadData(&auth, sizeof(auth), 5);
131                 if( len == 0 ) { 
132                         fprintf(stderr, "Timeout waiting for auth response\n");
133                         exit(-1);
134                 }
135                 giSyscall_ClientID = auth.pid;
136         }
137         #else
138         // Ask server for a client ID
139         if( !giSyscall_ClientID )
140         {
141                 tRequestHeader  req;
142                  int    len;
143                 req.ClientID = 0;
144                 req.CallID = 0;
145                 req.NParams = 0;
146                 
147                 SendData(&req, sizeof(req));
148                 
149                 len = ReadData(&req, sizeof(req), 5);
150                 if( len == 0 ) {
151                         fprintf(stderr, "Unable to connect to server (localhost:%i)\n", SERVER_PORT);
152                         exit(-1);
153                 }
154                 
155                 giSyscall_ClientID = req.ClientID;
156         }
157         #endif
158         
159         return 0;
160 }
161
162 /**
163  * \brief Close the syscall socket
164  * \note Used in acess_fork to get a different port number
165  */
166 void _CloseSyscalls(void)
167 {
168         #if __WIN32__
169         closesocket(gSocket);
170         WSACleanup();
171         #else
172         close(gSocket);
173         #endif
174 }
175
176 int SendRequest(tRequestHeader *Request, int RequestSize, int ResponseSize)
177 {
178         if( gSocket == INVALID_SOCKET )
179         {
180                 _InitSyscalls();                
181         }
182         
183         // Set header
184         Request->ClientID = giSyscall_ClientID;
185         
186         #if 0
187         {
188                 for(i=0;i<RequestSize;i++)
189                 {
190                         printf("%02x ", ((uint8_t*)Request)[i]);
191                         if( i % 16 == 15 )      printf("\n");
192                 }
193                 printf("\n");
194         }
195         #endif
196         #if DEBUG
197         {
198                  int    i;
199                 char    *data = (char*)&Request->Params[Request->NParams];
200                 DEBUG_S("Request #%i (%s) -", Request->CallID, casSYSCALL_NAMES[Request->CallID]);
201                 for( i = 0; i < Request->NParams; i ++ )
202                 {
203                         switch(Request->Params[i].Type)
204                         {
205                         case ARG_TYPE_INT32:
206                                 DEBUG_S(" 0x%08x", *(uint32_t*)data);
207                                 data += sizeof(uint32_t);
208                                 break;
209                         case ARG_TYPE_INT64:
210                                 DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
211                                 data += sizeof(uint64_t);
212                                 break;
213                         case ARG_TYPE_STRING:
214                                 DEBUG_S(" '%s'", (char*)data);
215                                 data += Request->Params[i].Length;
216                                 break;
217                         case ARG_TYPE_DATA:
218                                 DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
219                                 if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
220                                         data += Request->Params[i].Length;
221                                 break;
222                         }
223                 }
224                 DEBUG_S("\n");
225         }
226         #endif
227         
228         // Send it off
229         SendData(Request, RequestSize);
230
231         if( Request->CallID == SYS_EXIT )       return 0;
232
233         // Wait for a response (no timeout)
234         ReadData(Request, sizeof(*Request), 0);
235         // TODO: Sanity
236         size_t  recvbytes = sizeof(*Request), expbytes = Request->MessageLength;
237         char    *ptr = (void*)Request->Params;
238         while( recvbytes < expbytes )
239         {
240                 size_t  len = ReadData(ptr, expbytes - recvbytes, 1000);
241                 if( len == -1 ) {
242                         return -1;
243                 }
244                 recvbytes += len;
245                 ptr += len;
246         }
247         if( recvbytes > expbytes ) {
248                 // TODO: Warning
249         }
250         
251         #if DEBUG
252         {
253                  int    i;
254                 char    *data = (char*)&Request->Params[Request->NParams];
255                 DEBUG_S(" Reply:");
256                 for( i = 0; i < Request->NParams; i ++ )
257                 {
258                         switch(Request->Params[i].Type)
259                         {
260                         case ARG_TYPE_INT32:
261                                 DEBUG_S(" 0x%08x", *(uint32_t*)data);
262                                 data += sizeof(uint32_t);
263                                 break;
264                         case ARG_TYPE_INT64:
265                                 DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
266                                 data += sizeof(uint64_t);
267                                 break;
268                         case ARG_TYPE_STRING:
269                                 DEBUG_S(" '%s'", (char*)data);
270                                 data += Request->Params[i].Length;
271                                 break;
272                         case ARG_TYPE_DATA:
273                                 DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
274                                 if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
275                                         data += Request->Params[i].Length;
276                                 break;
277                         }
278                 }
279                 DEBUG_S("\n");
280         }
281         #endif
282         return recvbytes;
283 }
284
285 void SendData(void *Data, int Length)
286 {
287          int    len;
288         
289         #if USE_TCP
290         len = send(gSocket, Data, Length, 0);
291         #else
292         len = sendto(gSocket, Data, Length, 0,
293                 (struct sockaddr*)&gSyscall_ServerAddr, sizeof(gSyscall_ServerAddr));
294         #endif
295         
296         if( len != Length ) {
297                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
298                 perror("SendData");
299                 exit(-1);
300         }
301 }
302
303 int ReadData(void *Dest, int MaxLength, int Timeout)
304 {
305          int    ret;
306         fd_set  fds;
307         struct timeval  tv;
308         struct timeval  *timeoutPtr;
309         
310         FD_ZERO(&fds);
311         FD_SET(gSocket, &fds);
312         
313         if( Timeout ) {
314                 tv.tv_sec = Timeout;
315                 tv.tv_usec = 0;
316                 timeoutPtr = &tv;
317         }
318         else {
319                 timeoutPtr = NULL;
320         }
321         
322         ret = select(gSocket+1, &fds, NULL, NULL, timeoutPtr);
323         if( ret == -1 ) {
324                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
325                 perror("ReadData - select");
326                 exit(-1);
327         }
328         
329         if( !ret ) {
330                 printf("[ERROR %i] Timeout reading from socket\n", giSyscall_ClientID);
331                 return -2;      // Timeout
332         }
333         
334         #if USE_TCP
335         ret = recv(gSocket, Dest, MaxLength, 0);
336         #else
337         ret = recvfrom(gSocket, Dest, MaxLength, 0, NULL, 0);
338         #endif
339         
340         if( ret < 0 ) {
341                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
342                 perror("ReadData");
343                 exit(-1);
344         }
345         if( ret == 0 ) {
346                 fprintf(stderr, "[ERROR %i] Connection closed.\n", giSyscall_ClientID);
347                 #if __WIN32__
348                 closesocket(gSocket);
349                 #else
350                 close(gSocket);
351                 #endif
352                 exit(0);
353         }
354         
355         DEBUG_S("%i bytes read from socket\n", ret);
356         
357         return ret;
358 }

UCC git Repository :: git.ucc.asn.au