Networking - DNS resolution semi-working
[tpg/acess2.git] / KernelLand / Modules / IPStack / udp.c
index 8c5c6ca..b9dd323 100644 (file)
@@ -33,12 +33,14 @@ Uint16      UDP_int_FinaliseChecksum(Uint16 Value);
 
 // === GLOBALS ===
 tVFS_NodeType  gUDP_NodeType = {
+       .TypeName = "UDP",
+       .Flags = VFS_NODETYPEFLAG_STREAM,
        .Read = UDP_Channel_Read,
        .Write = UDP_Channel_Write,
        .IOCtl = UDP_Channel_IOCtl,
        .Close = UDP_Channel_Close
 };
-tMutex glUDP_Channels;
+tMutex glUDP_Channels; // TODO: Replace with a RWLock
 tUDPChannel    *gpUDP_Channels;
 
 tMutex glUDP_Ports;
@@ -65,13 +67,15 @@ void UDP_Initialise()
 int UDP_int_ScanList(tUDPChannel *List, tInterface *Interface, void *Address, int Length, void *Buffer)
 {
        tUDPHeader      *hdr = Buffer;
-       tUDPChannel     *chan;
-       tUDPPacket      *pack;
-        int    len;
        
-       for(chan = List; chan; chan = chan->Next)
+       for(tUDPChannel *chan = List; chan; chan = chan->Next)
        {
                // Match local endpoint
+               LOG("(%p):%i - %s/%i:%i",
+                       chan->Interface, chan->LocalPort,
+                       IPStack_PrintAddress(chan->Remote.AddrType, &chan->Remote.Addr), chan->RemoteMask,
+                       chan->Remote.Port
+                       );
                if(chan->Interface && chan->Interface != Interface)     continue;
                if(chan->LocalPort != ntohs(hdr->DestPort))     continue;
                
@@ -91,8 +95,8 @@ int UDP_int_ScanList(tUDPChannel *List, tInterface *Interface, void *Address, in
                
                Log_Log("UDP", "Recieved packet for %p", chan);
                // Create the cached packet
-               len = ntohs(hdr->Length);
-               pack = malloc(sizeof(tUDPPacket) + len);
+               int len = ntohs(hdr->Length);
+               tUDPPacket *pack = malloc(sizeof(tUDPPacket) + len);
                pack->Next = NULL;
                memcpy(&pack->Remote.Addr, Address, IPStack_GetAddressSize(Interface->Type));
                pack->Remote.Port = ntohs(hdr->SourcePort);
@@ -157,7 +161,11 @@ void UDP_SendPacketTo(tUDPChannel *Channel, int AddrType, const void *Address, U
 {
        tUDPHeader      hdr;
 
-       if(Channel->Interface && Channel->Interface->Type != AddrType)  return ;
+       if(Channel->Interface && Channel->Interface->Type != AddrType) {
+               LOG("Bad interface type for channel packet, IF is %i, but packet is %i",
+                       Channel->Interface->Type, AddrType);
+               return ;
+       }
        
        // Create the packet
        hdr.SourcePort = htons( Channel->LocalPort );
@@ -175,6 +183,7 @@ void UDP_SendPacketTo(tUDPChannel *Channel, int AddrType, const void *Address, U
                IPStack_Buffer_AppendSubBuffer(buffer, Length, 0, Data, NULL, NULL);
                IPStack_Buffer_AppendSubBuffer(buffer, sizeof(hdr), 0, &hdr, NULL, NULL);
                // TODO: What if Channel->Interface is NULL here?
+               ASSERT(Channel->Interface);
                IPv4_SendPacket(Channel->Interface, *(tIPv4*)Address, IP4PROT_UDP, 0, buffer);
                break;
        default:
@@ -189,6 +198,7 @@ tVFS_Node *UDP_Channel_Init(tInterface *Interface)
        tUDPChannel     *new;
        new = calloc( sizeof(tUDPChannel), 1 );
        new->Interface = Interface;
+       new->Node.Size = -1;
        new->Node.ImplPtr = new;
        new->Node.NumACLs = 1;
        new->Node.ACLs = &gVFS_ACL_EveryoneRW;
@@ -202,49 +212,59 @@ tVFS_Node *UDP_Channel_Init(tInterface *Interface)
        return &new->Node;
 }
 
-/**
- * \brief Read from the channel file (wait for a packet)
- */
-size_t UDP_Channel_Read(tVFS_Node *Node, off_t Offset, size_t Length, void *Buffer, Uint Flags)
+tUDPPacket *UDP_Channel_WaitForPacket(tUDPChannel *chan, Uint VFSFlags)
 {
-       tUDPChannel     *chan = Node->ImplPtr;
-       tUDPPacket      *pack;
-       tUDPEndpoint    *ep;
-        int    ofs, addrlen;
-       
-       if(chan->LocalPort == 0) {
-               Log_Notice("UDP", "Channel %p sent with no local port", chan);
-               return 0;
-       }
-       
-       while(chan->Queue == NULL)      Threads_Yield();
+       // EVIL - Yield until queue is created (avoids races)
+       while(chan->Queue == NULL)
+               Threads_Yield();
        
        for(;;)
        {
-               tTime   timeout_z = 0, *timeout = (Flags & VFS_IOFLAG_NOBLOCK) ? &timeout_z : NULL;
-               int rv = VFS_SelectNode(Node, VFS_SELECT_READ, timeout, "UDP_Channel_Read");
-               if( rv ) {
-                       errno = (Flags & VFS_IOFLAG_NOBLOCK) ? EWOULDBLOCK : EINTR;
+               tTime   timeout_z = 0, *timeout = (VFSFlags & VFS_IOFLAG_NOBLOCK) ? &timeout_z : NULL;
+               int rv = VFS_SelectNode(&chan->Node, VFS_SELECT_READ, timeout, "UDP_Channel_Read");
+               if( rv == 0 ) {
+                       errno = (VFSFlags & VFS_IOFLAG_NOBLOCK) ? EWOULDBLOCK : EINTR;
+                       return NULL;
                }
                SHORTLOCK(&chan->lQueue);
                if(chan->Queue == NULL) {
                        SHORTREL(&chan->lQueue);
                        continue;
                }
-               pack = chan->Queue;
+               tUDPPacket *pack = chan->Queue;
                chan->Queue = pack->Next;
                if(!chan->Queue) {
                        chan->QueueEnd = NULL;
-                       VFS_MarkAvaliable(Node, 0);     // Nothing left
+                       VFS_MarkAvaliable(&chan->Node, 0);      // Nothing left
                }
                SHORTREL(&chan->lQueue);
-               break;
+               return pack;
+       }
+       // Unreachable
+}
+
+/**
+ * \brief Read from the channel file (wait for a packet)
+ */
+size_t UDP_Channel_Read(tVFS_Node *Node, off_t Offset, size_t Length, void *Buffer, Uint Flags)
+{
+       tUDPChannel     *chan = Node->ImplPtr;
+       
+       if(chan->LocalPort == 0) {
+               Log_Notice("UDP", "Channel %p sent with no local port", chan);
+               return 0;
+       }
+       
+       tUDPPacket      *pack = UDP_Channel_WaitForPacket(chan, Flags);
+       if( !pack ) {
+               return 0;
        }
 
+       size_t addrlen = IPStack_GetAddressSize(pack->Remote.AddrType);
+       tUDPEndpoint *ep = Buffer;
+       size_t ofs = 4 + addrlen;
+       
        // Check that the header fits
-       addrlen = IPStack_GetAddressSize(pack->Remote.AddrType);
-       ep = Buffer;
-       ofs = 4 + addrlen;
        if(Length < ofs) {
                free(pack);
                Log_Notice("UDP", "Insuficient space for header in buffer (%i < %i)", (int)Length, ofs);
@@ -300,6 +320,8 @@ static const char *casIOCtls_Channel[] = {
        "getset_remoteport",
        "getset_remotemask",
        "set_remoteaddr",
+       "sendto",
+       "recvfrom",
        NULL
        };
 /**
@@ -351,14 +373,21 @@ int UDP_Channel_IOCtl(tVFS_Node *Node, int ID, void *Data)
        
        case 6: // getset_remotemask (returns bool success)
                if(!Data)       LEAVE_RET('i', chan->RemoteMask);
-               if(!CheckMem(Data, sizeof(int)))        LEAVE_RET('i', -1);
+               if(!CheckMem(Data, sizeof(int))) {
+                       LOG("Data pointer invalid");
+                       LEAVE_RET('i', -1);
+               }
                if( !chan->Interface ) {
                        LOG("Can't set remote mask on NULL interface");
                        LEAVE_RET('i', -1);
                }
-               if( *(int*)Data > IPStack_GetAddressSize(chan->Interface->Type) )
+                int    mask = *(int*)Data;
+                int    addr_bits = IPStack_GetAddressSize(chan->Interface->Type) * 8;
+               if( mask > addr_bits ) {
+                       LOG("Mask too large (%i > max %i)", mask, addr_bits);
                        LEAVE_RET('i', -1);
-               chan->RemoteMask = *(int*)Data;
+               }
+               chan->RemoteMask = mask;
                LEAVE('i', chan->RemoteMask);
                return chan->RemoteMask;        
 
@@ -371,9 +400,73 @@ int UDP_Channel_IOCtl(tVFS_Node *Node, int ID, void *Data)
                        LOG("Invalid pointer");
                        LEAVE_RET('i', -1);
                }
+               LOG("Set remote addr %s", IPStack_PrintAddress(chan->Interface->Type, Data));
+               chan->Remote.AddrType = chan->Interface->Type;
                memcpy(&chan->Remote.Addr, Data, IPStack_GetAddressSize(chan->Interface->Type));
                LEAVE('i', 0);
                return 0;
+       case 8: {       // sendto
+               if(!CheckMem(Data, 2*sizeof(void*)+2)) {
+                       LOG("Data pointer invalid");
+                       LEAVE_RET('i', -1);
+               }
+               const struct sSendToArgs {
+                       const tUDPEndpoint* ep;
+                       const void* buf;
+                       const Uint16 buflen;
+               } info = *(const struct sSendToArgs*)Data;
+               LOG("sendto(buf=%p + %u, ep=%p)", info.buf, info.buflen, info.ep);
+               if(!CheckMem(info.ep, 2+2) || !CheckMem(info.ep, 2+2+IPStack_GetAddressSize(info.ep->AddrType)) ) {
+                       LEAVE_RET('i', -1);
+               }
+               if(!CheckMem(info.buf, info.buflen)) {
+                       LEAVE_RET('i', -1);
+               }
+               
+               UDP_SendPacketTo(chan, info.ep->AddrType, &info.ep->Addr, info.ep->Port,
+                       info.buf, (size_t)info.buflen);
+               
+               LEAVE_RET('i', info.buflen); }
+       case 9: {       // recvfrom
+               if(!CheckMem(Data, 2*sizeof(void*)+2)) {
+                       LOG("Data pointer invalid");
+                       LEAVE_RET('i', -1);
+               }
+               const struct sRecvFromArgs {
+                       tUDPEndpoint* ep;
+                       void* buf;
+                       Uint16 buflen;
+               } info = *(const struct sRecvFromArgs*)Data;
+               LOG("recvfrom(buf=%p + %u, ep=%p)", info.buf, info.buflen, info.ep);
+               if(!CheckMem(info.ep, 2+2)) {
+                       LEAVE_RET('i', -1);
+               }
+               if(!CheckMem(info.buf, info.buflen)) {
+                       LEAVE_RET('i', -1);
+               }
+               
+               tUDPPacket      *pack = UDP_Channel_WaitForPacket(chan, 0);
+               if( pack == NULL ) {
+                       LOG("No packet");
+                       LEAVE_RET('i', 0);
+               }
+               
+               size_t  addrsize = IPStack_GetAddressSize(pack->Remote.AddrType);
+               if( !CheckMem(info.ep, 2+2+addrsize) ) {
+                       LOG("Insufficient space for source address");
+                       free(pack);
+                       LEAVE_RET('i', -1);
+               }
+               info.ep->Port = pack->Remote.Port;
+               info.ep->AddrType = pack->Remote.AddrType;
+               memcpy(&info.ep->Addr, &pack->Remote.Addr, addrsize);
+               
+               size_t  retlen = (info.buflen < pack->Length ? info.buflen : pack->Length);
+               memcpy(info.buf, pack->Data, retlen);
+
+               free(pack);
+       
+               LEAVE_RET('i', retlen); }
        }
        LEAVE_RET('i', 0);
 }

UCC git Repository :: git.ucc.asn.au