use std::net::{Ipv4Addr, Ipv6Addr};
use std::net::UdpSocket;



pub struct DnsPacket {
    pub header: DnsHeader,
    pub questions: Vec<DnsQuestion>,
    pub answers: Vec<DnsRecord>,
    pub authorities: Vec<DnsRecord>,
    pub resources: Vec<DnsRecord>
}


impl DnsPacket {

    pub fn new(query_type: QueryType, qname: &str) -> Self {

        let id = 6677;
        let mut header = DnsHeader::new(id);
        header.recursion_desired = true;
        let mut questions = Vec::new();
        questions.push(DnsQuestion::new(String::from(qname), query_type));

        DnsPacket {
            header,
            questions,
            answers: Vec::new(),
            authorities: Vec::new(),
            resources: Vec::new(),
        }
    }

    pub fn empty() -> Self {
        let id = 6677;
        let header = DnsHeader::new(id);

        DnsPacket {
            header,
            questions: Vec::new(),
            answers: Vec::new(),
            authorities: Vec::new(),
            resources: Vec::new(),
        }
    }

    pub fn from_bytes(buffer: &mut BytePacketBuffer) -> Self {

        let header = DnsHeader::from(buffer).unwrap();

        let mut questions = Vec::new();
        let mut answers = Vec::new();
        let mut authorities = Vec::new();
        let mut resources = Vec::new();

        for _ in 0..header.question_count {
            let question = DnsQuestion::from(buffer);
            questions.push(question)
        }

        for _ in 0..header.answer_count {
            let answer = DnsRecord::from(buffer);
            answers.push(answer)
        }

        for _ in 0..header.authority_count {
            let authority = DnsRecord::from(buffer);
            authorities.push(authority)
        }

        for _ in 0..header.authority_count {
            let rec = DnsRecord::from(buffer);
            resources.push(rec)
        }

        DnsPacket {
            header,
            questions,
            answers,
            authorities,
            resources
        }
    }

    pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), &str> {
        self.header.question_count = self.questions.len() as u16;
        self.header.answer_count = self.answers.len() as u16;
        self.header.authority_count = self.authorities.len() as u16;
        self.header.add_count = self.resources.len() as u16;

        self.header.write(buffer).unwrap();

        for question in &self.questions {
            question.write(buffer).unwrap();
        }
        for rec in &self.answers {
            rec.write(buffer).unwrap();
        }
        for rec in &self.authorities {
            rec.write(buffer).unwrap();
        }
        for rec in &self.resources {
            rec.write(buffer).unwrap();
        }

        Ok(())
    }


    pub fn send(&mut self, server: (&str, u16)) -> DnsPacket {

        let mut req_buffer = BytePacketBuffer::new();
        self.write(&mut req_buffer).unwrap();

        let socket = UdpSocket::bind(("0.0.0.0", 43210)).unwrap();
        socket.send_to(&req_buffer.buf[0..req_buffer.pos], server).unwrap();

        let mut rec_buffer = BytePacketBuffer::new();
        socket.recv_from(&mut rec_buffer.buf).unwrap();

        let response = DnsPacket::from_bytes(&mut rec_buffer);

        return response;
    }
}



#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ResultCode {
    NOERROR = 0,
    FORMERR = 1,
    SERVFAIL = 2,
    NXDOMAIN = 3,
    NOTIMP = 4,
    REFUSED = 5,
}

impl ResultCode {

    pub fn from_num(num: u8) -> Self {
        match num {
            1 => ResultCode::FORMERR,
            2 => ResultCode::SERVFAIL,
            3 => ResultCode::NXDOMAIN,
            4 => ResultCode::NOTIMP,
            5 => ResultCode::REFUSED,
            0 | _ => ResultCode::NOERROR,
        }
    }
}

// The holy  grail, the DNS Header itself
#[derive(Clone, Debug)]
pub struct DnsHeader {
    pub id: u16,                    // 16 bits
    pub query_response: bool,       // 1 bit
    pub operation_code: u8,         // 4 bits
    pub authoritative_answer: bool, // 1 bit
    pub truncated_message: bool,    // 1 bit
    pub recursion_desired: bool,    // 1 bit
    pub recursion_available: bool,  // 1 bit
    pub z: u8,                      // 3 bits
    pub result_code: ResultCode,  // 4 bits
    pub question_count: u16,        // 16 bits
    pub answer_count: u16,          // 16 bits
    pub authority_count: u16,       // 16 bits
    pub add_count: u16,             // 16 bits
}

