Networking - DNS resolution semi-working
[tpg/acess2.git] / Usermode / Libraries / libnet.so_src / dns.c
1 /*
2  * Acess2 Networking Toolkit
3  * By John Hodge (thePowersGang)
4  * 
5  * dns.c
6  * - Hostname<->Address resolution
7  */
8 #include <stddef.h>     // size_t / NULL
9 #include <stdint.h>     // uint*_t
10 #include <string.h>     // memcpy, strchr
11 #include <assert.h>
12 #include <acess/sys.h>  // for _SysSelect
13 #include <acess/fd_set.h>       // FD_SET
14 #include <net.h>
15 #include "include/dns.h"
16
17 // === PROTOTYPES ===
18 //int DNS_Query(int ServerAType, const void *ServerAddr, const char *name, enum eTypes type, enum eClass class, handle_record_t* handle_record, void *info);
19 int DNS_int_ParseResponse(const void* packet, size_t return_len, void *info, handle_record_t* handle_record_t);
20 size_t  DNS_EncodeName(void *buf, const char *dotted_name);
21 int DNS_DecodeName(char dotted_name[256], const void *buf, size_t ofs, size_t space);
22 int DNS_int_ParseRR(const void *buf, size_t ofs, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p);
23
24 static uint16_t get16(const void *buf);
25 static uint32_t get32(const void *buf);
26 static size_t put16(void *buf, uint16_t val);
27
28
29 // === CODE ===
30 int DNS_Query(int ServerAType, const void *ServerAddr, const char *name, enum eTypes type, enum eClass class, handle_record_t* handle_record, void *info)
31 {
32         int namelen = DNS_EncodeName(NULL, name);
33         assert(namelen < 256);
34         size_t  pos = 0;
35         char    packet[ 512 ];
36         assert( (6*2) + (namelen + 2*2) < 512 );
37         // - Header
38         pos += put16(packet + pos, 0xAC00);     // Identifier (arbitary)
39         pos += put16(packet + pos, (0 << 0) | (0 << 1) );       // Op : Query, Standard, no other flags
40         pos += put16(packet + pos, 1);  // QDCount
41         pos += put16(packet + pos, 0);  // ANCount
42         pos += put16(packet + pos, 0);  // NSCount
43         pos += put16(packet + pos, 0);  // ARCount
44         // - Question
45         pos += DNS_EncodeName(packet + pos, name);
46         pos += put16(packet + pos, type);       // QType
47         pos += put16(packet + pos, class);      // QClass
48         
49         assert(pos <= sizeof(packet));
50         
51         // Send and wait for reply
52         // - Lock
53         //  > TODO: Lock DNS queries
54         // - Send
55         int sock = Net_OpenSocket_UDP(ServerAType, ServerAddr, 53, 0);
56         if( sock < 0 ) {
57                 // Connection failed
58                 _SysDebug("DNS_Query - UDP open failed");
59                 // TODO: Correctly report this failure with a useful error code
60                 return 1;
61         }
62         int rv = Net_UDP_SendTo(sock, 53, ServerAType, ServerAddr, pos, packet);
63         if( rv != pos ) {
64                 _SysDebug("DNS_Query - Write failed");
65                 // TODO: Error reporting
66                 _SysClose(sock);
67                 return 1;
68         }
69         // - Wait
70         {
71                  int    nfd = sock + 1;
72                 fd_set  fds;
73                 FD_ZERO(&fds);
74                 FD_SET(sock, &fds);
75                 int64_t timeout = 2000; // Give it two seconds, should be long enough
76                 rv = _SysSelect(nfd, &fds, NULL, NULL, &timeout, 0);
77                 if( rv == 0 ) {
78                         // Timeout with no reply, give up
79                         _SysDebug("DNS_Query - Timeout");
80                         _SysClose(sock);
81                         return 1;
82                 }
83                 if( rv < 0 ) {
84                         // Oops, select failed
85                         _SysDebug("DNS_Query - Select failure");
86                         _SysClose(sock);
87                         return 1;
88                 }
89         }
90         int return_len = Net_UDP_RecvFrom(sock, NULL, NULL, NULL, sizeof(packet), packet);
91         if( return_len <= 0 ) {
92                 // TODO: Error reporting
93                 _SysDebug("DNS_Query - Read failure");
94                 _SysClose(sock);
95                 return 1;
96         }
97         _SysClose(sock);
98         // - Release
99         //  > TODO: Lock DNS queries
100         
101         // For each response in the answer (and additional) sections, call the passed callback
102         return DNS_int_ParseResponse(packet, return_len, info, handle_record);
103 }
104
105 int DNS_int_ParseResponse(const void* buf, size_t return_len, void *info, handle_record_t* handle_record)
106 {
107         const uint8_t* packet = buf;
108         char    rr_name[256];
109         unsigned int id = get16(packet + 0);
110         if( id != 0xAC00 ) {
111                 _SysDebug("DNS_Query - Packet ID mismatch");
112                 return 2;
113         }
114         unsigned int flags = get16(packet + 2);
115         unsigned int qd_count = get16(packet + 4);
116         unsigned int an_count = get16(packet + 6);
117         unsigned int ns_count = get16(packet + 8);
118         unsigned int ar_count = get16(packet + 10);
119         size_t pos = 6*2;
120         // TODO: Can I safely assert / fail if qd_count is non-zero?
121         // - Questions, ignored
122         for( unsigned int i = 0; i < qd_count; i ++ ) {
123                 int rv = DNS_DecodeName(rr_name, packet, pos, return_len);
124                 if( rv < 0 ) {
125                         _SysDebug("DNS_Query - Parse error in QD");
126                         return 1;
127                 }
128                 pos += rv + 2*2;
129         }
130         // - Answers, pass on to handler
131         for( unsigned int i = 0; i < an_count; i ++ )
132         {
133                 enum eTypes     type;
134                 enum eClass     class;
135                 uint32_t        ttl;
136                 size_t  rdlength;
137                 int rv = DNS_int_ParseRR(packet, pos, return_len, rr_name, &type, &class, &ttl, &rdlength);
138                 if( rv < 0 ) {
139                         _SysDebug("DNS_Query - Parse error in AN");
140                         return 1;
141                 }
142                 pos += rv;
143                 
144                 handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos - rdlength);
145         }
146         // Authority Records (should all be NS records)
147         for( unsigned int i = 0; i < ns_count; i ++ )
148         {
149                 size_t  rdlength;
150                 int rv = DNS_int_ParseRR(packet, pos, return_len, rr_name, NULL, NULL, NULL, &rdlength);
151                 if( rv < 0 ) {
152                         _SysDebug("DNS_Query - Parse error in NS");
153                         return 1;
154                 }
155                 pos += rv;
156         }
157         // - Additional records, pass to handler
158         for( unsigned int i = 0; i < ar_count; i ++ )
159         {
160                 enum eTypes     type;
161                 enum eClass     class;
162                 uint32_t        ttl;
163                 size_t  rdlength;
164                 int rv = DNS_int_ParseRR(packet, pos, return_len, rr_name, &type, &class, &ttl, &rdlength);
165                 if( rv < 0 ) {
166                         _SysDebug("DNS_Query - Parse error in AR");
167                         return 1;
168                 }
169                 pos += rv;
170                 
171                 handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos - rdlength);
172         }
173         
174         return 0;
175 }
176
177 /// Encode a dotted name as a DNS name
178 size_t  DNS_EncodeName(void *buf, const char *dotted_name)
179 {
180         size_t  ret = 0;
181         const char *str = dotted_name;
182         uint8_t *buf8 = buf;
183         while( *str )
184         {
185                 const char *next = strchr(str, '.');
186                 size_t seg_len = (next ? next - str : strlen(str));
187                 if( seg_len > 63 ) {
188                         // Oops, too long (truncate)
189                         seg_len = 63;
190                 }
191                 if( seg_len == 0 && next != NULL ) {
192                         // '..' encountered, invalid (skip)
193                         str = next+1;
194                         continue ;
195                 }
196                 
197                 if( buf8 )
198                 {
199                         buf8[ret] = seg_len;
200                         memcpy(buf8+ret+1, str, seg_len);
201                 }
202                 ret += 1 + seg_len;
203                 
204                 if( next == NULL ) {
205                         // No trailing '.', assume it's there? Yes, need to be NUL terminated
206                         if(buf8)        buf8[ret] = 0;
207                         ret ++;
208                         break;
209                 }
210                 else {
211                         str = next + 1;
212                 }
213         }
214         return ret;
215 }
216
217 // Decode a name (including trailing . for root)
218 int DNS_DecodeName(char dotted_name[256], const void *buf, size_t ofs, size_t space)
219 {
220         int consumed = 0;
221         int out_pos = 0;
222         const uint8_t *buf8 = (const uint8_t*)buf + ofs;
223         for( ;; )
224         {
225                 if( ofs + consumed + 1 > space ) {
226                         _SysDebug("DNS_DecodeName - Len byte OOR space=%i", space);
227                         return -1;
228                 }
229                 uint8_t seg_len = *buf8;
230                 buf8 ++;
231                 consumed ++;
232                 // Done
233                 if( seg_len == 0 )
234                         break;
235                 if( (seg_len & 0xC0) == 0xC0 )
236                 {
237                         // Backreference, the rest of the name is a backref
238                         char tmp[256];
239                         int ref_ofs = get16(buf8 - 1) & 0x3FFF;
240                         consumed += 1, buf8 += 1;       // Only one, previous inc still applies
241                         _SysDebug("DNS_DecodeName - Nested at %i", ref_ofs);
242                         if( DNS_DecodeName(tmp, buf, ref_ofs, space) < 0 )
243                                 return -1;
244                         memcpy(dotted_name+out_pos, tmp, strlen(tmp));
245                         out_pos += strlen(tmp);
246                         break;
247                 }
248                 // Protocol violation (segment too long)
249                 if( seg_len >= 64 ) {
250                         _SysDebug("DNS_DecodeName - Seg too long %i", seg_len);
251                         return -1;
252                 }
253                 // Protocol violation (overflowed end of buffer)
254                 if( ofs + consumed + seg_len > space ) {
255                         _SysDebug("DNS_DecodeName - Seg OOR %i+%i>%i", consumed, seg_len, space);
256                         return -1;
257                 }
258                 // Protocol violation (name was too long)
259                 if( out_pos + seg_len + 1 > 255 ) {
260                         _SysDebug("DNS_DecodeName - Dotted name too long %i+%i+1 > %i",
261                                 out_pos, seg_len, 255);
262                         return -1;
263                 }
264                 
265                 _SysDebug("DNS_DecodeName : Seg %i '%.*s'", seg_len, seg_len, buf8);
266                 
267                 // Read segment
268                 memcpy(dotted_name + out_pos, buf8, seg_len);
269                 buf8 += seg_len;
270                 consumed += seg_len;
271                 out_pos += seg_len;
272                 
273                 // Place '.'
274                 dotted_name[out_pos] = '.';
275                 out_pos ++;
276         }
277         dotted_name[out_pos] = '\0';
278         _SysDebug("DNS_DecodeName - '%s', consumed = %i", dotted_name, consumed);
279         return consumed;
280 }
281
282 // Parse a Resource Record
283 int DNS_int_ParseRR(const void *buf, size_t ofs, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p)
284 {
285         const uint8_t   *buf8 = buf;
286         size_t  consumed = 0;
287         
288         // 1. Name
289         int rv = DNS_DecodeName(name_p, buf, ofs, space);
290         if(rv < 0)      return -1;
291         
292         ofs += rv, consumed += rv;
293         
294         if( type_p )
295                 *type_p = get16(buf8 + ofs);
296         ofs += 2, consumed += 2;
297         
298         if( class_p )
299                 *class_p = get16(buf8 + ofs);
300         ofs += 2, consumed += 2;
301         
302         if( ttl_p )
303                 *ttl_p = get32(buf + ofs);
304         ofs += 4, consumed += 4;
305         
306         size_t rdlength = get16(buf + ofs);
307         if( rdlength_p )
308                 *rdlength_p = rdlength;
309         ofs += 2, consumed += 2;
310         
311         _SysDebug("DNS_int_ParseRR - name='%s', rdlength=%i", name_p, rdlength);
312         
313         return consumed + rdlength;
314 }
315
316 static uint16_t get16(const void *buf) {
317         const uint8_t* buf8 = buf;
318         uint16_t rv = 0;
319         rv |= (uint16_t)buf8[0] << 8;
320         rv |= (uint16_t)buf8[1] << 0;
321         return rv;
322 }
323 static uint32_t get32(const void *buf) {
324         const uint8_t* buf8 = buf;
325         uint32_t rv = 0;
326         rv |= (uint32_t)buf8[0] << 24;
327         rv |= (uint32_t)buf8[1] << 16;
328         rv |= (uint32_t)buf8[2] << 8;
329         rv |= (uint32_t)buf8[3] << 0;
330         return rv;
331 }
332 static size_t put16(void *buf, uint16_t val) {
333         uint8_t* buf8 = buf;
334         buf8[0] = val >> 8;
335         buf8[1] = val & 0xFF;
336         return 2;
337 }
338

UCC git Repository :: git.ucc.asn.au