Usermode/libnet - Starting work on DNS resolver
[tpg/acess2.git] / Usermode / Libraries / libnet.so_src / dns.c
1 /*
2  * Acess2 Networking Toolkit
3  * By John Hodge (thePowersGang)
4  * 
5  * dns.c
6  * - Hostname<->Address resolution
7  */
8 #include <stddef.h>     // size_t / NULL
9 #include <stdint.h>     // uint*_t
10 #include <string.h>     // memcpy, strchrnul
11 #include <assert.h>
12 #include <net.h>
13 #include "include/dns.h"
14
15 // === PROTOTYPES ===
16 size_t  DNS_EncodeName(void *buf, const char *dotted_name);
17 int DNS_DecodeName(char dotted_name[256], const void *buf, size_t space);
18 int DNS_int_ParseRR(const void *buf, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p);
19 static uint16_t get16(const void *buf);
20 static size_t put16(void *buf, uint16_t val);
21
22
23 // === CODE ===
24 int DNS_Query(int ServerAType, const void *ServerAddr, const char *name, enum eTypes type, enum eClass class, handle_record_t* handle_record, void *info)
25 {
26         int namelen = DNS_EncodeName(NULL, name);
27         assert(namelen < 256);
28         size_t  pos = 0;
29         char    packet[ 512 ];
30         assert( (6*2) + (namelen + 2*2) < 512 );
31         // - Header
32         pos += put16(packet + pos, 0xAC00);     // Identifier (arbitary)
33         pos += put16(packet + pos, (0 << 0) | (0 << 1) );       // Op : Query, Standard, no other flags
34         pos += put16(packet + pos, 1);  // QDCount
35         pos += put16(packet + pos, 0);  // ANCount
36         pos += put16(packet + pos, 0);  // NSCount
37         pos += put16(packet + pos, 0);  // ARCount
38         // - Question
39         pos += DNS_EncodeName(packet + pos, name);
40         pos += put16(packet + pos, type);       // QType
41         pos += put16(packet + pos, class);      // QClass
42         
43         assert(pos <= sizeof(packet));
44         
45         // Send and wait for reply
46         // - Lock
47         //  > TODO: Lock DNS queries
48         // - Send
49         int sock = Net_OpenSocket_UDP(ServerAType, ServerAddr, 53, 0);
50         if( sock < 0 ) {
51                 // Connection failed
52                 // TODO: Correctly report this failure with a useful error code
53                 return 1;
54         }
55         int rv = _SysWrite(sock, packet, pos);
56         if( rv != pos ) {
57                 // TODO: Error reporting
58                 _SysClose(sock);
59                 return 1;
60         }
61         // - Wait
62         int return_len = 0;
63         do {
64                 return_len = _SysRead(sock, packet, sizeof(packet));
65         } while( return_len == 0 );
66         if( return_len < 0 ) {
67                 // TODO: Error reporting
68                 _SysClose(sock);
69                 return 1;
70         }
71         _SysClose(sock);
72         // - Release
73         //  > TODO: Lock DNS queries
74         
75         // For each response in the answer (and additional) sections, call the passed callback
76         char    rr_name[256];
77         unsigned int qd_count = get16(packet + 4);
78         unsigned int an_count = get16(packet + 6);
79         unsigned int ns_count = get16(packet + 8);
80         unsigned int ar_count = get16(packet + 10);
81         pos = 6*2;
82         // TODO: Can I safely assert / fail if qd_count is non-zero?
83         // - Questions, ignored
84         for( unsigned int i = 0; i < qd_count; i ++ ) {
85                 pos += DNS_DecodeName(NULL, packet + pos, return_len - pos);
86                 pos += 2*2;
87         }
88         // - Answers, pass on to handler
89         for( unsigned int i = 0; i < an_count; i ++ )
90         {
91                 enum eTypes     type;
92                 enum eClass     class;
93                 uint32_t        ttl;
94                 size_t  rdlength;
95                 int rv = DNS_int_ParseRR(packet + pos, return_len - pos, rr_name, &type, &class, &ttl, &rdlength);
96                 if( rv < 0 ) {
97                         return 1;
98                 }
99                 pos += rv;
100                 
101                 handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos);
102         }
103         // Authority Records (should all be NS records)
104         for( unsigned int i = 0; i < ns_count; i ++ )
105         {
106                 size_t  rdlength;
107                 int rv = DNS_int_ParseRR(packet + pos, return_len - pos, rr_name, NULL, NULL, NULL, &rdlength);
108                 if( rv < 0 ) {
109                         return 1;
110                 }
111                 pos += rv;
112         }
113         // - Additional records, pass to handler
114         for( unsigned int i = 0; i < ar_count; i ++ )
115         {
116                 enum eTypes     type;
117                 enum eClass     class;
118                 uint32_t        ttl;
119                 size_t  rdlength;
120                 int rv = DNS_int_ParseRR(packet + pos, return_len - pos, rr_name, &type, &class, &ttl, &rdlength);
121                 if( rv < 0 ) {
122                         return 1;
123                 }
124                 pos += rv;
125                 
126                 handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos);
127         }
128
129         return 0;
130 }
131
132 /// Encode a dotted name as a DNS name
133 size_t  DNS_EncodeName(void *buf, const char *dotted_name)
134 {
135         size_t  ret = 0;
136         const char *str = dotted_name;
137         uint8_t *buf8 = buf;
138         while( *str )
139         {
140                 const char *next = strchr(str, '.');
141                 size_t seg_len = (next ? next - str : strlen(str));
142                 if( seg_len > 63 ) {
143                         // Oops, too long (truncate)
144                         seg_len = 63;
145                 }
146                 if( seg_len == 0 && next != NULL ) {
147                         // '..' encountered, invalid (skip)
148                         str = next+1;
149                         continue ;
150                 }
151                 
152                 if( buf8 )
153                 {
154                         buf8[ret] = seg_len;
155                         memcpy(buf8+ret+1, str, seg_len);
156                 }
157                 ret += 1 + seg_len;
158                 
159                 if( next == NULL ) {
160                         // No trailing '.', assume it's there? Yes, need to be NUL terminated
161                         if(buf8)        buf8[ret] = 0;
162                         ret ++;
163                         break;
164                 }
165                 else {
166                         str = next + 1;
167                 }
168         }
169         return ret;
170 }
171
172 // Decode a name (including trailing . for root)
173 int DNS_DecodeName(char dotted_name[256], const void *buf, size_t space)
174 {
175         int consumed = 0;
176         int out_pos = 0;
177         const uint8_t *buf8 = buf;
178         while( *buf8 && space > 0 )
179         {
180                 if( consumed + 1 > space )      return -1;
181                 uint8_t seg_len = *buf8;
182                 buf8 ++;
183                 consumed ++;
184                 // Protocol violation (overflowed end of buffer)
185                 if( consumed + seg_len > space )
186                         return -1;
187                 // Protocol violation (segment too long)
188                 if( seg_len >= 64 )
189                         return -1;
190                 // Protocol violation (name was too long)
191                 if( out_pos + seg_len + 1 > sizeof(dotted_name)-1 )
192                         return -1;
193                 
194                 // Read segment
195                 memcpy(dotted_name + out_pos, buf8, seg_len);
196                 buf8 += seg_len;
197                 consumed += seg_len;
198                 
199                 // Place '.'
200                 dotted_name[out_pos+seg_len+1] = '.';
201                 // Increment output counter
202                 out_pos += seg_len + 1;
203         }
204         
205         dotted_name[out_pos] = '\0';
206         return consumed;
207 }
208
209 // Parse a Resource Record
210 int DNS_int_ParseRR(const void *buf, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p)
211 {
212         const uint8_t   *buf8 = buf;
213         size_t  consumed = 0;
214         
215         // 1. Name
216         int rv = DNS_DecodeName(name_p, buf8, space);
217         if(rv < 0)      return -1;
218         
219         buf8 += rv, consumed += rv;
220         
221         if( type_p )
222                 *type_p = get16(buf8);
223         buf8 += 2, consumed += 2;
224         
225         if( class_p )
226                 *class_p = get16(buf8);
227         buf8 += 2, consumed += 2;
228         
229         return consumed;
230 }
231
232 static uint16_t get16(const void *buf) {
233         const uint8_t* buf8 = buf;
234         uint16_t rv = 0;
235         rv |= buf8[0];
236         rv |= (uint16_t)buf8[1] << 8;
237         return rv;
238 }
239 static size_t put16(void *buf, uint16_t val) {
240         uint8_t* buf8 = buf;
241         buf8[0] = val & 0xFF;
242         buf8[1] = val >> 8;
243         return 2;
244 }
245

UCC git Repository :: git.ucc.asn.au