impl DnsHeader {

    // returns a plain DnsHeader
    pub fn new(id: u16) -> Self {
        DnsHeader {
            id,
            query_response: false,
            operation_code: 0,
            authoritative_answer: false,
            truncated_message: false,
            recursion_desired: false,
            recursion_available: false,
            z: 0,
            result_code: ResultCode::NOERROR,
            question_count: 0,
            answer_count: 0,
            authority_count: 0,
            add_count: 0
        }
    }


    pub fn from(buffer: &mut BytePacketBuffer) -> Result<Self, &str> {
        let id = buffer.read_u16().unwrap();
        let flags = buffer.read_u16().unwrap();
        let bit1 = (flags >> 8) as u8;      // Most significant byte
        let bit2 = (flags & 0xFF) as u8;    // Least significant byte

        // parse the DNS Header from a 512 byte buffer
        Ok(DnsHeader {
            id,
            recursion_desired: (bit1 & (1 << 0)) > 0,
            truncated_message: (bit1 & (1 << 1)) > 0,
            authoritative_answer: (bit1 & (1 << 2)) > 0,
            operation_code: (bit1 >> 3) & 0xF,
            query_response: (bit1 >> 7) > 0,

            result_code: ResultCode::from_num(bit2 & 0x0F),
            z: (bit2 >> 0x4) & 0x7,
            recursion_available: (bit2 >> 7) > 0,

            question_count: buffer.read_u16().unwrap(),
            answer_count: buffer.read_u16().unwrap(),
            authority_count: buffer.read_u16().unwrap(),
            add_count: buffer.read_u16().unwrap(),
        })
    }

    pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), &str> {
        buffer.write_u16(self.id).unwrap();

        buffer.write_u8(
            (self.recursion_desired as u8)
            | ((self.truncated_message as u8) << 1)
            | ((self.authoritative_answer as u8) << 2)
            | (self.operation_code << 3)
            | ((self.query_response as u8) << 7) as u8
            ).unwrap();

        buffer.write_u8(
            (self.result_code as u8)
            | ((self.z as u8) << 6)
            | ((self.recursion_available as u8) << 7)
            ).unwrap();

        buffer.write_u16(self.question_count).unwrap();
        buffer.write_u16(self.answer_count).unwrap();
        buffer.write_u16(self.authority_count).unwrap();
        buffer.write_u16(self.add_count).unwrap();
        Ok(())
    }
}


#[derive(Copy, Clone, PartialEq, Eq, Debug, Ord, PartialOrd, Hash)]
pub enum QueryType {
    UNKNOWN(u16),
    A,      // = 1
    NS,     // = 2
    CNAME,  // = 5
    MX,     // = 15
    AAAA,   // 28
    TXT,    // 16
}

impl QueryType {
    pub fn to_num(&self) -> u16 {
        match *self {
            QueryType::UNKNOWN(x) => x,
            QueryType::A        => 1,
            QueryType::NS       => 2,
            QueryType::CNAME    => 5,
            QueryType::MX       => 15,
            QueryType::TXT      => 16,
            QueryType::AAAA     => 28,
        }
    }

    pub fn from_num(num: u16) -> Self {
        match num {
            1   => QueryType::A,
            2   => QueryType::NS,
            5   => QueryType::CNAME,
            15  => QueryType::MX,
            16  => QueryType::TXT,
            28  => QueryType::AAAA,
            _ => QueryType::UNKNOWN(num),
        }

    }
}



#[derive(Clone, PartialEq, Eq, Debug)]
pub struct DnsQuestion {
    pub name: String,
    pub qtype: QueryType
}

impl DnsQuestion {

    pub fn new(name: String, qtype: QueryType) -> Self {
        DnsQuestion {
            name,
            qtype
        }
    }

    pub fn from(buffer: &mut BytePacketBuffer) -> Self {

        let name = buffer.read_qname().unwrap();
        let qtype = QueryType::from_num(buffer.read_u16().unwrap());
        let _ = buffer.read_u16().unwrap();
        DnsQuestion {
            name,
            qtype
        }
    }

    pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), &str> {

        buffer.write_qname(&self.name).unwrap();
        let qtype_num = self.qtype.to_num();
        buffer.write_u16(qtype_num).unwrap();
        buffer.write_u16(1).unwrap();

        Ok(())
    }

}


