Modules/IPStack - Fixed UDP port collision detection
[tpg/acess2.git] / KernelLand / Modules / IPStack / udp.c
1 /*
2  * Acess2 IP Stack
3  * - By John Hodge (thePowersGang)
4  *
5  * udp.c
6  * - UDP Protocol handling
7  */
8 #define DEBUG   1
9 #include "ipstack.h"
10 #include <api_drv_common.h>
11 #include "udp.h"
12
13 #define UDP_ALLOC_BASE  0xC000
14
15 // === PROTOTYPES ===
16 void    UDP_Initialise();
17 void    UDP_GetPacket(tInterface *Interface, void *Address, int Length, void *Buffer);
18 void    UDP_Unreachable(tInterface *Interface, int Code, void *Address, int Length, void *Buffer);
19 void    UDP_SendPacketTo(tUDPChannel *Channel, int AddrType, const void *Address, Uint16 Port, const void *Data, size_t Length);
20 // --- Client Channels
21 tVFS_Node       *UDP_Channel_Init(tInterface *Interface);
22 size_t  UDP_Channel_Read(tVFS_Node *Node, off_t Offset, size_t Length, void *Buffer, Uint Flags);
23 size_t  UDP_Channel_Write(tVFS_Node *Node, off_t Offset, size_t Length, const void *Buffer, Uint Flags);
24  int    UDP_Channel_IOCtl(tVFS_Node *Node, int ID, void *Data);
25 void    UDP_Channel_Close(tVFS_Node *Node);
26 // --- Helpers
27 Uint16  UDP_int_AllocatePort(tUDPChannel *Channel);
28  int    UDP_int_ClaimPort(tUDPChannel *Channel, Uint16 Port);
29 void    UDP_int_FreePort(Uint16 Port);
30
31 // === GLOBALS ===
32 tVFS_NodeType   gUDP_NodeType = {
33         .Read = UDP_Channel_Read,
34         .Write = UDP_Channel_Write,
35         .IOCtl = UDP_Channel_IOCtl,
36         .Close = UDP_Channel_Close
37 };
38 tMutex  glUDP_Channels;
39 tUDPChannel     *gpUDP_Channels;
40
41 tMutex  glUDP_Ports;
42 Uint32  gUDP_Ports[0x10000/32];
43
44 tSocketFile     gUDP_SocketFile = {NULL, "udp", UDP_Channel_Init};
45
46 // === CODE ===
47 /**
48  * \fn void TCP_Initialise()
49  * \brief Initialise the TCP Layer
50  */
51 void UDP_Initialise()
52 {
53         IPStack_AddFile(&gUDP_SocketFile);
54         //IPv4_RegisterCallback(IP4PROT_UDP, UDP_GetPacket, UDP_Unreachable);
55         IPv4_RegisterCallback(IP4PROT_UDP, UDP_GetPacket);
56 }
57
58 /**
59  * \brief Scan a list of tUDPChannels and find process the first match
60  * \return 0 if no match was found, -1 on error and 1 if a match was found
61  */
62 int UDP_int_ScanList(tUDPChannel *List, tInterface *Interface, void *Address, int Length, void *Buffer)
63 {
64         tUDPHeader      *hdr = Buffer;
65         tUDPChannel     *chan;
66         tUDPPacket      *pack;
67          int    len;
68         
69         for(chan = List; chan; chan = chan->Next)
70         {
71                 // Match local endpoint
72                 if(chan->Interface && chan->Interface != Interface)     continue;
73                 if(chan->LocalPort != ntohs(hdr->DestPort))     continue;
74                 
75                 // Check for remote port restriction
76                 if(chan->Remote.Port && chan->Remote.Port != ntohs(hdr->SourcePort))
77                         continue;
78                 // Check for remote address restriction
79                 if(chan->RemoteMask)
80                 {
81                         if(chan->Remote.AddrType != Interface->Type)
82                                 continue;
83                         if(!IPStack_CompareAddress(Interface->Type, Address,
84                                 &chan->Remote.Addr, chan->RemoteMask)
85                                 )
86                                 continue;
87                 }
88                 
89                 Log_Log("UDP", "Recieved packet for %p", chan);
90                 // Create the cached packet
91                 len = ntohs(hdr->Length);
92                 pack = malloc(sizeof(tUDPPacket) + len);
93                 pack->Next = NULL;
94                 memcpy(&pack->Remote.Addr, Address, IPStack_GetAddressSize(Interface->Type));
95                 pack->Remote.Port = ntohs(hdr->SourcePort);
96                 pack->Remote.AddrType = Interface->Type;
97                 pack->Length = len;
98                 memcpy(pack->Data, hdr->Data, len);
99                 
100                 // Add the packet to the channel's queue
101                 SHORTLOCK(&chan->lQueue);
102                 if(chan->Queue)
103                         chan->QueueEnd->Next = pack;
104                 else
105                         chan->QueueEnd = chan->Queue = pack;
106                 SHORTREL(&chan->lQueue);
107                 VFS_MarkAvaliable(&chan->Node, 1);
108                 Mutex_Release(&glUDP_Channels);
109                 return 1;
110         }
111         return 0;
112 }
113
114 /**
115  * \fn void UDP_GetPacket(tInterface *Interface, void *Address, int Length, void *Buffer)
116  * \brief Handles a packet from the IP Layer
117  */
118 void UDP_GetPacket(tInterface *Interface, void *Address, int Length, void *Buffer)
119 {
120         tUDPHeader      *hdr = Buffer;
121         
122         Log_Debug("UDP", "%i bytes :%i->:%i (Cksum 0x%04x)",
123                 ntohs(hdr->Length), ntohs(hdr->SourcePort), ntohs(hdr->Length), ntohs(hdr->Checksum));
124         
125         // Check registered connections
126         Mutex_Acquire(&glUDP_Channels);
127         UDP_int_ScanList(gpUDP_Channels, Interface, Address, Length, Buffer);
128         Mutex_Release(&glUDP_Channels);
129 }
130
131 /**
132  * \brief Handle an ICMP Unrechable Error
133  */
134 void UDP_Unreachable(tInterface *Interface, int Code, void *Address, int Length, void *Buffer)
135 {
136         
137 }
138
139 /**
140  * \brief Send a packet
141  * \param Channel       Channel to send the packet from
142  * \param Data  Packet data
143  * \param Length        Length in bytes of packet data
144  */
145 void UDP_SendPacketTo(tUDPChannel *Channel, int AddrType, const void *Address, Uint16 Port, const void *Data, size_t Length)
146 {
147         tUDPHeader      hdr;
148
149         if(Channel->Interface && Channel->Interface->Type != AddrType)  return ;
150         
151         switch(AddrType)
152         {
153         case 4:
154                 // Create the packet
155                 hdr.SourcePort = htons( Channel->LocalPort );
156                 hdr.DestPort = htons( Port );
157                 hdr.Length = htons( sizeof(tUDPHeader) + Length );
158                 hdr.Checksum = 0;       // Checksum can be zero on IPv4
159                 // Pass on the the IPv4 Layer
160                 tIPStackBuffer  *buffer = IPStack_Buffer_CreateBuffer(2 + IPV4_BUFFERS);
161                 IPStack_Buffer_AppendSubBuffer(buffer, Length, 0, Data, NULL, NULL);
162                 IPStack_Buffer_AppendSubBuffer(buffer, sizeof(hdr), 0, &hdr, NULL, NULL);
163                 // TODO: What if Channel->Interface is NULL here?
164                 IPv4_SendPacket(Channel->Interface, *(tIPv4*)Address, IP4PROT_UDP, 0, buffer);
165                 IPStack_Buffer_DestroyBuffer(buffer);
166                 break;
167         }
168 }
169
170 // --- Client Channels
171 tVFS_Node *UDP_Channel_Init(tInterface *Interface)
172 {
173         tUDPChannel     *new;
174         new = calloc( sizeof(tUDPChannel), 1 );
175         new->Interface = Interface;
176         new->Node.ImplPtr = new;
177         new->Node.NumACLs = 1;
178         new->Node.ACLs = &gVFS_ACL_EveryoneRW;
179         new->Node.Type = &gUDP_NodeType;
180         
181         Mutex_Acquire(&glUDP_Channels);
182         new->Next = gpUDP_Channels;
183         gpUDP_Channels = new;
184         Mutex_Release(&glUDP_Channels);
185         
186         return &new->Node;
187 }
188
189 /**
190  * \brief Read from the channel file (wait for a packet)
191  */
192 size_t UDP_Channel_Read(tVFS_Node *Node, off_t Offset, size_t Length, void *Buffer, Uint Flags)
193 {
194         tUDPChannel     *chan = Node->ImplPtr;
195         tUDPPacket      *pack;
196         tUDPEndpoint    *ep;
197          int    ofs, addrlen;
198         
199         if(chan->LocalPort == 0) {
200                 Log_Notice("UDP", "Channel %p sent with no local port", chan);
201                 return 0;
202         }
203         
204         while(chan->Queue == NULL)      Threads_Yield();
205         
206         for(;;)
207         {
208                 tTime   timeout_z = 0, *timeout = (Flags & VFS_IOFLAG_NOBLOCK) ? &timeout_z : NULL;
209                 int rv = VFS_SelectNode(Node, VFS_SELECT_READ, timeout, "UDP_Channel_Read");
210                 if( rv ) {
211                         errno = (Flags & VFS_IOFLAG_NOBLOCK) ? EWOULDBLOCK : EINTR;
212                 }
213                 SHORTLOCK(&chan->lQueue);
214                 if(chan->Queue == NULL) {
215                         SHORTREL(&chan->lQueue);
216                         continue;
217                 }
218                 pack = chan->Queue;
219                 chan->Queue = pack->Next;
220                 if(!chan->Queue) {
221                         chan->QueueEnd = NULL;
222                         VFS_MarkAvaliable(Node, 0);     // Nothing left
223                 }
224                 SHORTREL(&chan->lQueue);
225                 break;
226         }
227
228         // Check that the header fits
229         addrlen = IPStack_GetAddressSize(pack->Remote.AddrType);
230         ep = Buffer;
231         ofs = 4 + addrlen;
232         if(Length < ofs) {
233                 free(pack);
234                 Log_Notice("UDP", "Insuficient space for header in buffer (%i < %i)", (int)Length, ofs);
235                 return 0;
236         }
237         
238         // Fill header
239         ep->Port = pack->Remote.Port;
240         ep->AddrType = pack->Remote.AddrType;
241         memcpy(&ep->Addr, &pack->Remote.Addr, addrlen);
242         
243         // Copy packet data
244         if(Length > ofs + pack->Length) Length = ofs + pack->Length;
245         memcpy((char*)Buffer + ofs, pack->Data, Length - ofs);
246
247         // Free cached packet
248         free(pack);
249         
250         return Length;
251 }
252
253 /**
254  * \brief Write to the channel file (send a packet)
255  */
256 size_t UDP_Channel_Write(tVFS_Node *Node, off_t Offset, size_t Length, const void *Buffer, Uint Flags)
257 {
258         tUDPChannel     *chan = Node->ImplPtr;
259         const tUDPEndpoint      *ep;
260         const void      *data;
261          int    ofs;
262         
263         if(chan->LocalPort == 0) {
264                 Log_Notice("UDP", "Write to channel %p with zero local port", chan);
265                 return 0;
266         }
267         
268         ep = Buffer;    
269         ofs = 2 + 2 + IPStack_GetAddressSize( ep->AddrType );
270
271         data = (const char *)Buffer + ofs;
272
273         UDP_SendPacketTo(chan, ep->AddrType, &ep->Addr, ep->Port, data, (size_t)Length - ofs);
274         
275         return Length;
276 }
277
278 /**
279  * \brief Names for channel IOCtl Calls
280  */
281 static const char *casIOCtls_Channel[] = {
282         DRV_IOCTLNAMES,
283         "getset_localport",
284         "getset_remoteport",
285         "getset_remotemask",
286         "set_remoteaddr",
287         NULL
288         };
289 /**
290  * \brief Channel IOCtls
291  */
292 int UDP_Channel_IOCtl(tVFS_Node *Node, int ID, void *Data)
293 {
294         tUDPChannel     *chan = Node->ImplPtr;
295         ENTER("pNode iID pData", Node, ID, Data);
296         switch(ID)
297         {
298         BASE_IOCTLS(DRV_TYPE_MISC, "UDP Channel", 0x100, casIOCtls_Channel);
299         
300         case 4: { // getset_localport (returns bool success)
301                 if(!Data)       LEAVE_RET('i', chan->LocalPort);
302                 if(!CheckMem( Data, sizeof(Uint16) ) ) {
303                         LOG("Invalid pointer %p", Data);
304                         LEAVE_RET('i', -1);
305                 }
306                 // Set port
307                 int req_port = *(Uint16*)Data;
308                 // Permissions check (Ports lower than 1024 are root-only)
309                 if(req_port != 0 && req_port < 1024) {
310                         if( Threads_GetUID() != 0 ) {
311                                 LOG("Attempt by non-superuser to listen on port %i", req_port);
312                                 LEAVE_RET('i', -1);
313                         }
314                 }
315                 // Allocate a random port if requested
316                 if( req_port == 0 )
317                         UDP_int_AllocatePort(chan);
318                 // Else, mark the requested port as used
319                 else if( UDP_int_ClaimPort(chan, req_port) ) {
320                         LOG("Port %i is currently in use", req_port);
321                         LEAVE_RET('i', 0);
322                 }
323                 LEAVE_RET('i', chan->LocalPort);
324                 }
325         
326         case 5: // getset_remoteport (returns bool success)
327                 if(!Data)       LEAVE_RET('i', chan->Remote.Port);
328                 if(!CheckMem( Data, sizeof(Uint16) ) ) {
329                         LOG("Invalid pointer %p", Data);
330                         LEAVE_RET('i', -1);
331                 }
332                 chan->Remote.Port = *(Uint16*)Data;
333                 LEAVE('i', chan->Remote.Port);
334                 return chan->Remote.Port;
335         
336         case 6: // getset_remotemask (returns bool success)
337                 if(!Data)       LEAVE_RET('i', chan->RemoteMask);
338                 if(!CheckMem(Data, sizeof(int)))        LEAVE_RET('i', -1);
339                 if( !chan->Interface ) {
340                         LOG("Can't set remote mask on NULL interface");
341                         LEAVE_RET('i', -1);
342                 }
343                 if( *(int*)Data > IPStack_GetAddressSize(chan->Interface->Type) )
344                         LEAVE_RET('i', -1);
345                 chan->RemoteMask = *(int*)Data;
346                 LEAVE('i', chan->RemoteMask);
347                 return chan->RemoteMask;        
348
349         case 7: // set_remoteaddr (returns bool success)
350                 if( !chan->Interface ) {
351                         LOG("Can't set remote address on NULL interface");
352                         LEAVE_RET('i', -1);
353                 }
354                 if(!CheckMem(Data, IPStack_GetAddressSize(chan->Interface->Type))) {
355                         LOG("Invalid pointer");
356                         LEAVE_RET('i', -1);
357                 }
358                 memcpy(&chan->Remote.Addr, Data, IPStack_GetAddressSize(chan->Interface->Type));
359                 LEAVE('i', 0);
360                 return 0;
361         }
362         LEAVE_RET('i', 0);
363 }
364
365 /**
366  * \brief Close and destroy an open channel
367  */
368 void UDP_Channel_Close(tVFS_Node *Node)
369 {
370         tUDPChannel     *chan = Node->ImplPtr;
371         tUDPChannel     *prev;
372         
373         // Remove from the main list first
374         Mutex_Acquire(&glUDP_Channels);
375         if(gpUDP_Channels == chan)
376                 gpUDP_Channels = gpUDP_Channels->Next;
377         else
378         {
379                 for(prev = gpUDP_Channels;
380                         prev->Next && prev->Next != chan;
381                         prev = prev->Next);
382                 if(!prev->Next)
383                         Log_Warning("UDP", "Bookeeping Fail, channel %p is not in main list", chan);
384                 else
385                         prev->Next = prev->Next->Next;
386         }
387         Mutex_Release(&glUDP_Channels);
388         
389         // Clear Queue
390         SHORTLOCK(&chan->lQueue);
391         while(chan->Queue)
392         {
393                 tUDPPacket      *tmp;
394                 tmp = chan->Queue;
395                 chan->Queue = tmp->Next;
396                 free(tmp);
397         }
398         SHORTREL(&chan->lQueue);
399         
400         // Free channel structure
401         free(chan);
402 }
403
404 /**
405  * \return Port Number on success, or zero on failure
406  */
407 Uint16 UDP_int_AllocatePort(tUDPChannel *Channel)
408 {
409         Mutex_Acquire(&glUDP_Ports);
410          int    i;
411         // Fast Search
412         for( int base = UDP_ALLOC_BASE; base < 0x10000; base += 32 )
413         {
414                 if( gUDP_Ports[base/32] == 0xFFFFFFFF )
415                         continue ;
416                 for( int i = 0; i < 32; i++ )
417                 {
418                         if( gUDP_Ports[base/32] & (1 << i) )
419                                 continue ;
420                         gUDP_Ports[base/32] |= (1 << i);
421                         Mutex_Release(&glUDP_Ports);
422                         // If claim succeeds, good
423                         if( UDP_int_ClaimPort(Channel, base + i) == 0 )
424                                 return base + i;
425                         // otherwise keep looking
426                         Mutex_Acquire(&glUDP_Ports);
427                         break;
428                 }
429         }
430         Mutex_Release(&glUDP_Ports);
431         return 0;
432 }
433
434 /**
435  * \brief Allocate a specific port
436  * \return Boolean Success
437  */
438 int UDP_int_ClaimPort(tUDPChannel *Channel, Uint16 Port)
439 {
440         // Search channel list for a connection with same (or wildcard)
441         // interface, and same port
442         Mutex_Acquire(&glUDP_Channels);
443         for( tUDPChannel *ch = gpUDP_Channels; ch; ch = ch->Next)
444         {
445                 if( ch == Channel )
446                         continue ;
447                 if( ch->Interface && ch->Interface != Channel->Interface )
448                         continue ;
449                 if( ch->LocalPort != Port )
450                         continue ;
451                 Mutex_Release(&glUDP_Channels);
452                 return 1;
453         }
454         Channel->LocalPort = Port;
455         Mutex_Release(&glUDP_Channels);
456         return 0;
457 }
458
459 /**
460  * \brief Free an allocated port
461  */
462 void UDP_int_FreePort(Uint16 Port)
463 {
464         Mutex_Acquire(&glUDP_Ports);
465         gUDP_Ports[Port/32] &= ~(1 << (Port%32));
466         Mutex_Release(&glUDP_Ports);
467 }

UCC git Repository :: git.ucc.asn.au