Merge branch 'master' of git://cadel.mutabah.net/acess2
[tpg/acess2.git] / AcessNative / ld-acess_src / request.c
1 /*
2  */
3 #define DEBUG   1
4
5
6 #if DEBUG
7 # define DEBUG_S        printf
8 #else
9 # define DEBUG_S(...)
10 # define DONT_INCLUDE_SYSCALL_NAMES
11 #endif
12
13 #include <stdlib.h>
14 #include <string.h>
15 #include <stdio.h>
16 #include <inttypes.h>
17 #ifdef __WIN32__
18 # include <windows.h>
19 # include <winsock.h>
20 #else
21 # include <unistd.h>
22 # include <sys/socket.h>
23 # include <netinet/in.h>
24 #endif
25 #include "request.h"
26 #include "../syscalls.h"
27
28 #define USE_TCP 1
29
30 // === PROTOTYPES ===
31 void    SendData(void *Data, int Length);
32  int    ReadData(void *Dest, int MaxLen, int Timeout);
33
34 // === GLOBALS ===
35 #ifdef __WIN32__
36 WSADATA gWinsock;
37 SOCKET  gSocket = INVALID_SOCKET;
38 #else
39 # define INVALID_SOCKET -1
40  int    gSocket = INVALID_SOCKET;
41 #endif
42 // Client ID to pass to server
43 // TODO: Implement such that each thread gets a different one
44  int    giSyscall_ClientID = 0;
45 struct sockaddr_in      gSyscall_ServerAddr;
46
47 // === CODE ===
48 void Request_Preinit(void)
49 {
50         // Set server address
51         memset((void *)&gSyscall_ServerAddr, '\0', sizeof(struct sockaddr_in));
52         gSyscall_ServerAddr.sin_family = AF_INET;
53         gSyscall_ServerAddr.sin_port = htons(SERVER_PORT);
54         gSyscall_ServerAddr.sin_addr.s_addr = htonl(0x7F000001);
55 }
56
57 int _InitSyscalls(void)
58 {
59         #ifdef __WIN32__
60         /* Open windows connection */
61         if (WSAStartup(0x0101, &gWinsock) != 0)
62         {
63                 fprintf(stderr, "Could not open Windows connection.\n");
64                 exit(0);
65         }
66         #endif
67         
68         #if USE_TCP
69         // Open TCP Connection
70         gSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
71         #else
72         // Open UDP Connection
73         gSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
74         #endif
75         if (gSocket == INVALID_SOCKET)
76         {
77                 fprintf(stderr, "Could not create socket.\n");
78                 #if __WIN32__
79                 WSACleanup();
80                 #endif
81                 exit(0);
82         }
83         
84         #if USE_TCP
85         if( connect(gSocket, (struct sockaddr *)&gSyscall_ServerAddr, sizeof(struct sockaddr_in)) < 0 )
86         {
87                 fprintf(stderr, "[ERROR -] Cannot connect to server (localhost:%i)\n", SERVER_PORT);
88                 perror("_InitSyscalls");
89                 #if __WIN32__
90                 fprintf(stderr, "[ERROR -] - WSAGetLastError said %i", WSAGetLastError());
91                 closesocket(gSocket);
92                 WSACleanup();
93                 #else
94                 close(gSocket);
95                 #endif
96                 exit(0);
97         }
98         #endif
99         
100         #if 0
101         // Set client address
102         memset((void *)&client, '\0', sizeof(struct sockaddr_in));
103         client.sin_family = AF_INET;
104         client.sin_port = htons(0);
105         client.sin_addr.s_addr = htonl(0x7F000001);
106         // Bind
107         if( bind(gSocket, (struct sockaddr *)&client, sizeof(struct sockaddr_in)) == -1 )
108         {
109                 fprintf(stderr, "Cannot bind address to socket.\n");
110                 #if __WIN32__
111                 closesocket(gSocket);
112                 WSACleanup();
113                 #else
114                 close(gSocket);
115                 #endif
116                 exit(0);
117         }
118         #endif
119         
120         #if USE_TCP
121         {
122                 tRequestAuthHdr auth;
123                 auth.pid = giSyscall_ClientID;
124                 auth.key = 0;
125                 SendData(&auth, sizeof(auth));
126                 int len = ReadData(&auth, sizeof(auth), 5);
127                 if( len == 0 ) { 
128                         fprintf(stderr, "Timeout waiting for auth response\n");
129                         exit(-1);
130                 }
131                 giSyscall_ClientID = auth.pid;
132         }
133         #else
134         // Ask server for a client ID
135         if( !giSyscall_ClientID )
136         {
137                 tRequestHeader  req;
138                  int    len;
139                 req.ClientID = 0;
140                 req.CallID = 0;
141                 req.NParams = 0;
142                 
143                 SendData(&req, sizeof(req));
144                 
145                 len = ReadData(&req, sizeof(req), 5);
146                 if( len == 0 ) {
147                         fprintf(stderr, "Unable to connect to server (localhost:%i)\n", SERVER_PORT);
148                         exit(-1);
149                 }
150                 
151                 giSyscall_ClientID = req.ClientID;
152         }
153         #endif
154         
155         return 0;
156 }
157
158 /**
159  * \brief Close the syscall socket
160  * \note Used in acess_fork to get a different port number
161  */
162 void _CloseSyscalls(void)
163 {
164         #if __WIN32__
165         closesocket(gSocket);
166         WSACleanup();
167         #else
168         close(gSocket);
169         #endif
170 }
171
172 int SendRequest(tRequestHeader *Request, int RequestSize, int ResponseSize)
173 {
174         if( gSocket == INVALID_SOCKET )
175         {
176                 _InitSyscalls();                
177         }
178         
179         // Set header
180         Request->ClientID = giSyscall_ClientID;
181         
182         #if 0
183         {
184                 for(i=0;i<RequestSize;i++)
185                 {
186                         printf("%02x ", ((uint8_t*)Request)[i]);
187                         if( i % 16 == 15 )      printf("\n");
188                 }
189                 printf("\n");
190         }
191         #endif
192         #if DEBUG
193         {
194                  int    i;
195                 char    *data = (char*)&Request->Params[Request->NParams];
196                 DEBUG_S("Request #%i (%s) -", Request->CallID, casSYSCALL_NAMES[Request->CallID]);
197                 for( i = 0; i < Request->NParams; i ++ )
198                 {
199                         switch(Request->Params[i].Type)
200                         {
201                         case ARG_TYPE_INT32:
202                                 DEBUG_S(" 0x%08x", *(uint32_t*)data);
203                                 data += sizeof(uint32_t);
204                                 break;
205                         case ARG_TYPE_INT64:
206                                 DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
207                                 data += sizeof(uint64_t);
208                                 break;
209                         case ARG_TYPE_STRING:
210                                 DEBUG_S(" '%s'", (char*)data);
211                                 data += Request->Params[i].Length;
212                                 break;
213                         case ARG_TYPE_DATA:
214                                 DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
215                                 if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
216                                         data += Request->Params[i].Length;
217                                 break;
218                         }
219                 }
220                 DEBUG_S("\n");
221         }
222         #endif
223         
224         // Send it off
225         SendData(Request, RequestSize);
226
227         if( Request->CallID == SYS_EXIT )       return 0;
228
229         // Wait for a response (no timeout)
230         ReadData(Request, sizeof(*Request), 0);
231         // TODO: Sanity
232         size_t  recvbytes = sizeof(*Request), expbytes = Request->MessageLength;
233         char    *ptr = (void*)Request->Params;
234         while( recvbytes < expbytes )
235         {
236                 size_t  len = ReadData(ptr, expbytes - recvbytes, 1000);
237                 if( len == -1 ) {
238                         return -1;
239                 }
240                 recvbytes += len;
241                 ptr += len;
242         }
243         if( recvbytes > expbytes ) {
244                 // TODO: Warning
245         }
246         
247         #if DEBUG
248         {
249                  int    i;
250                 char    *data = (char*)&Request->Params[Request->NParams];
251                 DEBUG_S(" Reply:");
252                 for( i = 0; i < Request->NParams; i ++ )
253                 {
254                         switch(Request->Params[i].Type)
255                         {
256                         case ARG_TYPE_INT32:
257                                 DEBUG_S(" 0x%08x", *(uint32_t*)data);
258                                 data += sizeof(uint32_t);
259                                 break;
260                         case ARG_TYPE_INT64:
261                                 DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
262                                 data += sizeof(uint64_t);
263                                 break;
264                         case ARG_TYPE_STRING:
265                                 DEBUG_S(" '%s'", (char*)data);
266                                 data += Request->Params[i].Length;
267                                 break;
268                         case ARG_TYPE_DATA:
269                                 DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
270                                 if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
271                                         data += Request->Params[i].Length;
272                                 break;
273                         }
274                 }
275                 DEBUG_S("\n");
276         }
277         #endif
278         return recvbytes;
279 }
280
281 void SendData(void *Data, int Length)
282 {
283          int    len;
284         
285         #if USE_TCP
286         len = send(gSocket, Data, Length, 0);
287         #else
288         len = sendto(gSocket, Data, Length, 0,
289                 (struct sockaddr*)&gSyscall_ServerAddr, sizeof(gSyscall_ServerAddr));
290         #endif
291         
292         if( len != Length ) {
293                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
294                 perror("SendData");
295                 exit(-1);
296         }
297 }
298
299 int ReadData(void *Dest, int MaxLength, int Timeout)
300 {
301          int    ret;
302         fd_set  fds;
303         struct timeval  tv;
304         struct timeval  *timeoutPtr;
305         
306         FD_ZERO(&fds);
307         FD_SET(gSocket, &fds);
308         
309         if( Timeout ) {
310                 tv.tv_sec = Timeout;
311                 tv.tv_usec = 0;
312                 timeoutPtr = &tv;
313         }
314         else {
315                 timeoutPtr = NULL;
316         }
317         
318         ret = select(gSocket+1, &fds, NULL, NULL, timeoutPtr);
319         if( ret == -1 ) {
320                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
321                 perror("ReadData - select");
322                 exit(-1);
323         }
324         
325         if( !ret ) {
326                 printf("[ERROR %i] Timeout reading from socket\n", giSyscall_ClientID);
327                 return -2;      // Timeout
328         }
329         
330         #if USE_TCP
331         ret = recv(gSocket, Dest, MaxLength, 0);
332         #else
333         ret = recvfrom(gSocket, Dest, MaxLength, 0, NULL, 0);
334         #endif
335         
336         if( ret < 0 ) {
337                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
338                 perror("ReadData");
339                 exit(-1);
340         }
341         if( ret == 0 ) {
342                 fprintf(stderr, "[ERROR %i] Connection closed.\n", giSyscall_ClientID);
343                 #if __WIN32__
344                 closesocket(gSocket);
345                 #else
346                 close(gSocket);
347                 #endif
348                 exit(0);
349         }
350         
351         DEBUG_S("%i bytes read from socket\n", ret);
352         
353         return ret;
354 }

UCC git Repository :: git.ucc.asn.au