#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)]
pub enum DnsRecord {
    UNKNOWN {
        domain: String,
        qtype: QueryType,
        data_len: u16,
        ttl: u32
    },
    A {
        domain: String,
        addr: Ipv4Addr,
        ttl: u32
    },
    NS {
        domain: String,
        host: String,
        ttl: u32
    },
    CNAME {
        domain: String,
        host: String,
        ttl: u32
    },
    MX {
        domain: String,
        priority: u16,
        host: String,
        ttl: u32
    },
    AAAA {
        domain: String,
        addr: Ipv6Addr,
        ttl: u32
    },
    TXT {
        domain: String,
        len: u8,
        txt: String,
        ttl: u32
    },
}

impl DnsRecord {

    pub fn from(buffer: &mut BytePacketBuffer) -> Self {
        let domain = buffer.read_qname().unwrap();
        let qtype_num = buffer.read_u16().unwrap();
        let qtype = QueryType::from_num(qtype_num);
        let _  = buffer.read_u16().unwrap();    // CLASS (should always be IN)
        let ttl = buffer.read_u32().unwrap();
        let data_len = buffer.read_u16().unwrap();

        match qtype {
            QueryType::A => {
                let raw_addr = buffer.read_u32().unwrap();
                let addr = Ipv4Addr::new(
                    ((raw_addr >> 24) & 0xFF) as u8,
                    ((raw_addr >> 16) & 0xFF) as u8,
                    ((raw_addr >> 8) & 0xFF) as u8,
                    ((raw_addr) & 0xFF) as u8);
                DnsRecord::A {
                    domain,
                    addr,
                    ttl
                }
            },
            QueryType::UNKNOWN(_) => {
                buffer.step(data_len as usize).unwrap();

                DnsRecord::UNKNOWN {
                    domain,
                    qtype,
                    data_len,
                    ttl
                }
            },
            QueryType::AAAA => {
                let raw_addr0 = buffer.read_u32().unwrap();
                let raw_addr1 = buffer.read_u32().unwrap();
                let raw_addr2 = buffer.read_u32().unwrap();
                let raw_addr3 = buffer.read_u32().unwrap();

                let addr = Ipv6Addr::new(
                    ((raw_addr0 >> 16) & 0xFF) as u16,
                    ((raw_addr0 >> 0) & 0xFF) as u16,
                    ((raw_addr1 >> 16) & 0xFF) as u16,
                    ((raw_addr1 >> 0) & 0xFF) as u16,
                    ((raw_addr2 >> 16) & 0xFF) as u16,
                    ((raw_addr2 >> 0) & 0xFF) as u16,
                    ((raw_addr3 >> 16) & 0xFF) as u16,
                    ((raw_addr3 >> 0) & 0xFF) as u16);
                DnsRecord::AAAA {
                    domain,
                    addr,
                    ttl
                }
            },
            QueryType::NS => {

                let ns = buffer.read_qname().unwrap();
                DnsRecord::NS {
                    domain,
                    host: ns,
                    ttl
                }
            },
            QueryType::CNAME => {

                let cname =  buffer.read_qname().unwrap();
                DnsRecord::CNAME {
                    domain,
                    host: cname,
                    ttl
                }
            },
            QueryType::MX => {

                let priority = buffer.read_u16().unwrap();
                let mx = buffer.read_qname().unwrap();
                DnsRecord::MX {
                    domain,
                    priority,
                    host: mx,
                    ttl
                }
            },
            QueryType::TXT => {

                let txt_len = buffer.read().unwrap();

                let mut s = Vec::with_capacity(txt_len as usize);

                for _ in 0..txt_len {
                    s.push(buffer.read().unwrap());
                }
                let txt =  String::from_utf8_lossy(&s).to_string();
                DnsRecord::TXT {
                    domain,
                    len: txt_len as u8,
                    txt,
                    ttl
                }
            }
        }
    }

    pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<usize, &str> {
        let start_pos = buffer.pos();

        match *self {
            DnsRecord::A {
                ref domain,
                ref addr,
                ttl
            } => {
                buffer.write_qname(domain).unwrap();
                buffer.write_u16(QueryType::A.to_num()).unwrap();
                buffer.write_u16(1).unwrap();
                buffer.write_u32(ttl).unwrap();
                buffer.write_u16(4).unwrap();

                let oct = addr.octets();
                buffer.write_u8(oct[0]).unwrap();
                buffer.write_u8(oct[1]).unwrap();
                buffer.write_u8(oct[2]).unwrap();
                buffer.write_u8(oct[3]).unwrap();
            }
            DnsRecord::AAAA {
                ref domain,
                ref addr,
                ttl
            } => {
                buffer.write_qname(domain).unwrap();
                buffer.write_u16(QueryType::AAAA.to_num()).unwrap();
                buffer.write_u16(1).unwrap();   // CLASS `IN`
                buffer.write_u32(ttl).unwrap(); // TTL
                buffer.write_u16(8).unwrap();   // Length

                for oct in addr.segments().iter() {
                    buffer.write_u16(*oct).unwrap();
                }

            },
            DnsRecord::NS {
                ref domain,
                ref host,
                ttl
            } => {
                buffer.write_qname(domain).unwrap();
                buffer.write_u16(QueryType::NS.to_num()).unwrap();
                buffer.write_u16(1).unwrap();
                buffer.write_u32(ttl).unwrap();

                let pos = buffer.pos();
                buffer.write_u16(0).unwrap();
                buffer.write_qname(host).unwrap();

                let size = buffer.pos() - (pos + 2);
                buffer.set_u16(pos, size as u16).unwrap();

            },
            DnsRecord::CNAME {
                ref domain,
                ref host,
                ttl
            } => {
                buffer.write_qname(domain).unwrap();
                buffer.write_u16(QueryType::CNAME.to_num()).unwrap();
                buffer.write_u16(1).unwrap();
                buffer.write_u32(ttl).unwrap();

                let pos = buffer.pos();
                buffer.write_u16(0).unwrap();       // write temporary size
                buffer.write_qname(host).unwrap();  // write host

                let size = buffer.pos() - (pos + 2);
                buffer.set_u16(pos, size as u16).unwrap();
            },
            DnsRecord::MX {
                ref domain,
                priority,
                ref host,
                ttl
            } => {
                buffer.write_qname(domain).unwrap();
                buffer.write_u16(QueryType::MX.to_num()).unwrap();
                buffer.write_u16(1).unwrap();
                buffer.write_u32(ttl).unwrap();


                let pos = buffer.pos();
                buffer.write_u16(0).unwrap();       // write temporary size
                buffer.write_u16(priority).unwrap();
                buffer.write_qname(host).unwrap();  // write host

                let size = buffer.pos() - (pos + 2);
                buffer.set_u16(pos, size as u16).unwrap();
            },
            DnsRecord::TXT {
                ref domain,
                len,
                ref txt,
                ttl
            } => {
                buffer.write_qname(domain).unwrap();
                buffer.write_u16(QueryType::TXT.to_num()).unwrap();
                buffer.write_u16(1).unwrap();
                buffer.write_u32(ttl).unwrap();

                let pos = buffer.pos();
                buffer.write_u16(0).unwrap();       // write temporary size

                buffer.write_u8(len).unwrap();
                for byte in txt.as_bytes() {
                    buffer.write_u8(*byte).unwrap();  // write byte by byte
                }

                let size = buffer.pos() - (pos + 2);
                buffer.set_u16(pos, size as u16).unwrap();
            },
            DnsRecord::UNKNOWN {..} => {
                return Err("Unknown DnsRecord");
            },
        }

        Ok(buffer.pos() - start_pos)
    }
}


/// Buffer containing the full DNS Packet
pub struct BytePacketBuffer {
    pub buf: [u8; 512],
    pub pos: usize,
}


impl BytePacketBuffer {

    pub fn new() -> Self {
        BytePacketBuffer { buf: [0; 512] ,pos: 0}
    }

    pub fn pos(&self) -> usize {
        self.pos
    }

    pub fn step(&mut self, steps: usize) -> Result<(), &str> {
        self.pos += steps;
        Ok(())
    }

    pub fn seek(&mut self, pos: usize) -> Result<(), &str> {
        self.pos = pos;
        Ok(())
    }

    /// Read a single byte from the buffer
    pub fn read(&mut self) -> Result<u8, &str> {
        if self.pos >= 512 {
            return Err("End of Buffer");
        }

        let ret = self.buf[self.pos];
        self.pos += 1;
        Ok(ret)
    }

    pub fn get(&self, pos: usize) -> Result<u8, &str> {
        if pos >= 512 {
            return Err("End of Buffer");
        }
        Ok(self.buf[pos])
    }

    pub fn get_range(&self, start: usize, len: usize) -> Result<&[u8], &str> {
        if (start + len) >= 512 {
            return Err("End of Buffer");
        }
        Ok(&self.buf[start..start+len])
    }

