X-Git-Url: https://git.ucc.asn.au/?a=blobdiff_plain;f=Usermode%2FLibraries%2Flibnet.so_src%2Fdns.c;h=9ac9198b2ec1929e12de61566838cfcdf589057c;hb=db55040ba8814edf681d4ccc12ad8955d8aa404a;hp=001ec0006162a7f76ce5c987525ab95266b5557a;hpb=07e446727e54a17327b53928ce8582ba10eec619;p=tpg%2Facess2.git diff --git a/Usermode/Libraries/libnet.so_src/dns.c b/Usermode/Libraries/libnet.so_src/dns.c index 001ec000..9ac9198b 100644 --- a/Usermode/Libraries/libnet.so_src/dns.c +++ b/Usermode/Libraries/libnet.so_src/dns.c @@ -9,38 +9,20 @@ #include // uint*_t #include // memcpy, strchr #include +#include // for _SysSelect +#include // FD_SET #include #include "include/dns.h" +#include "include/dns_int.h" // === PROTOTYPES === -size_t DNS_EncodeName(void *buf, const char *dotted_name); -int DNS_DecodeName(char dotted_name[256], const void *buf, size_t space); -int DNS_int_ParseRR(const void *buf, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p); -static uint16_t get16(const void *buf); -static size_t put16(void *buf, uint16_t val); - +//int DNS_Query(int ServerAType, const void *ServerAddr, const char *name, enum eTypes type, enum eClass class, handle_record_t* handle_record, void *info); // === CODE === int DNS_Query(int ServerAType, const void *ServerAddr, const char *name, enum eTypes type, enum eClass class, handle_record_t* handle_record, void *info) { - int namelen = DNS_EncodeName(NULL, name); - assert(namelen < 256); - size_t pos = 0; - char packet[ 512 ]; - assert( (6*2) + (namelen + 2*2) < 512 ); - // - Header - pos += put16(packet + pos, 0xAC00); // Identifier (arbitary) - pos += put16(packet + pos, (0 << 0) | (0 << 1) ); // Op : Query, Standard, no other flags - pos += put16(packet + pos, 1); // QDCount - pos += put16(packet + pos, 0); // ANCount - pos += put16(packet + pos, 0); // NSCount - pos += put16(packet + pos, 0); // ARCount - // - Question - pos += DNS_EncodeName(packet + pos, name); - pos += put16(packet + pos, type); // QType - pos += put16(packet + pos, class); // QClass - - assert(pos <= sizeof(packet)); + char packet[512]; + size_t packlen = DNS_int_EncodeQuery(packet, sizeof(packet), name, type, class); // Send and wait for reply // - Lock @@ -49,197 +31,50 @@ int DNS_Query(int ServerAType, const void *ServerAddr, const char *name, enum eT int sock = Net_OpenSocket_UDP(ServerAType, ServerAddr, 53, 0); if( sock < 0 ) { // Connection failed + _SysDebug("DNS_Query - UDP open failed"); // TODO: Correctly report this failure with a useful error code return 1; } - int rv = _SysWrite(sock, packet, pos); + int rv = Net_UDP_SendTo(sock, 53, ServerAType, ServerAddr, pos, packet); if( rv != pos ) { + _SysDebug("DNS_Query - Write failed"); // TODO: Error reporting _SysClose(sock); return 1; } // - Wait - int return_len = 0; - do { - return_len = _SysRead(sock, packet, sizeof(packet)); - } while( return_len == 0 ); - if( return_len < 0 ) { - // TODO: Error reporting - _SysClose(sock); - return 1; - } - _SysClose(sock); - // - Release - // > TODO: Lock DNS queries - - // For each response in the answer (and additional) sections, call the passed callback - char rr_name[256]; - unsigned int qd_count = get16(packet + 4); - unsigned int an_count = get16(packet + 6); - unsigned int ns_count = get16(packet + 8); - unsigned int ar_count = get16(packet + 10); - pos = 6*2; - // TODO: Can I safely assert / fail if qd_count is non-zero? - // - Questions, ignored - for( unsigned int i = 0; i < qd_count; i ++ ) { - pos += DNS_DecodeName(NULL, packet + pos, return_len - pos); - pos += 2*2; - } - // - Answers, pass on to handler - for( unsigned int i = 0; i < an_count; i ++ ) { - enum eTypes type; - enum eClass class; - uint32_t ttl; - size_t rdlength; - int rv = DNS_int_ParseRR(packet + pos, return_len - pos, rr_name, &type, &class, &ttl, &rdlength); - if( rv < 0 ) { + int nfd = sock + 1; + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + int64_t timeout = 2000; // Give it two seconds, should be long enough + rv = _SysSelect(nfd, &fds, NULL, NULL, &timeout, 0); + if( rv == 0 ) { + // Timeout with no reply, give up + _SysDebug("DNS_Query - Timeout"); + _SysClose(sock); return 1; } - pos += rv; - - handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos); - } - // Authority Records (should all be NS records) - for( unsigned int i = 0; i < ns_count; i ++ ) - { - size_t rdlength; - int rv = DNS_int_ParseRR(packet + pos, return_len - pos, rr_name, NULL, NULL, NULL, &rdlength); if( rv < 0 ) { + // Oops, select failed + _SysDebug("DNS_Query - Select failure"); + _SysClose(sock); return 1; } - pos += rv; } - // - Additional records, pass to handler - for( unsigned int i = 0; i < ar_count; i ++ ) - { - enum eTypes type; - enum eClass class; - uint32_t ttl; - size_t rdlength; - int rv = DNS_int_ParseRR(packet + pos, return_len - pos, rr_name, &type, &class, &ttl, &rdlength); - if( rv < 0 ) { - return 1; - } - pos += rv; - - handle_record(info, rr_name, type, class, ttl, rdlength, packet + pos); - } - - return 0; -} - -/// Encode a dotted name as a DNS name -size_t DNS_EncodeName(void *buf, const char *dotted_name) -{ - size_t ret = 0; - const char *str = dotted_name; - uint8_t *buf8 = buf; - while( *str ) - { - const char *next = strchr(str, '.'); - size_t seg_len = (next ? next - str : strlen(str)); - if( seg_len > 63 ) { - // Oops, too long (truncate) - seg_len = 63; - } - if( seg_len == 0 && next != NULL ) { - // '..' encountered, invalid (skip) - str = next+1; - continue ; - } - - if( buf8 ) - { - buf8[ret] = seg_len; - memcpy(buf8+ret+1, str, seg_len); - } - ret += 1 + seg_len; - - if( next == NULL ) { - // No trailing '.', assume it's there? Yes, need to be NUL terminated - if(buf8) buf8[ret] = 0; - ret ++; - break; - } - else { - str = next + 1; - } - } - return ret; -} - -// Decode a name (including trailing . for root) -int DNS_DecodeName(char dotted_name[256], const void *buf, size_t space) -{ - int consumed = 0; - int out_pos = 0; - const uint8_t *buf8 = buf; - while( *buf8 && space > 0 ) - { - if( consumed + 1 > space ) return -1; - uint8_t seg_len = *buf8; - buf8 ++; - consumed ++; - // Protocol violation (overflowed end of buffer) - if( consumed + seg_len > space ) - return -1; - // Protocol violation (segment too long) - if( seg_len >= 64 ) - return -1; - // Protocol violation (name was too long) - if( out_pos + seg_len + 1 > sizeof(dotted_name)-1 ) - return -1; - - // Read segment - memcpy(dotted_name + out_pos, buf8, seg_len); - buf8 += seg_len; - consumed += seg_len; - - // Place '.' - dotted_name[out_pos+seg_len+1] = '.'; - // Increment output counter - out_pos += seg_len + 1; + int return_len = Net_UDP_RecvFrom(sock, NULL, NULL, NULL, sizeof(packet), packet); + if( return_len <= 0 ) { + // TODO: Error reporting + _SysDebug("DNS_Query - Read failure"); + _SysClose(sock); + return 1; } + _SysClose(sock); + // - Release + // > TODO: Lock DNS queries - dotted_name[out_pos] = '\0'; - return consumed; -} - -// Parse a Resource Record -int DNS_int_ParseRR(const void *buf, size_t space, char* name_p, enum eTypes* type_p, enum eClass* class_p, uint32_t* ttl_p, size_t* rdlength_p) -{ - const uint8_t *buf8 = buf; - size_t consumed = 0; - - // 1. Name - int rv = DNS_DecodeName(name_p, buf8, space); - if(rv < 0) return -1; - - buf8 += rv, consumed += rv; - - if( type_p ) - *type_p = get16(buf8); - buf8 += 2, consumed += 2; - - if( class_p ) - *class_p = get16(buf8); - buf8 += 2, consumed += 2; - - return consumed; -} - -static uint16_t get16(const void *buf) { - const uint8_t* buf8 = buf; - uint16_t rv = 0; - rv |= buf8[0]; - rv |= (uint16_t)buf8[1] << 8; - return rv; -} -static size_t put16(void *buf, uint16_t val) { - uint8_t* buf8 = buf; - buf8[0] = val & 0xFF; - buf8[1] = val >> 8; - return 2; + // For each response in the answer (and additional) sections, call the passed callback + return DNS_int_ParseResponse(packet, return_len, info, handle_record); }