Usermode/utests - Fix DNS utest, update libc utests to new format (no more EXP files)
[tpg/acess2.git] / Usermode / Libraries / libnet.so_src / dns_proto.c
1 /*
2  */
3
4 #include "include/dns.h"
5 #include "include/dns_int.h"
6 #include <stdint.h>
7 #include <string.h>
8 #include <assert.h>
9
10 // === PROTOTYPES ===
11 extern int DNS_int_ParseRR(const void *buf, size_t ofs, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p);
12
13 static uint16_t get16(const void *buf);
14 static uint32_t get32(const void *buf);
15 static size_t put16(void *buf, uint16_t val);
16
17 // === CODE ===
18 size_t DNS_int_EncodeQuery(void *buf, size_t bufsize, const char *name, enum eTypes type, enum eClass class)
19 {
20         int namelen = DNS_EncodeName(NULL, name);
21         if( namelen >= 256 ) {
22                 _SysDebug("DNS_int_EncodeQuery - ERROR: Name encoded to >= 256 bytes");
23                 return 0;
24         }
25         size_t  pos = 0;
26         uint8_t *packet = buf;
27         if( (6*2) + (namelen + 2*2) > bufsize ) {
28                 _SysDebug("DNS_int_EncodeQuery - ERROR: Passed buffer too small");
29                 return 0;
30         }
31         // - Header
32         pos += put16(packet + pos, 0xAC00);     // Identifier (arbitary)
33         pos += put16(packet + pos, (0 << 0) | (0 << 1) | (1 << 8) );    // Op : Query, Standard, Recursion
34         pos += put16(packet + pos, 1);  // QDCount
35         pos += put16(packet + pos, 0);  // ANCount
36         pos += put16(packet + pos, 0);  // NSCount
37         pos += put16(packet + pos, 0);  // ARCount
38         // - Question
39         pos += DNS_EncodeName(packet + pos, name);
40         pos += put16(packet + pos, type);       // QType
41         pos += put16(packet + pos, class);      // QClass
42         
43         assert(pos <= bufsize);
44         return pos;
45 }
46
47 int DNS_int_ParseResponse(const void* buf, size_t return_len, void *info, handle_record_t* handle_record)
48 {
49         const uint8_t* packet = buf;
50         char    rr_name[256];
51         unsigned int id = get16(packet + 0);
52         if( id != 0xAC00 ) {
53                 _SysDebug("DNS_Query - Packet ID mismatch");
54                 return 2;
55         }
56         unsigned int flags = get16(packet + 2);
57         unsigned int qd_count = get16(packet + 4);
58         unsigned int an_count = get16(packet + 6);
59         unsigned int ns_count = get16(packet + 8);
60         unsigned int ar_count = get16(packet + 10);
61         size_t pos = 6*2;
62         // TODO: Can I safely assert / fail if qd_count is non-zero?
63         // - Questions, ignored
64         for( unsigned int i = 0; i < qd_count; i ++ ) {
65                 int rv = DNS_DecodeName(rr_name, packet, pos, return_len);
66                 if( rv < 0 ) {
67                         _SysDebug("DNS_Query - Parse error in QD");
68                         return 1;
69                 }
70                 pos += rv + 2*2;
71         }
72         // - Answers, pass on to handler
73         for( unsigned int i = 0; i < an_count; i ++ )
74         {
75                 enum eTypes     type;
76                 enum eClass     class;
77                 uint32_t        ttl;
78                 size_t  rdlength;
79                 int rv = DNS_int_ParseRR(packet, pos, return_len, rr_name, &type, &class, &ttl, &rdlength);
80                 if( rv < 0 ) {
81                         _SysDebug("DNS_Query - Parse error in AN");
82                         return 1;
83                 }
84                 pos += rv;
85                 
86                 if( handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos - rdlength) )
87                         return 0;
88         }
89         // Authority Records (should all be NS records)
90         for( unsigned int i = 0; i < ns_count; i ++ )
91         {
92                 size_t  rdlength;
93                 int rv = DNS_int_ParseRR(packet, pos, return_len, rr_name, NULL, NULL, NULL, &rdlength);
94                 if( rv < 0 ) {
95                         _SysDebug("DNS_Query - Parse error in NS");
96                         return 1;
97                 }
98                 pos += rv;
99         }
100         // - Additional records, pass to handler
101         for( unsigned int i = 0; i < ar_count; i ++ )
102         {
103                 enum eTypes     type;
104                 enum eClass     class;
105                 uint32_t        ttl;
106                 size_t  rdlength;
107                 int rv = DNS_int_ParseRR(packet, pos, return_len, rr_name, &type, &class, &ttl, &rdlength);
108                 if( rv < 0 ) {
109                         _SysDebug("DNS_Query - Parse error in AR");
110                         return 1;
111                 }
112                 pos += rv;
113                 
114                 if( handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos - rdlength) )
115                         return 0;
116         }
117         
118         return 0;
119 }
120
121 /// Encode a dotted name as a DNS name
122 size_t  DNS_EncodeName(void *buf, const char *dotted_name)
123 {
124         size_t  ret = 0;
125         const char *str = dotted_name;
126         uint8_t *buf8 = buf;
127         while( *str )
128         {
129                 const char *next = strchr(str, '.');
130                 size_t seg_len = (next ? next - str : strlen(str));
131                 if( seg_len > 63 ) {
132                         // Oops, too long (truncate)
133                         seg_len = 63;
134                 }
135                 if( seg_len == 0 && next != NULL ) {
136                         // '..' encountered, invalid (skip)
137                         str = next+1;
138                         continue ;
139                 }
140                 
141                 if( buf8 )
142                 {
143                         buf8[ret] = seg_len;
144                         memcpy(buf8+ret+1, str, seg_len);
145                 }
146                 ret += 1 + seg_len;
147                 
148                 if( next == NULL ) {
149                         // No trailing '.', assume it's there? Yes, need to be NUL terminated
150                         if(buf8)        buf8[ret] = 0;
151                         ret ++;
152                         break;
153                 }
154                 else {
155                         str = next + 1;
156                 }
157         }
158         return ret;
159 }
160
161 // Decode a name (including trailing . for root)
162 int DNS_DecodeName(char dotted_name[256], const void *buf, size_t ofs, size_t space)
163 {
164         int consumed = 0;
165         int out_pos = 0;
166         const uint8_t *buf8 = (const uint8_t*)buf + ofs;
167         for( ;; )
168         {
169                 if( ofs + consumed + 1 > space ) {
170                         _SysDebug("DNS_DecodeName - Len byte OOR space=%zi", space);
171                         return -1;
172                 }
173                 uint8_t seg_len = *buf8;
174                 buf8 ++;
175                 consumed ++;
176                 // Done
177                 if( seg_len == 0 )
178                         break;
179                 if( (seg_len & 0xC0) == 0xC0 )
180                 {
181                         // Backreference, the rest of the name is a backref
182                         char tmp[256];
183                         int ref_ofs = get16(buf8 - 1) & 0x3FFF;
184                         consumed += 1, buf8 += 1;       // Only one, previous inc still applies
185                         //_SysDebug("DNS_DecodeName - Nested at %i", ref_ofs);
186                         if( DNS_DecodeName(tmp, buf, ref_ofs, space) < 0 )
187                                 return -1;
188                         memcpy(dotted_name+out_pos, tmp, strlen(tmp));
189                         out_pos += strlen(tmp);
190                         break;
191                 }
192                 // Protocol violation (segment too long)
193                 if( seg_len >= 64 ) {
194                         _SysDebug("DNS_DecodeName - Seg too long %i", seg_len);
195                         return -1;
196                 }
197                 // Protocol violation (overflowed end of buffer)
198                 if( ofs + consumed + seg_len > space ) {
199                         _SysDebug("DNS_DecodeName - Seg OOR %i+%i>%zi", consumed, seg_len, space);
200                         return -1;
201                 }
202                 // Protocol violation (name was too long)
203                 if( out_pos + seg_len + 1 > 255 ) {
204                         _SysDebug("DNS_DecodeName - Dotted name too long %i+%i+1 > %i",
205                                 out_pos, seg_len, 255);
206                         return -1;
207                 }
208                 
209                 //_SysDebug("DNS_DecodeName : Seg %i '%.*s'", seg_len, seg_len, buf8);
210                 
211                 // Read segment
212                 memcpy(dotted_name + out_pos, buf8, seg_len);
213                 buf8 += seg_len;
214                 consumed += seg_len;
215                 out_pos += seg_len;
216                 
217                 // Place '.'
218                 dotted_name[out_pos] = '.';
219                 out_pos ++;
220         }
221         dotted_name[out_pos] = '\0';
222         //_SysDebug("DNS_DecodeName - '%s', consumed = %i", dotted_name, consumed);
223         return consumed;
224 }
225
226 // Parse a Resource Record
227 int DNS_int_ParseRR(const void *buf, size_t ofs, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p)
228 {
229         const uint8_t   *buf8 = buf;
230         size_t  consumed = 0;
231         
232         // 1. Name
233         int rv = DNS_DecodeName(name_p, buf, ofs, space);
234         if(rv < 0)      return -1;
235         
236         ofs += rv, consumed += rv;
237         
238         if( type_p )
239                 *type_p = get16(buf8 + ofs);
240         ofs += 2, consumed += 2;
241         
242         if( class_p )
243                 *class_p = get16(buf8 + ofs);
244         ofs += 2, consumed += 2;
245         
246         if( ttl_p )
247                 *ttl_p = get32(buf + ofs);
248         ofs += 4, consumed += 4;
249         
250         size_t rdlength = get16(buf + ofs);
251         if( rdlength_p )
252                 *rdlength_p = rdlength;
253         ofs += 2, consumed += 2;
254         
255         _SysDebug("DNS_int_ParseRR - name='%s', rdlength=%zi", name_p, rdlength);
256         
257         return consumed + rdlength;
258 }
259
260 static uint16_t get16(const void *buf) {
261         const uint8_t* buf8 = buf;
262         uint16_t rv = 0;
263         rv |= (uint16_t)buf8[0] << 8;
264         rv |= (uint16_t)buf8[1] << 0;
265         return rv;
266 }
267 static uint32_t get32(const void *buf) {
268         const uint8_t* buf8 = buf;
269         uint32_t rv = 0;
270         rv |= (uint32_t)buf8[0] << 24;
271         rv |= (uint32_t)buf8[1] << 16;
272         rv |= (uint32_t)buf8[2] << 8;
273         rv |= (uint32_t)buf8[3] << 0;
274         return rv;
275 }
276 static size_t put16(void *buf, uint16_t val) {
277         uint8_t* buf8 = buf;
278         buf8[0] = val >> 8;
279         buf8[1] = val & 0xFF;
280         return 2;
281 }

UCC git Repository :: git.ucc.asn.au