    pub fn read_u16(&mut self) -> Result<u16, &str> {
        if self.pos >= 512 {
            return Err("End of Buffer");
        }

        let mut res = (self.read().unwrap() as u16) << 8;
        res |=  self.read()? as u16;
        Ok(res)
    }

    pub fn read_u32(&mut self) -> Result<u32, &str> {
        if self.pos >= 512 {
            return Err("End of Buffer");
        }

        // Hack: calling ? on read() results in error
        // "Second mutable borrow occurs here"
        let mut res = (self.read().unwrap() as u32) << 24;
        res |=  (self.read().unwrap() as u32) << 16;
        res |=  (self.read().unwrap() as u32) << 8;
        res |=  self.read()? as u32;
        Ok(res)
    }

    /// Reading Domain Name in the following format
    /// [len][name][len][name] [...]
    /// The first [len] can be a jump to another position in the buffer
    /// if the two most sign. bits are set (0xC0)
    fn read_qname(&mut self) -> Result<String, &str> {


        let mut lpos = self.pos;
        let mut jumped = false;
        let mut num_jmps = 0;
        let max_jmps = 5;
        // delimiter to split the domain names
        let mut delim = "";

        let mut outstr = String::from("");

        // begin reading the domain names
        loop {

            // prevent endless looping jmps
            if num_jmps >= max_jmps {
                return Err("too many jumps");
            }

            let len = self.get(lpos).unwrap() as usize;
            if (len & 0xC0) == 0xC0 {

                if !jumped {
                    self.seek(lpos + 2).unwrap();
                }

                // This indicates a jump to another location in the buffer
                // the location is indicated by the following byte

                let sb = self.get(lpos + 1).unwrap() as u16;
                let sb = ((self.get(lpos).unwrap() as u16 ^ 0xC0) << 8) as u16 | sb;

                // set the position to the jmp location
                lpos = sb as usize;
                jumped = true;
                num_jmps += 1;

                continue;
            }
            else {
                // Read regular domain name
                lpos += 1;
                // if the len is 0, this means that the final part of
                // the domain names is arrived
                if len == 0 {
                    break
                }
                outstr.push_str(delim);
                let domain = &String::from_utf8_lossy(self.get_range(lpos, len).unwrap()).to_lowercase();
                outstr.push_str(domain);
                delim = ".";
                // move forward the local pos
                lpos += len;
            }


        }

        if !jumped {
            self.seek(lpos).unwrap();
        }
        Ok(outstr)
    }

    fn write(&mut self, val: u8) -> Result<(), &str> {
        if self.pos >= 512  {
            return Err("End of buffer");
        }

        self.buf[self.pos] = val;
        self.pos += 1;
        Ok(())
    }

    fn write_u8(&mut self, val: u8) -> Result<(), &str> {
        self.write(val)
    }

    fn write_u16(&mut self, val: u16) -> Result<(), &str> {

        self.write_u8((val >> 8) as u8).unwrap();
        self.write_u8((val & 0xFF) as u8).unwrap();
        Ok(())
    }
    fn write_u32(&mut self, val: u32) -> Result<(), &str> {

        self.write_u8(((val >> 24) & 0xFF) as u8).unwrap();
        self.write_u8(((val >> 16) & 0xFF) as u8).unwrap();
        self.write_u8(((val >> 8) & 0xFF) as u8).unwrap();
        self.write_u8(((val >> 0) & 0xFF) as u8).unwrap();
        Ok(())
    }

    fn write_qname(&mut self, qname: &str) -> Result<(), &str> {

        for label in qname.split(".") {
            let len = label.len();

            if len > 0x3F {
                return Err("Single Label exeeds 63 characters");
            }

            self.write_u8(len as u8).unwrap();

            for b in label.as_bytes() {
                self.write_u8(*b as u8).unwrap();
            }
        }

        self.write_u8(0).unwrap();
        Ok(())
    }

    fn set(&mut self, pos: usize, val: u8) -> Result<(), &str> {

        if pos >= 512 {
            return Err("End of Buffer");
        }
        self.buf[pos] = val;
        Ok(())
    }
    fn set_u16(&mut self, pos: usize, val: u16) -> Result<(), &str> {

        if pos >= 511 {
            return Err("End of Buffer");
        }
        self.buf[pos] = ((val >> 8) & 0xFF) as u8;
        self.buf[pos+1] = (val & 0xFF) as u8;
        Ok(())
    }
}