use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::UdpSocket; pub struct DnsPacket { pub header: DnsHeader, pub questions: Vec<DnsQuestion>, pub answers: Vec<DnsRecord>, pub authorities: Vec<DnsRecord>, pub resources: Vec<DnsRecord> } impl DnsPacket { pub fn new(query_type: QueryType, qname: &str) -> Self { let id = 6677; let mut header = DnsHeader::new(id); header.recursion_desired = true; let mut questions = Vec::new(); questions.push(DnsQuestion::new(String::from(qname), query_type)); DnsPacket { header, questions, answers: Vec::new(), authorities: Vec::new(), resources: Vec::new(), } } pub fn empty() -> Self { let id = 6677; let header = DnsHeader::new(id); DnsPacket { header, questions: Vec::new(), answers: Vec::new(), authorities: Vec::new(), resources: Vec::new(), } } pub fn from_bytes(buffer: &mut BytePacketBuffer) -> Self { let header = DnsHeader::from(buffer).unwrap(); let mut questions = Vec::new(); let mut answers = Vec::new(); let mut authorities = Vec::new(); let mut resources = Vec::new(); for _ in 0..header.question_count { let question = DnsQuestion::from(buffer); questions.push(question) } for _ in 0..header.answer_count { let answer = DnsRecord::from(buffer); answers.push(answer) } for _ in 0..header.authority_count { let authority = DnsRecord::from(buffer); authorities.push(authority) } for _ in 0..header.authority_count { let rec = DnsRecord::from(buffer); resources.push(rec) } DnsPacket { header, questions, answers, authorities, resources } } pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), &str> { self.header.question_count = self.questions.len() as u16; self.header.answer_count = self.answers.len() as u16; self.header.authority_count = self.authorities.len() as u16; self.header.add_count = self.resources.len() as u16; self.header.write(buffer).unwrap(); for question in &self.questions { question.write(buffer).unwrap(); } for rec in &self.answers { rec.write(buffer).unwrap(); } for rec in &self.authorities { rec.write(buffer).unwrap(); } for rec in &self.resources { rec.write(buffer).unwrap(); } Ok(()) } pub fn send(&mut self, server: (&str, u16)) -> DnsPacket { let mut req_buffer = BytePacketBuffer::new(); self.write(&mut req_buffer).unwrap(); let socket = UdpSocket::bind(("0.0.0.0", 43210)).unwrap(); socket.send_to(&req_buffer.buf[0..req_buffer.pos], server).unwrap(); let mut rec_buffer = BytePacketBuffer::new(); socket.recv_from(&mut rec_buffer.buf).unwrap(); let response = DnsPacket::from_bytes(&mut rec_buffer); return response; } } #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ResultCode { NOERROR = 0, FORMERR = 1, SERVFAIL = 2, NXDOMAIN = 3, NOTIMP = 4, REFUSED = 5, } impl ResultCode { pub fn from_num(num: u8) -> Self { match num { 1 => ResultCode::FORMERR, 2 => ResultCode::SERVFAIL, 3 => ResultCode::NXDOMAIN, 4 => ResultCode::NOTIMP, 5 => ResultCode::REFUSED, 0 | _ => ResultCode::NOERROR, } } } // The holy grail, the DNS Header itself #[derive(Clone, Debug)] pub struct DnsHeader { pub id: u16, // 16 bits pub query_response: bool, // 1 bit pub operation_code: u8, // 4 bits pub authoritative_answer: bool, // 1 bit pub truncated_message: bool, // 1 bit pub recursion_desired: bool, // 1 bit pub recursion_available: bool, // 1 bit pub z: u8, // 3 bits pub result_code: ResultCode, // 4 bits pub question_count: u16, // 16 bits pub answer_count: u16, // 16 bits pub authority_count: u16, // 16 bits pub add_count: u16, // 16 bits } impl DnsHeader { // returns a plain DnsHeader pub fn new(id: u16) -> Self { DnsHeader { id, query_response: false, operation_code: 0, authoritative_answer: false, truncated_message: false, recursion_desired: false, recursion_available: false, z: 0, result_code: ResultCode::NOERROR, question_count: 0, answer_count: 0, authority_count: 0, add_count: 0 } } pub fn from(buffer: &mut BytePacketBuffer) -> Result<Self, &str> { let id = buffer.read_u16().unwrap(); let flags = buffer.read_u16().unwrap(); let bit1 = (flags >> 8) as u8; // Most significant byte let bit2 = (flags & 0xFF) as u8; // Least significant byte // parse the DNS Header from a 512 byte buffer Ok(DnsHeader { id, recursion_desired: (bit1 & (1 << 0)) > 0, truncated_message: (bit1 & (1 << 1)) > 0, authoritative_answer: (bit1 & (1 << 2)) > 0, operation_code: (bit1 >> 3) & 0xF, query_response: (bit1 >> 7) > 0, result_code: ResultCode::from_num(bit2 & 0x0F), z: (bit2 >> 0x4) & 0x7, recursion_available: (bit2 >> 7) > 0, question_count: buffer.read_u16().unwrap(), answer_count: buffer.read_u16().unwrap(), authority_count: buffer.read_u16().unwrap(), add_count: buffer.read_u16().unwrap(), }) } pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), &str> { buffer.write_u16(self.id).unwrap(); buffer.write_u8( (self.recursion_desired as u8) | ((self.truncated_message as u8) << 1) | ((self.authoritative_answer as u8) << 2) | (self.operation_code << 3) | ((self.query_response as u8) << 7) as u8 ).unwrap(); buffer.write_u8( (self.result_code as u8) | ((self.z as u8) << 6) | ((self.recursion_available as u8) << 7) ).unwrap(); buffer.write_u16(self.question_count).unwrap(); buffer.write_u16(self.answer_count).unwrap(); buffer.write_u16(self.authority_count).unwrap(); buffer.write_u16(self.add_count).unwrap(); Ok(()) } } #[derive(Copy, Clone, PartialEq, Eq, Debug, Ord, PartialOrd, Hash)] pub enum QueryType { UNKNOWN(u16), A, // = 1 NS, // = 2 CNAME, // = 5 MX, // = 15 AAAA, // 28 TXT, // 16 } impl QueryType { pub fn to_num(&self) -> u16 { match *self { QueryType::UNKNOWN(x) => x, QueryType::A => 1, QueryType::NS => 2, QueryType::CNAME => 5, QueryType::MX => 15, QueryType::TXT => 16, QueryType::AAAA => 28, } } pub fn from_num(num: u16) -> Self { match num { 1 => QueryType::A, 2 => QueryType::NS, 5 => QueryType::CNAME, 15 => QueryType::MX, 16 => QueryType::TXT, 28 => QueryType::AAAA, _ => QueryType::UNKNOWN(num), } } } #[derive(Clone, PartialEq, Eq, Debug)] pub struct DnsQuestion { pub name: String, pub qtype: QueryType } impl DnsQuestion { pub fn new(name: String, qtype: QueryType) -> Self { DnsQuestion { name, qtype } } pub fn from(buffer: &mut BytePacketBuffer) -> Self { let name = buffer.read_qname().unwrap(); let qtype = QueryType::from_num(buffer.read_u16().unwrap()); let _ = buffer.read_u16().unwrap(); DnsQuestion { name, qtype } } pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), &str> { buffer.write_qname(&self.name).unwrap(); let qtype_num = self.qtype.to_num(); buffer.write_u16(qtype_num).unwrap(); buffer.write_u16(1).unwrap(); Ok(()) } } #[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] pub enum DnsRecord { UNKNOWN { domain: String, qtype: QueryType, data_len: u16, ttl: u32 }, A { domain: String, addr: Ipv4Addr, ttl: u32 }, NS { domain: String, host: String, ttl: u32 }, CNAME { domain: String, host: String, ttl: u32 }, MX { domain: String, priority: u16, host: String, ttl: u32 }, AAAA { domain: String, addr: Ipv6Addr, ttl: u32 }, TXT { domain: String, len: u8, txt: String, ttl: u32 }, } impl DnsRecord { pub fn from(buffer: &mut BytePacketBuffer) -> Self { let domain = buffer.read_qname().unwrap(); let qtype_num = buffer.read_u16().unwrap(); let qtype = QueryType::from_num(qtype_num); let _ = buffer.read_u16().unwrap(); // CLASS (should always be IN) let ttl = buffer.read_u32().unwrap(); let data_len = buffer.read_u16().unwrap(); match qtype { QueryType::A => { let raw_addr = buffer.read_u32().unwrap(); let addr = Ipv4Addr::new( ((raw_addr >> 24) & 0xFF) as u8, ((raw_addr >> 16) & 0xFF) as u8, ((raw_addr >> 8) & 0xFF) as u8, ((raw_addr) & 0xFF) as u8); DnsRecord::A { domain, addr, ttl } }, QueryType::UNKNOWN(_) => { buffer.step(data_len as usize).unwrap(); DnsRecord::UNKNOWN { domain, qtype, data_len, ttl } }, QueryType::AAAA => { let raw_addr0 = buffer.read_u32().unwrap(); let raw_addr1 = buffer.read_u32().unwrap(); let raw_addr2 = buffer.read_u32().unwrap(); let raw_addr3 = buffer.read_u32().unwrap(); let addr = Ipv6Addr::new( ((raw_addr0 >> 16) & 0xFF) as u16, ((raw_addr0 >> 0) & 0xFF) as u16, ((raw_addr1 >> 16) & 0xFF) as u16, ((raw_addr1 >> 0) & 0xFF) as u16, ((raw_addr2 >> 16) & 0xFF) as u16, ((raw_addr2 >> 0) & 0xFF) as u16, ((raw_addr3 >> 16) & 0xFF) as u16, ((raw_addr3 >> 0) & 0xFF) as u16); DnsRecord::AAAA { domain, addr, ttl } }, QueryType::NS => { let ns = buffer.read_qname().unwrap(); DnsRecord::NS { domain, host: ns, ttl } }, QueryType::CNAME => { let cname = buffer.read_qname().unwrap(); DnsRecord::CNAME { domain, host: cname, ttl } }, QueryType::MX => { let priority = buffer.read_u16().unwrap(); let mx = buffer.read_qname().unwrap(); DnsRecord::MX { domain, priority, host: mx, ttl } }, QueryType::TXT => { let txt_len = buffer.read().unwrap(); let mut s = Vec::with_capacity(txt_len as usize); for _ in 0..txt_len { s.push(buffer.read().unwrap()); } let txt = String::from_utf8_lossy(&s).to_string(); DnsRecord::TXT { domain, len: txt_len as u8, txt, ttl } } } } pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<usize, &str> { let start_pos = buffer.pos(); match *self { DnsRecord::A { ref domain, ref addr, ttl } => { buffer.write_qname(domain).unwrap(); buffer.write_u16(QueryType::A.to_num()).unwrap(); buffer.write_u16(1).unwrap(); buffer.write_u32(ttl).unwrap(); buffer.write_u16(4).unwrap(); let oct = addr.octets(); buffer.write_u8(oct[0]).unwrap(); buffer.write_u8(oct[1]).unwrap(); buffer.write_u8(oct[2]).unwrap(); buffer.write_u8(oct[3]).unwrap(); } DnsRecord::AAAA { ref domain, ref addr, ttl } => { buffer.write_qname(domain).unwrap(); buffer.write_u16(QueryType::AAAA.to_num()).unwrap(); buffer.write_u16(1).unwrap(); // CLASS `IN` buffer.write_u32(ttl).unwrap(); // TTL buffer.write_u16(8).unwrap(); // Length for oct in addr.segments().iter() { buffer.write_u16(*oct).unwrap(); } }, DnsRecord::NS { ref domain, ref host, ttl } => { buffer.write_qname(domain).unwrap(); buffer.write_u16(QueryType::NS.to_num()).unwrap(); buffer.write_u16(1).unwrap(); buffer.write_u32(ttl).unwrap(); let pos = buffer.pos(); buffer.write_u16(0).unwrap(); buffer.write_qname(host).unwrap(); let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16).unwrap(); }, DnsRecord::CNAME { ref domain, ref host, ttl } => { buffer.write_qname(domain).unwrap(); buffer.write_u16(QueryType::CNAME.to_num()).unwrap(); buffer.write_u16(1).unwrap(); buffer.write_u32(ttl).unwrap(); let pos = buffer.pos(); buffer.write_u16(0).unwrap(); // write temporary size buffer.write_qname(host).unwrap(); // write host let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16).unwrap(); }, DnsRecord::MX { ref domain, priority, ref host, ttl } => { buffer.write_qname(domain).unwrap(); buffer.write_u16(QueryType::MX.to_num()).unwrap(); buffer.write_u16(1).unwrap(); buffer.write_u32(ttl).unwrap(); let pos = buffer.pos(); buffer.write_u16(0).unwrap(); // write temporary size buffer.write_u16(priority).unwrap(); buffer.write_qname(host).unwrap(); // write host let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16).unwrap(); }, DnsRecord::TXT { ref domain, len, ref txt, ttl } => { buffer.write_qname(domain).unwrap(); buffer.write_u16(QueryType::TXT.to_num()).unwrap(); buffer.write_u16(1).unwrap(); buffer.write_u32(ttl).unwrap(); let pos = buffer.pos(); buffer.write_u16(0).unwrap(); // write temporary size buffer.write_u8(len).unwrap(); for byte in txt.as_bytes() { buffer.write_u8(*byte).unwrap(); // write byte by byte } let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16).unwrap(); }, DnsRecord::UNKNOWN {..} => { return Err("Unknown DnsRecord"); }, } Ok(buffer.pos() - start_pos) } } /// Buffer containing the full DNS Packet pub struct BytePacketBuffer { pub buf: [u8; 512], pub pos: usize, } impl BytePacketBuffer { pub fn new() -> Self { BytePacketBuffer { buf: [0; 512] ,pos: 0} } pub fn pos(&self) -> usize { self.pos } pub fn step(&mut self, steps: usize) -> Result<(), &str> { self.pos += steps; Ok(()) } pub fn seek(&mut self, pos: usize) -> Result<(), &str> { self.pos = pos; Ok(()) } /// Read a single byte from the buffer pub fn read(&mut self) -> Result<u8, &str> { if self.pos >= 512 { return Err("End of Buffer"); } let ret = self.buf[self.pos]; self.pos += 1; Ok(ret) } pub fn get(&self, pos: usize) -> Result<u8, &str> { if pos >= 512 { return Err("End of Buffer"); } Ok(self.buf[pos]) } pub fn get_range(&self, start: usize, len: usize) -> Result<&[u8], &str> { if (start + len) >= 512 { return Err("End of Buffer"); } Ok(&self.buf[start..start+len]) } pub fn read_u16(&mut self) -> Result<u16, &str> { if self.pos >= 512 { return Err("End of Buffer"); } let mut res = (self.read().unwrap() as u16) << 8; res |= self.read()? as u16; Ok(res) } pub fn read_u32(&mut self) -> Result<u32, &str> { if self.pos >= 512 { return Err("End of Buffer"); } // Hack: calling ? on read() results in error // "Second mutable borrow occurs here" let mut res = (self.read().unwrap() as u32) << 24; res |= (self.read().unwrap() as u32) << 16; res |= (self.read().unwrap() as u32) << 8; res |= self.read()? as u32; Ok(res) } /// Reading Domain Name in the following format /// [len][name][len][name] [...] /// The first [len] can be a jump to another position in the buffer /// if the two most sign. bits are set (0xC0) fn read_qname(&mut self) -> Result<String, &str> { let mut lpos = self.pos; let mut jumped = false; let mut num_jmps = 0; let max_jmps = 5; // delimiter to split the domain names let mut delim = ""; let mut outstr = String::from(""); // begin reading the domain names loop { // prevent endless looping jmps if num_jmps >= max_jmps { return Err("too many jumps"); } let len = self.get(lpos).unwrap() as usize; if (len & 0xC0) == 0xC0 { if !jumped { self.seek(lpos + 2).unwrap(); } // This indicates a jump to another location in the buffer // the location is indicated by the following byte let sb = self.get(lpos + 1).unwrap() as u16; let sb = ((self.get(lpos).unwrap() as u16 ^ 0xC0) << 8) as u16 | sb; // set the position to the jmp location lpos = sb as usize; jumped = true; num_jmps += 1; continue; } else { // Read regular domain name lpos += 1; // if the len is 0, this means that the final part of // the domain names is arrived if len == 0 { break } outstr.push_str(delim); let domain = &String::from_utf8_lossy(self.get_range(lpos, len).unwrap()).to_lowercase(); outstr.push_str(domain); delim = "."; // move forward the local pos lpos += len; } } if !jumped { self.seek(lpos).unwrap(); } Ok(outstr) } fn write(&mut self, val: u8) -> Result<(), &str> { if self.pos >= 512 { return Err("End of buffer"); } self.buf[self.pos] = val; self.pos += 1; Ok(()) } fn write_u8(&mut self, val: u8) -> Result<(), &str> { self.write(val) } fn write_u16(&mut self, val: u16) -> Result<(), &str> { self.write_u8((val >> 8) as u8).unwrap(); self.write_u8((val & 0xFF) as u8).unwrap(); Ok(()) } fn write_u32(&mut self, val: u32) -> Result<(), &str> { self.write_u8(((val >> 24) & 0xFF) as u8).unwrap(); self.write_u8(((val >> 16) & 0xFF) as u8).unwrap(); self.write_u8(((val >> 8) & 0xFF) as u8).unwrap(); self.write_u8(((val >> 0) & 0xFF) as u8).unwrap(); Ok(()) } fn write_qname(&mut self, qname: &str) -> Result<(), &str> { for label in qname.split(".") { let len = label.len(); if len > 0x3F { return Err("Single Label exeeds 63 characters"); } self.write_u8(len as u8).unwrap(); for b in label.as_bytes() { self.write_u8(*b as u8).unwrap(); } } self.write_u8(0).unwrap(); Ok(()) } fn set(&mut self, pos: usize, val: u8) -> Result<(), &str> { if pos >= 512 { return Err("End of Buffer"); } self.buf[pos] = val; Ok(()) } fn set_u16(&mut self, pos: usize, val: u16) -> Result<(), &str> { if pos >= 511 { return Err("End of Buffer"); } self.buf[pos] = ((val >> 8) & 0xFF) as u8; self.buf[pos+1] = (val & 0xFF) as u8; Ok(()) } }