AcessNative - Added message length to simplfy message reception
[tpg/acess2.git] / AcessNative / ld-acess_src / request.c
1 /*
2  */
3 #define DEBUG   0
4
5
6 #if DEBUG
7 # define DEBUG_S        printf
8 #else
9 # define DEBUG_S(...)
10 # define DONT_INCLUDE_SYSCALL_NAMES
11 #endif
12
13 #include <stdlib.h>
14 #include <string.h>
15 #include <stdio.h>
16 #include <inttypes.h>
17 #ifdef __WIN32__
18 # include <windows.h>
19 # include <winsock.h>
20 #else
21 # include <unistd.h>
22 # include <sys/socket.h>
23 # include <netinet/in.h>
24 #endif
25 #include "request.h"
26 #include "../syscalls.h"
27
28 #define USE_TCP 1
29
30 // === PROTOTYPES ===
31 void    SendData(void *Data, int Length);
32  int    ReadData(void *Dest, int MaxLen, int Timeout);
33
34 // === GLOBALS ===
35 #ifdef __WIN32__
36 WSADATA gWinsock;
37 SOCKET  gSocket = INVALID_SOCKET;
38 #else
39 # define INVALID_SOCKET -1
40  int    gSocket = INVALID_SOCKET;
41 #endif
42 // Client ID to pass to server
43 // TODO: Implement such that each thread gets a different one
44  int    giSyscall_ClientID = 0;
45 struct sockaddr_in      gSyscall_ServerAddr;
46
47 // === CODE ===
48 void Request_Preinit(void)
49 {
50         // Set server address
51         memset((void *)&gSyscall_ServerAddr, '\0', sizeof(struct sockaddr_in));
52         gSyscall_ServerAddr.sin_family = AF_INET;
53         gSyscall_ServerAddr.sin_port = htons(SERVER_PORT);
54 }
55
56 int _InitSyscalls(void)
57 {
58         #ifdef __WIN32__
59         /* Open windows connection */
60         if (WSAStartup(0x0101, &gWinsock) != 0)
61         {
62                 fprintf(stderr, "Could not open Windows connection.\n");
63                 exit(0);
64         }
65         #endif
66         
67         #if USE_TCP
68         // Open TCP Connection
69         gSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
70         #else
71         // Open UDP Connection
72         gSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
73         #endif
74         if (gSocket == INVALID_SOCKET)
75         {
76                 fprintf(stderr, "Could not create socket.\n");
77                 #if __WIN32__
78                 WSACleanup();
79                 #endif
80                 exit(0);
81         }
82         
83         #if 0
84         // Set client address
85         memset((void *)&client, '\0', sizeof(struct sockaddr_in));
86         client.sin_family = AF_INET;
87         client.sin_port = htons(0);
88         client.sin_addr.s_addr = htonl(0x7F000001);
89         #endif
90         
91         #if USE_TCP
92         if( connect(gSocket, (struct sockaddr *)&gSyscall_ServerAddr, sizeof(struct sockaddr_in)) < 0 )
93         {
94                 fprintf(stderr, "[ERROR -] Cannot connect to server (localhost:%i)\n", SERVER_PORT);
95                 perror("_InitSyscalls");
96                 #if __WIN32__
97                 closesocket(gSocket);
98                 WSACleanup();
99                 #else
100                 close(gSocket);
101                 #endif
102                 exit(0);
103         }
104         #endif
105         
106         #if 0
107         // Bind
108         if( bind(gSocket, (struct sockaddr *)&client, sizeof(struct sockaddr_in)) == -1 )
109         {
110                 fprintf(stderr, "Cannot bind address to socket.\n");
111                 #if __WIN32__
112                 closesocket(gSocket);
113                 WSACleanup();
114                 #else
115                 close(gSocket);
116                 #endif
117                 exit(0);
118         }
119         #endif
120         
121         #if USE_TCP
122         {
123                 tRequestAuthHdr auth;
124                 auth.pid = giSyscall_ClientID;
125                 auth.key = 0;
126                 SendData(&auth, sizeof(auth));
127                 int len = ReadData(&auth, sizeof(auth), 5);
128                 if( len == 0 ) { 
129                         fprintf(stderr, "Timeout waiting for auth response\n");
130                         exit(-1);
131                 }
132                 giSyscall_ClientID = auth.pid;
133         }
134         #else
135         // Ask server for a client ID
136         if( !giSyscall_ClientID )
137         {
138                 tRequestHeader  req;
139                  int    len;
140                 req.ClientID = 0;
141                 req.CallID = 0;
142                 req.NParams = 0;
143                 
144                 SendData(&req, sizeof(req));
145                 
146                 len = ReadData(&req, sizeof(req), 5);
147                 if( len == 0 ) {
148                         fprintf(stderr, "Unable to connect to server (localhost:%i)\n", SERVER_PORT);
149                         exit(-1);
150                 }
151                 
152                 giSyscall_ClientID = req.ClientID;
153         }
154         #endif
155         
156         return 0;
157 }
158
159 /**
160  * \brief Close the syscall socket
161  * \note Used in acess_fork to get a different port number
162  */
163 void _CloseSyscalls(void)
164 {
165         #if __WIN32__
166         closesocket(gSocket);
167         WSACleanup();
168         #else
169         close(gSocket);
170         #endif
171 }
172
173 int SendRequest(tRequestHeader *Request, int RequestSize, int ResponseSize)
174 {
175         if( gSocket == INVALID_SOCKET )
176         {
177                 _InitSyscalls();                
178         }
179         
180         // Set header
181         Request->ClientID = giSyscall_ClientID;
182         
183         #if 0
184         {
185                 for(i=0;i<RequestSize;i++)
186                 {
187                         printf("%02x ", ((uint8_t*)Request)[i]);
188                         if( i % 16 == 15 )      printf("\n");
189                 }
190                 printf("\n");
191         }
192         #endif
193         #if DEBUG
194         {
195                  int    i;
196                 char    *data = (char*)&Request->Params[Request->NParams];
197                 DEBUG_S("Request #%i (%s) -", Request->CallID, casSYSCALL_NAMES[Request->CallID]);
198                 for( i = 0; i < Request->NParams; i ++ )
199                 {
200                         switch(Request->Params[i].Type)
201                         {
202                         case ARG_TYPE_INT32:
203                                 DEBUG_S(" 0x%08x", *(uint32_t*)data);
204                                 data += sizeof(uint32_t);
205                                 break;
206                         case ARG_TYPE_INT64:
207                                 DEBUG_S(" 0x%016"PRIx64"", *(uint64_t*)data);
208                                 data += sizeof(uint64_t);
209                                 break;
210                         case ARG_TYPE_STRING:
211                                 DEBUG_S(" '%s'", (char*)data);
212                                 data += Request->Params[i].Length;
213                                 break;
214                         case ARG_TYPE_DATA:
215                                 DEBUG_S(" %p:0x%x", (char*)data, Request->Params[i].Length);
216                                 if( !(Request->Params[i].Flags & ARG_FLAG_ZEROED) )
217                                         data += Request->Params[i].Length;
218                                 break;
219                         }
220                 }
221                 DEBUG_S("\n");
222         }
223         #endif
224         
225         // Send it off
226         SendData(Request, RequestSize);
227
228         if( Request->CallID == SYS_EXIT )       return 0;
229
230         // Wait for a response (no timeout)
231         ReadData(Request, sizeof(*Request), 0);
232         // TODO: Sanity
233         size_t  recvbytes = sizeof(*Request), expbytes = Request->MessageLength;
234         char    *ptr = (void*)Request->Params;
235         while( recvbytes < expbytes )
236         {
237                 size_t  len = ReadData(ptr, expbytes - recvbytes, 1000);
238                 if( len == -1 ) {
239                         return -1;
240                 }
241                 recvbytes += len;
242                 ptr += len;
243         }
244         if( recvbytes > expbytes ) {
245                 // TODO: Warning
246         }
247         return recvbytes;
248 }
249
250 void SendData(void *Data, int Length)
251 {
252          int    len;
253         
254         #if USE_TCP
255         len = send(gSocket, Data, Length, 0);
256         #else
257         len = sendto(gSocket, Data, Length, 0,
258                 (struct sockaddr*)&gSyscall_ServerAddr, sizeof(gSyscall_ServerAddr));
259         #endif
260         
261         if( len != Length ) {
262                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
263                 perror("SendData");
264                 exit(-1);
265         }
266 }
267
268 int ReadData(void *Dest, int MaxLength, int Timeout)
269 {
270          int    ret;
271         fd_set  fds;
272         struct timeval  tv;
273         struct timeval  *timeoutPtr;
274         
275         FD_ZERO(&fds);
276         FD_SET(gSocket, &fds);
277         
278         if( Timeout ) {
279                 tv.tv_sec = Timeout;
280                 tv.tv_usec = 0;
281                 timeoutPtr = &tv;
282         }
283         else {
284                 timeoutPtr = NULL;
285         }
286         
287         ret = select(gSocket+1, &fds, NULL, NULL, timeoutPtr);
288         if( ret == -1 ) {
289                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
290                 perror("ReadData - select");
291                 exit(-1);
292         }
293         
294         if( !ret ) {
295                 printf("[ERROR %i] Timeout reading from socket\n", giSyscall_ClientID);
296                 return 0;       // Timeout
297         }
298         
299         #if USE_TCP
300         ret = recv(gSocket, Dest, MaxLength, 0);
301         #else
302         ret = recvfrom(gSocket, Dest, MaxLength, 0, NULL, 0);
303         #endif
304         
305         if( ret < 0 ) {
306                 fprintf(stderr, "[ERROR %i] ", giSyscall_ClientID);
307                 perror("ReadData");
308                 exit(-1);
309         }
310         
311         DEBUG_S("%i bytes read from socket\n", ret);
312         
313         return ret;
314 }

UCC git Repository :: git.ucc.asn.au