Modules/IPStack - Add ICMPv6 (not tested), fix TCP packet caching
[tpg/acess2.git] / KernelLand / Modules / IPStack / tcp.c
index 8f16da8..ac58ec3 100644 (file)
@@ -2,7 +2,7 @@
  * Acess2 IP Stack
  * - TCP Handling
  */
-#define DEBUG  0
+#define DEBUG  1
 #include "ipstack.h"
 #include "ipv4.h"
 #include "ipv6.h"
@@ -52,6 +52,7 @@ size_t        TCP_Client_Write(tVFS_Node *Node, off_t Offset, size_t Length, const void
 void   TCP_Client_Close(tVFS_Node *Node);
 // --- Helpers
  int   WrapBetween(Uint32 Lower, Uint32 Value, Uint32 Higher, Uint32 MaxValue);
+Uint32 GetRelative(Uint32 Base, Uint32 Value);
 
 // === TEMPLATES ===
 tSocketFile    gTCP_ServerFile = {NULL, "tcps", TCP_Server_Init};
@@ -119,7 +120,7 @@ void TCP_int_SendPacket(tInterface *Interface, const void *Dest, tTCPHeader *Hea
 
        LOG("Sending %i+%i to %s:%i", sizeof(*Header), Length,
                IPStack_PrintAddress(Interface->Type, Dest),
-               ntohs(Header->RemotePort)
+               ntohs(Header->DestPort)
                );
 
        Header->Checksum = 0;
@@ -289,6 +290,7 @@ void TCP_GetPacket(tInterface *Interface, void *Address, int Length, void *Buffe
                conn->NextSequenceRcv = ntohl( hdr->SequenceNumber ) + 1;
                conn->HighestSequenceRcvd = conn->NextSequenceRcv;
                conn->NextSequenceSend = rand();
+               conn->LastACKSequence = ntohl( hdr->SequenceNumber );
                
                conn->Node.ImplInt = srv->NextID ++;
                
@@ -516,7 +518,9 @@ void TCP_INT_HandleConnectionPacket(tTCPConnection *Connection, tTCPHeader *Head
 
                        #if 1
                        // - Only send an ACK if we've had a burst
-                       if( Connection->NextSequenceRcv > (Uint32)(TCP_DACK_THRESHOLD + Connection->LastACKSequence) )
+                       Uint32  bytes_since_last_ack = Connection->NextSequenceRcv - Connection->LastACKSequence;
+                       LOG("bytes_since_last_ack = 0x%x", bytes_since_last_ack);
+                       if( bytes_since_last_ack > TCP_DACK_THRESHOLD )
                        {
                                TCP_INT_SendACK(Connection, "DACK Burst");
                                // - Extend TCP deferred ACK timer
@@ -529,14 +533,13 @@ void TCP_INT_HandleConnectionPacket(tTCPConnection *Connection, tTCPHeader *Head
                        #endif
                }
                // Check if the packet is in window
-               else if( WrapBetween(Connection->NextSequenceRcv, sequence_num,
-                               Connection->NextSequenceRcv+TCP_WINDOW_SIZE, 0xFFFFFFFF) )
+               else if( sequence_num - Connection->NextSequenceRcv < TCP_WINDOW_SIZE )
                {
                        Uint8   *dataptr = (Uint8*)Header + (Header->DataOffset>>4)*4;
-                       #if CACHE_FUTURE_PACKETS_IN_BYTES
-                       Uint32  index;
-                       
-                       index = sequence_num % TCP_WINDOW_SIZE;
+                       Uint32  index = sequence_num % TCP_WINDOW_SIZE;
+                       Uint32  max = Connection->NextSequenceRcv % TCP_WINDOW_SIZE;
+                       if( !(Connection->FuturePacketValidBytes[index/8] & (1 << (index%8))) )
+                               TCP_INT_SendACK(Connection, "Lost packet");
                        for( int i = 0; i < dataLen; i ++ )
                        {
                                Connection->FuturePacketValidBytes[index/8] |= 1 << (index%8);
@@ -544,52 +547,15 @@ void TCP_INT_HandleConnectionPacket(tTCPConnection *Connection, tTCPHeader *Head
                                // Do a wrap increment
                                index ++;
                                if(index == TCP_WINDOW_SIZE)    index = 0;
+                               if(index == max)        break;
                        }
-                       #else
-                       tTCPStoredPacket        *pkt, *tmp, *prev = NULL;
-                       
-                       // Allocate and fill cached packet
-                       pkt = malloc( sizeof(tTCPStoredPacket) + dataLen );
-                       pkt->Next = NULL;
-                       pkt->Sequence = ntohl(Header->SequenceNumber);
-                       pkt->Length = dataLen;
-                       memcpy(pkt->Data, dataptr, dataLen);
-                       
-                       Log_Log("TCP", "We missed a packet, caching",
-                               pkt->Sequence, Connection->NextSequenceRcv);
-                       
-                       // No? Well, let's cache it and look at it later
-                       SHORTLOCK( &Connection->lFuturePackets );
-                       for(tmp = Connection->FuturePackets;
-                               tmp;
-                               prev = tmp, tmp = tmp->Next)
-                       {
-                               if(tmp->Sequence >= pkt->Sequence)      break;
-                       }
-                       
-                       // Add if before first, or sequences don't match 
-                       if( !tmp || tmp->Sequence != pkt->Sequence )
-                       {
-                               if(prev)
-                                       prev->Next = pkt;
-                               else
-                                       Connection->FuturePackets = pkt;
-                               pkt->Next = tmp;
-                       }
-                       // Replace if larger
-                       else if(pkt->Length > tmp->Length)
-                       {
-                               if(prev)
-                                       prev->Next = pkt;
-                               pkt->Next = tmp->Next;
-                               free(tmp);
-                       }
-                       else
+                       Uint32  rel_highest = Connection->HighestSequenceRcvd - Connection->NextSequenceRcv;
+                       Uint32  rel_this = index - Connection->NextSequenceRcv;
+                       LOG("Updating highest this(0x%x) > highest(%x)", rel_this, rel_highest);
+                       if( rel_this > rel_highest )
                        {
-                               free(pkt);      // TODO: Find some way to remove this
+                               Connection->HighestSequenceRcvd = index;
                        }
-                       SHORTREL( &Connection->lFuturePackets );
-                       #endif
                }
                // Badly out of sequence packet
                else
@@ -716,17 +682,18 @@ int TCP_INT_AppendRecieved(tTCPConnection *Connection, const void *Data, size_t
  */
 void TCP_INT_UpdateRecievedFromFuture(tTCPConnection *Connection)
 {
-       #if CACHE_FUTURE_PACKETS_IN_BYTES
        // Calculate length of contiguous bytes
-        int    length = Connection->HighestSequenceRcvd - Connection->NextSequenceRcv;
+       const int       length = Connection->HighestSequenceRcvd - Connection->NextSequenceRcv;
        Uint32  index = Connection->NextSequenceRcv % TCP_WINDOW_SIZE;
-       LOG("length=%i, index=%i", length, index);
+       size_t  runlength = length;
+       LOG("length=%i, index=0x%x", length, index);
        for( int i = 0; i < length; i ++ )
        {
                 int    bit = index % 8;
                Uint8   bitfield_byte = Connection->FuturePacketValidBytes[index / 8];
                if( (bitfield_byte & (1 << bit)) == 0 ) {
-                       length = i;
+                       runlength = i;
+                       LOG("Hit missing, break");
                        break;
                }
 
@@ -743,90 +710,51 @@ void TCP_INT_UpdateRecievedFromFuture(tTCPConnection *Connection)
        }
        
        index = Connection->NextSequenceRcv % TCP_WINDOW_SIZE;
+       Connection->NextSequenceRcv += runlength;
        
        // Write data to to the ring buffer
-       if( TCP_WINDOW_SIZE - index > length )
+       if( TCP_WINDOW_SIZE - index > runlength )
        {
                // Simple case
-               RingBuffer_Write( Connection->RecievedBuffer, Connection->FuturePacketData + index, length );
+               RingBuffer_Write( Connection->RecievedBuffer, Connection->FuturePacketData + index, runlength );
        }
        else
        {
                 int    endLen = TCP_WINDOW_SIZE - index;
                // 2-part case
                RingBuffer_Write( Connection->RecievedBuffer, Connection->FuturePacketData + index, endLen );
-               RingBuffer_Write( Connection->RecievedBuffer, Connection->FuturePacketData, endLen - length );
+               RingBuffer_Write( Connection->RecievedBuffer, Connection->FuturePacketData, endLen - runlength );
        }
        
        // Mark (now saved) bytes as invalid
        // - Align index
-       while(index % 8 && length > 0)
+       while(index % 8 && runlength > 0)
        {
                Connection->FuturePacketData[index] = 0;
                Connection->FuturePacketValidBytes[index/8] &= ~(1 << (index%8));
                index ++;
                if(index > TCP_WINDOW_SIZE)
                        index -= TCP_WINDOW_SIZE;
-               length --;
+               runlength --;
        }
-       while( length > 7 )
+       while( runlength > 7 )
        {
                Connection->FuturePacketData[index] = 0;
                Connection->FuturePacketValidBytes[index/8] = 0;
-               length -= 8;
+               runlength -= 8;
                index += 8;
                if(index > TCP_WINDOW_SIZE)
                        index -= TCP_WINDOW_SIZE;
        }
-       while(length)
+       while( runlength > 0)
        {
                Connection->FuturePacketData[index] = 0;
                Connection->FuturePacketData[index/8] &= ~(1 << (index%8));
                index ++;
                if(index > TCP_WINDOW_SIZE)
                        index -= TCP_WINDOW_SIZE;
-               length --;
-       }
-       
-       #else
-       tTCPStoredPacket        *pkt;
-       for(;;)
-       {
-               SHORTLOCK( &Connection->lFuturePackets );
-               
-               // Clear out duplicates from cache
-               // - If a packet has just been recieved, and it is expected, then
-               //   (since NextSequenceRcv = rcvd->Sequence + rcvd->Length) all
-               //   packets in cache that are smaller than the next expected
-               //   are now defunct.
-               pkt = Connection->FuturePackets;
-               while(pkt && pkt->Sequence < Connection->NextSequenceRcv)
-               {
-                       tTCPStoredPacket        *next = pkt->Next;
-                       free(pkt);
-                       pkt = next;
-               }
-               
-               // If there's no packets left in cache, stop looking
-               if(!pkt || pkt->Sequence > Connection->NextSequenceRcv) {
-                       SHORTREL( &Connection->lFuturePackets );
-                       return;
-               }
-               
-               // Delete packet from future list
-               Connection->FuturePackets = pkt->Next;
-               
-               // Release list
-               SHORTREL( &Connection->lFuturePackets );
-               
-               // Looks like we found one
-               TCP_INT_AppendRecieved(Connection, pkt->Data, pkt->Length);
-               if( Connection->HighestSequenceRcvd == Connection->NextSequenceRcv )
-                       Connection->HighestSequenceRcvd += pkt->Length;
-               Connection->NextSequenceRcv += pkt->Length;
-               free(pkt);
+               runlength --;
        }
-       #endif
 }
 
 void TCP_int_SendDelayedACK(void *ConnPtr)
@@ -1259,6 +1187,8 @@ void TCP_INT_SendDataPacket(tTCPConnection *Connection, size_t Length, const voi
 
        // - Stop Delayed ACK timer (as this data packet ACKs)
        Time_RemoveTimer(Connection->DeferredACKTimer);
+
+       // TODO: Don't exceed window size
        
        packet->SourcePort = htons(Connection->LocalPort);
        packet->DestPort = htons(Connection->RemotePort);
@@ -1534,3 +1464,10 @@ int WrapBetween(Uint32 Lower, Uint32 Value, Uint32 Higher, Uint32 MaxValue)
        
        return 0;
 }
+Uint32 GetRelative(Uint32 Base, Uint32 Value)
+{
+       if( Value < Base )
+               return Value - Base + 0xFFFFFFFF;
+       else
+               return Value - Base;
+}

UCC git Repository :: git.ucc.asn.au