67f6056128718099cef4cddc6316ae11a1e5f9e2
[tpg/acess2.git] / KernelLand / Modules / IPStack / udp.c
1 /*
2  * Acess2 IP Stack
3  * - By John Hodge (thePowersGang)
4  *
5  * udp.c
6  * - UDP Protocol handling
7  */
8 #define DEBUG   1
9 #include "ipstack.h"
10 #include <api_drv_common.h>
11 #include "udp.h"
12
13 #define UDP_ALLOC_BASE  0xC000
14
15 // === PROTOTYPES ===
16 void    UDP_Initialise();
17 void    UDP_GetPacket(tInterface *Interface, void *Address, int Length, void *Buffer);
18 void    UDP_Unreachable(tInterface *Interface, int Code, void *Address, int Length, void *Buffer);
19 void    UDP_SendPacketTo(tUDPChannel *Channel, int AddrType, const void *Address, Uint16 Port, const void *Data, size_t Length);
20 // --- Client Channels
21 tVFS_Node       *UDP_Channel_Init(tInterface *Interface);
22 size_t  UDP_Channel_Read(tVFS_Node *Node, off_t Offset, size_t Length, void *Buffer, Uint Flags);
23 size_t  UDP_Channel_Write(tVFS_Node *Node, off_t Offset, size_t Length, const void *Buffer, Uint Flags);
24  int    UDP_Channel_IOCtl(tVFS_Node *Node, int ID, void *Data);
25 void    UDP_Channel_Close(tVFS_Node *Node);
26 // --- Helpers
27 Uint16  UDP_int_AllocatePort(tUDPChannel *Channel);
28  int    UDP_int_ClaimPort(tUDPChannel *Channel, Uint16 Port);
29 void    UDP_int_FreePort(Uint16 Port);
30 Uint16  UDP_int_MakeChecksum(tInterface *Iface, const void *Dest, tUDPHeader *Hdr, size_t Len, const void *Data); 
31 Uint16  UDP_int_PartialChecksum(Uint16 Prev, size_t Len, const void *Data);
32 Uint16  UDP_int_FinaliseChecksum(Uint16 Value);
33
34 // === GLOBALS ===
35 tVFS_NodeType   gUDP_NodeType = {
36         .TypeName = "UDP",
37         .Flags = VFS_NODETYPEFLAG_STREAM,
38         .Read = UDP_Channel_Read,
39         .Write = UDP_Channel_Write,
40         .IOCtl = UDP_Channel_IOCtl,
41         .Close = UDP_Channel_Close
42 };
43 tMutex  glUDP_Channels; // TODO: Replace with a RWLock
44 tUDPChannel     *gpUDP_Channels;
45
46 tMutex  glUDP_Ports;
47 Uint32  gUDP_Ports[0x10000/32];
48
49 tSocketFile     gUDP_SocketFile = {NULL, "udp", UDP_Channel_Init};
50
51 // === CODE ===
52 /**
53  * \fn void TCP_Initialise()
54  * \brief Initialise the TCP Layer
55  */
56 void UDP_Initialise()
57 {
58         IPStack_AddFile(&gUDP_SocketFile);
59         //IPv4_RegisterCallback(IP4PROT_UDP, UDP_GetPacket, UDP_Unreachable);
60         IPv4_RegisterCallback(IP4PROT_UDP, UDP_GetPacket);
61 }
62
63 /**
64  * \brief Scan a list of tUDPChannels and find process the first match
65  * \return 0 if no match was found, -1 on error and 1 if a match was found
66  */
67 int UDP_int_ScanList(tUDPChannel *List, tInterface *Interface, void *Address, int Length, void *Buffer)
68 {
69         tUDPHeader      *hdr = Buffer;
70         tUDPChannel     *chan;
71         tUDPPacket      *pack;
72          int    len;
73         
74         for(chan = List; chan; chan = chan->Next)
75         {
76                 // Match local endpoint
77                 if(chan->Interface && chan->Interface != Interface)     continue;
78                 if(chan->LocalPort != ntohs(hdr->DestPort))     continue;
79                 
80                 // Check for remote port restriction
81                 if(chan->Remote.Port && chan->Remote.Port != ntohs(hdr->SourcePort))
82                         continue;
83                 // Check for remote address restriction
84                 if(chan->RemoteMask)
85                 {
86                         if(chan->Remote.AddrType != Interface->Type)
87                                 continue;
88                         if(!IPStack_CompareAddress(Interface->Type, Address,
89                                 &chan->Remote.Addr, chan->RemoteMask)
90                                 )
91                                 continue;
92                 }
93                 
94                 Log_Log("UDP", "Recieved packet for %p", chan);
95                 // Create the cached packet
96                 len = ntohs(hdr->Length);
97                 pack = malloc(sizeof(tUDPPacket) + len);
98                 pack->Next = NULL;
99                 memcpy(&pack->Remote.Addr, Address, IPStack_GetAddressSize(Interface->Type));
100                 pack->Remote.Port = ntohs(hdr->SourcePort);
101                 pack->Remote.AddrType = Interface->Type;
102                 pack->Length = len;
103                 memcpy(pack->Data, hdr->Data, len);
104                 
105                 // Add the packet to the channel's queue
106                 SHORTLOCK(&chan->lQueue);
107                 if(chan->Queue)
108                         chan->QueueEnd->Next = pack;
109                 else
110                         chan->QueueEnd = chan->Queue = pack;
111                 SHORTREL(&chan->lQueue);
112                 VFS_MarkAvaliable(&chan->Node, 1);
113                 Mutex_Release(&glUDP_Channels);
114                 return 1;
115         }
116         return 0;
117 }
118
119 /**
120  * \fn void UDP_GetPacket(tInterface *Interface, void *Address, int Length, void *Buffer)
121  * \brief Handles a packet from the IP Layer
122  */
123 void UDP_GetPacket(tInterface *Interface, void *Address, int Length, void *Buffer)
124 {
125         tUDPHeader      *hdr = Buffer;
126         
127         #if 1
128         size_t len = strlen( IPStack_PrintAddress(Interface->Type, Address) );
129         char    tmp[len+1];
130         strcpy(tmp, IPStack_PrintAddress(Interface->Type, Address));
131         Log_Debug("UDP", "%i bytes %s:%i -> %s:%i (Cksum 0x%04x)",
132                 ntohs(hdr->Length),
133                 tmp, ntohs(hdr->SourcePort),
134                 IPStack_PrintAddress(Interface->Type, Interface->Address), ntohs(hdr->DestPort),
135                 ntohs(hdr->Checksum));
136         #endif
137         
138         // Check registered connections
139         Mutex_Acquire(&glUDP_Channels);
140         UDP_int_ScanList(gpUDP_Channels, Interface, Address, Length, Buffer);
141         Mutex_Release(&glUDP_Channels);
142 }
143
144 /**
145  * \brief Handle an ICMP Unrechable Error
146  */
147 void UDP_Unreachable(tInterface *Interface, int Code, void *Address, int Length, void *Buffer)
148 {
149         
150 }
151
152 /**
153  * \brief Send a packet
154  * \param Channel       Channel to send the packet from
155  * \param Data  Packet data
156  * \param Length        Length in bytes of packet data
157  */
158 void UDP_SendPacketTo(tUDPChannel *Channel, int AddrType, const void *Address, Uint16 Port, const void *Data, size_t Length)
159 {
160         tUDPHeader      hdr;
161
162         if(Channel->Interface && Channel->Interface->Type != AddrType)  return ;
163         
164         // Create the packet
165         hdr.SourcePort = htons( Channel->LocalPort );
166         hdr.DestPort = htons( Port );
167         hdr.Length = htons( sizeof(tUDPHeader) + Length );
168         hdr.Checksum = 0;
169         hdr.Checksum = htons( UDP_int_MakeChecksum(Channel->Interface, Address, &hdr, Length, Data) );
170         
171         tIPStackBuffer  *buffer;
172         switch(AddrType)
173         {
174         case 4:
175                 // Pass on the the IPv4 Layer
176                 buffer = IPStack_Buffer_CreateBuffer(2 + IPV4_BUFFERS);
177                 IPStack_Buffer_AppendSubBuffer(buffer, Length, 0, Data, NULL, NULL);
178                 IPStack_Buffer_AppendSubBuffer(buffer, sizeof(hdr), 0, &hdr, NULL, NULL);
179                 // TODO: What if Channel->Interface is NULL here?
180                 IPv4_SendPacket(Channel->Interface, *(tIPv4*)Address, IP4PROT_UDP, 0, buffer);
181                 break;
182         default:
183                 Log_Warning("UDP", "TODO: Implement on proto %i", AddrType);
184                 break;
185         }
186 }
187
188 // --- Client Channels
189 tVFS_Node *UDP_Channel_Init(tInterface *Interface)
190 {
191         tUDPChannel     *new;
192         new = calloc( sizeof(tUDPChannel), 1 );
193         new->Interface = Interface;
194         new->Node.Size = -1;
195         new->Node.ImplPtr = new;
196         new->Node.NumACLs = 1;
197         new->Node.ACLs = &gVFS_ACL_EveryoneRW;
198         new->Node.Type = &gUDP_NodeType;
199         
200         Mutex_Acquire(&glUDP_Channels);
201         new->Next = gpUDP_Channels;
202         gpUDP_Channels = new;
203         Mutex_Release(&glUDP_Channels);
204         
205         return &new->Node;
206 }
207
208 /**
209  * \brief Read from the channel file (wait for a packet)
210  */
211 size_t UDP_Channel_Read(tVFS_Node *Node, off_t Offset, size_t Length, void *Buffer, Uint Flags)
212 {
213         tUDPChannel     *chan = Node->ImplPtr;
214         tUDPPacket      *pack;
215         tUDPEndpoint    *ep;
216          int    ofs, addrlen;
217         
218         if(chan->LocalPort == 0) {
219                 Log_Notice("UDP", "Channel %p sent with no local port", chan);
220                 return 0;
221         }
222         
223         while(chan->Queue == NULL)      Threads_Yield();
224         
225         for(;;)
226         {
227                 tTime   timeout_z = 0, *timeout = (Flags & VFS_IOFLAG_NOBLOCK) ? &timeout_z : NULL;
228                 int rv = VFS_SelectNode(Node, VFS_SELECT_READ, timeout, "UDP_Channel_Read");
229                 if( rv ) {
230                         errno = (Flags & VFS_IOFLAG_NOBLOCK) ? EWOULDBLOCK : EINTR;
231                 }
232                 SHORTLOCK(&chan->lQueue);
233                 if(chan->Queue == NULL) {
234                         SHORTREL(&chan->lQueue);
235                         continue;
236                 }
237                 pack = chan->Queue;
238                 chan->Queue = pack->Next;
239                 if(!chan->Queue) {
240                         chan->QueueEnd = NULL;
241                         VFS_MarkAvaliable(Node, 0);     // Nothing left
242                 }
243                 SHORTREL(&chan->lQueue);
244                 break;
245         }
246
247         // Check that the header fits
248         addrlen = IPStack_GetAddressSize(pack->Remote.AddrType);
249         ep = Buffer;
250         ofs = 4 + addrlen;
251         if(Length < ofs) {
252                 free(pack);
253                 Log_Notice("UDP", "Insuficient space for header in buffer (%i < %i)", (int)Length, ofs);
254                 return 0;
255         }
256         
257         // Fill header
258         ep->Port = pack->Remote.Port;
259         ep->AddrType = pack->Remote.AddrType;
260         memcpy(&ep->Addr, &pack->Remote.Addr, addrlen);
261         
262         // Copy packet data
263         if(Length > ofs + pack->Length) Length = ofs + pack->Length;
264         memcpy((char*)Buffer + ofs, pack->Data, Length - ofs);
265
266         // Free cached packet
267         free(pack);
268         
269         return Length;
270 }
271
272 /**
273  * \brief Write to the channel file (send a packet)
274  */
275 size_t UDP_Channel_Write(tVFS_Node *Node, off_t Offset, size_t Length, const void *Buffer, Uint Flags)
276 {
277         tUDPChannel     *chan = Node->ImplPtr;
278         const tUDPEndpoint      *ep;
279         const void      *data;
280          int    ofs;
281         
282         if(chan->LocalPort == 0) {
283                 Log_Notice("UDP", "Write to channel %p with zero local port", chan);
284                 return 0;
285         }
286         
287         ep = Buffer;    
288         ofs = 2 + 2 + IPStack_GetAddressSize( ep->AddrType );
289
290         data = (const char *)Buffer + ofs;
291
292         UDP_SendPacketTo(chan, ep->AddrType, &ep->Addr, ep->Port, data, (size_t)Length - ofs);
293         
294         return Length;
295 }
296
297 /**
298  * \brief Names for channel IOCtl Calls
299  */
300 static const char *casIOCtls_Channel[] = {
301         DRV_IOCTLNAMES,
302         "getset_localport",
303         "getset_remoteport",
304         "getset_remotemask",
305         "set_remoteaddr",
306         NULL
307         };
308 /**
309  * \brief Channel IOCtls
310  */
311 int UDP_Channel_IOCtl(tVFS_Node *Node, int ID, void *Data)
312 {
313         tUDPChannel     *chan = Node->ImplPtr;
314         ENTER("pNode iID pData", Node, ID, Data);
315         switch(ID)
316         {
317         BASE_IOCTLS(DRV_TYPE_MISC, "UDP Channel", 0x100, casIOCtls_Channel);
318         
319         case 4: { // getset_localport (returns bool success)
320                 if(!Data)       LEAVE_RET('i', chan->LocalPort);
321                 if(!CheckMem( Data, sizeof(Uint16) ) ) {
322                         LOG("Invalid pointer %p", Data);
323                         LEAVE_RET('i', -1);
324                 }
325                 // Set port
326                 int req_port = *(Uint16*)Data;
327                 // Permissions check (Ports lower than 1024 are root-only)
328                 if(req_port != 0 && req_port < 1024) {
329                         if( Threads_GetUID() != 0 ) {
330                                 LOG("Attempt by non-superuser to listen on port %i", req_port);
331                                 LEAVE_RET('i', -1);
332                         }
333                 }
334                 // Allocate a random port if requested
335                 if( req_port == 0 )
336                         UDP_int_AllocatePort(chan);
337                 // Else, mark the requested port as used
338                 else if( UDP_int_ClaimPort(chan, req_port) ) {
339                         LOG("Port %i is currently in use", req_port);
340                         LEAVE_RET('i', 0);
341                 }
342                 LEAVE_RET('i', chan->LocalPort);
343                 }
344         
345         case 5: // getset_remoteport (returns bool success)
346                 if(!Data)       LEAVE_RET('i', chan->Remote.Port);
347                 if(!CheckMem( Data, sizeof(Uint16) ) ) {
348                         LOG("Invalid pointer %p", Data);
349                         LEAVE_RET('i', -1);
350                 }
351                 chan->Remote.Port = *(Uint16*)Data;
352                 LEAVE('i', chan->Remote.Port);
353                 return chan->Remote.Port;
354         
355         case 6: // getset_remotemask (returns bool success)
356                 if(!Data)       LEAVE_RET('i', chan->RemoteMask);
357                 if(!CheckMem(Data, sizeof(int)))        LEAVE_RET('i', -1);
358                 if( !chan->Interface ) {
359                         LOG("Can't set remote mask on NULL interface");
360                         LEAVE_RET('i', -1);
361                 }
362                 if( *(int*)Data > IPStack_GetAddressSize(chan->Interface->Type) )
363                         LEAVE_RET('i', -1);
364                 chan->RemoteMask = *(int*)Data;
365                 LEAVE('i', chan->RemoteMask);
366                 return chan->RemoteMask;        
367
368         case 7: // set_remoteaddr (returns bool success)
369                 if( !chan->Interface ) {
370                         LOG("Can't set remote address on NULL interface");
371                         LEAVE_RET('i', -1);
372                 }
373                 if(!CheckMem(Data, IPStack_GetAddressSize(chan->Interface->Type))) {
374                         LOG("Invalid pointer");
375                         LEAVE_RET('i', -1);
376                 }
377                 memcpy(&chan->Remote.Addr, Data, IPStack_GetAddressSize(chan->Interface->Type));
378                 LEAVE('i', 0);
379                 return 0;
380         }
381         LEAVE_RET('i', 0);
382 }
383
384 /**
385  * \brief Close and destroy an open channel
386  */
387 void UDP_Channel_Close(tVFS_Node *Node)
388 {
389         tUDPChannel     *chan = Node->ImplPtr;
390         tUDPChannel     *prev;
391         
392         // Remove from the main list first
393         Mutex_Acquire(&glUDP_Channels);
394         if(gpUDP_Channels == chan)
395                 gpUDP_Channels = gpUDP_Channels->Next;
396         else
397         {
398                 for(prev = gpUDP_Channels;
399                         prev->Next && prev->Next != chan;
400                         prev = prev->Next);
401                 if(!prev->Next)
402                         Log_Warning("UDP", "Bookeeping Fail, channel %p is not in main list", chan);
403                 else
404                         prev->Next = prev->Next->Next;
405         }
406         Mutex_Release(&glUDP_Channels);
407         
408         // Clear Queue
409         SHORTLOCK(&chan->lQueue);
410         while(chan->Queue)
411         {
412                 tUDPPacket      *tmp;
413                 tmp = chan->Queue;
414                 chan->Queue = tmp->Next;
415                 free(tmp);
416         }
417         SHORTREL(&chan->lQueue);
418         
419         // Free channel structure
420         free(chan);
421 }
422
423 /**
424  * \return Port Number on success, or zero on failure
425  */
426 Uint16 UDP_int_AllocatePort(tUDPChannel *Channel)
427 {
428         Mutex_Acquire(&glUDP_Ports);
429         // Fast Search
430         for( int base = UDP_ALLOC_BASE; base < 0x10000; base += 32 )
431         {
432                 if( gUDP_Ports[base/32] == 0xFFFFFFFF )
433                         continue ;
434                 for( int i = 0; i < 32; i++ )
435                 {
436                         if( gUDP_Ports[base/32] & (1 << i) )
437                                 continue ;
438                         gUDP_Ports[base/32] |= (1 << i);
439                         Mutex_Release(&glUDP_Ports);
440                         // If claim succeeds, good
441                         if( UDP_int_ClaimPort(Channel, base + i) == 0 )
442                                 return base + i;
443                         // otherwise keep looking
444                         Mutex_Acquire(&glUDP_Ports);
445                         break;
446                 }
447         }
448         Mutex_Release(&glUDP_Ports);
449         return 0;
450 }
451
452 /**
453  * \brief Allocate a specific port
454  * \return Boolean Success
455  */
456 int UDP_int_ClaimPort(tUDPChannel *Channel, Uint16 Port)
457 {
458         // Search channel list for a connection with same (or wildcard)
459         // interface, and same port
460         Mutex_Acquire(&glUDP_Channels);
461         for( tUDPChannel *ch = gpUDP_Channels; ch; ch = ch->Next)
462         {
463                 if( ch == Channel )
464                         continue ;
465                 if( ch->Interface && ch->Interface != Channel->Interface )
466                         continue ;
467                 if( ch->LocalPort != Port )
468                         continue ;
469                 Mutex_Release(&glUDP_Channels);
470                 return 1;
471         }
472         Channel->LocalPort = Port;
473         Mutex_Release(&glUDP_Channels);
474         return 0;
475 }
476
477 /**
478  * \brief Free an allocated port
479  */
480 void UDP_int_FreePort(Uint16 Port)
481 {
482         Mutex_Acquire(&glUDP_Ports);
483         gUDP_Ports[Port/32] &= ~(1 << (Port%32));
484         Mutex_Release(&glUDP_Ports);
485 }
486
487 /**
488  *
489  */
490 Uint16 UDP_int_MakeChecksum(tInterface *Interface, const void *Dest, tUDPHeader *Hdr, size_t Len, const void *Data)
491 {
492         size_t  addrsize = IPStack_GetAddressSize(Interface->Type);
493         struct {
494                 Uint8   Zeroes;
495                 Uint8   Protocol;
496                 Uint16  UDPLength;
497         } pheader;
498         
499         pheader.Zeroes = 0;
500         switch(Interface->Type)
501         {
502         case 4: pheader.Protocol = IP4PROT_UDP; break;
503         //case 6:       pheader.Protocol = IP6PROT_UDP; break;
504         default:
505                 Log_Warning("UDP", "Unimplemented _MakeChecksum proto %i", Interface->Type);
506                 return 0;
507         }
508         pheader.UDPLength = Hdr->Length;
509         
510         Uint16  csum = 0;
511         csum = UDP_int_PartialChecksum(csum, addrsize, Interface->Address);
512         csum = UDP_int_PartialChecksum(csum, addrsize, Dest);
513         csum = UDP_int_PartialChecksum(csum, sizeof(pheader), &pheader);
514         csum = UDP_int_PartialChecksum(csum, sizeof(tUDPHeader), Hdr);
515         csum = UDP_int_PartialChecksum(csum, Len, Data);
516         
517         return UDP_int_FinaliseChecksum(csum);
518 }
519
520 static inline Uint16 _add_ones_complement16(Uint16 a, Uint16 b)
521 {
522         // One's complement arithmatic, overflows increment bottom bit
523         return a + b + (b > 0xFFFF - a ? 1 : 0);
524 }
525
526 Uint16 UDP_int_PartialChecksum(Uint16 Prev, size_t Len, const void *Data)
527 {
528         Uint16  ret = Prev;
529         const Uint16    *data = Data;
530         for( int i = 0; i < Len/2; i ++ )
531                 ret = _add_ones_complement16(ret, htons(*data++));
532         if( Len % 2 == 1 )
533                 ret = _add_ones_complement16(ret, htons(*(const Uint8*)data));
534         return ret;
535 }
536
537 Uint16 UDP_int_FinaliseChecksum(Uint16 Value)
538 {
539         Value = ~Value; // One's complement it
540         return (Value == 0 ? 0xFFFF : Value);
541 }

UCC git Repository :: git.ucc.asn.au