#!/usr/bin/env ruby
# Copyright, 2009, 2012, by Samuel G. D. Williams. <http://www.codeotaku.com>
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# Pulls down DNS data from old-dns
# rd-dns-check -s old-dns.mydomain.com -d mydomain.com. -f old-dns.yml

# Check data against old-dns
# rd-dns-check -s old-dns.mydomain.com -d mydomain.com. -c old-dns.yml

# Check data against new DNS server
# rd-dns-check -s 10.0.0.36 -d mydomain.com. -c old-dns.yml

require 'yaml'
require 'optparse'
require 'set'

class DNSRecord
  def initialize(arr)
    @record = arr
    normalize
  end
  
  def normalize
    @record[0] = @record[0].downcase
    @record[1] = @record[1].upcase
    @record[2] = @record[2].upcase
    @record[3] = @record[3].downcase
  end
  
  def hostname
    @record[0]
  end
  
  def klass
    @record[1]
  end
  
  def type
    @record[2]
  end
  
  def value
    @record[3]
  end
  
  def is_address?
    ["A", "AAAA"].include?(type)
  end
  
  def is_cname?
    return type == "CNAME"
  end
  
  def to_s
    "#{hostname.ljust(50)} #{klass.rjust(4)} #{type.rjust(5)} #{value}"
  end
  
  def key
    "#{hostname}:#{klass}:#{type}".downcase
  end
  
  def to_a
    @record
  end
  
  def == other
    return @record == other.to_a
  end
end

def dig(dns_server, cmd, exclude = ["TXT", "HINFO", "SOA", "NS"])
	records = []
	
	IO.popen("dig @#{dns_server} +nottlid +nocmd +noall +answer " + cmd) do |p|
		p.each do |line|
			r = line.chomp.split(/\s/, 4)
			
			next if exclude.include?(r[2])
      
			records << DNSRecord.new(r)
		end
	end
	
	return records
end

def retrieve_records(dns_server, dns_root)
	return dig(dns_server, "#{dns_root} AXFR")
end

def resolve_hostname(dns_server, hostname)
  return dig(dns_server, "#{hostname} A").first
end

def resolve_address(dns_server, address)
  return dig(dns_server, "-x #{address}").first
end

def print_summary(records, errors, okay, &block)
  puts "[ Summary ]".center(72, "=")
  puts "Checked #{records.size} record(s). #{errors} errors."
  if errors == 0
    puts "Everything seemed okay."
  else
    puts "The following records are okay:"
    okay.each do |r|
      if block_given?
        yield r
      else
        puts "".rjust(12) + r.to_s
      end
    end
  end
end

# Resolve hostnames to IP address "A" or "AAAA" records.
# Works through CNAME records in order to find out the final
# address if possible. Checks for loops in CNAME records.
def resolve_addresses(records)
  addresses = {}
  cnames = {}
  
  # Extract all hostname -> ip address mappings
  records.each do |r|
    if r.is_address?
      addresses[r.hostname] = r
    elsif r.is_cname?
      cnames[r.hostname] = r
    end
  end
  
  cnames.each do |hostname, r|
    q = r
    trail = []
    failed = false
    
    # Keep track of CNAME records to avoid loops
    while q.is_cname?
      trail << q
      q = cnames[q.value] || addresses[q.value]
      
      # Q could be nil at this point, which means there was no address record
      # Q could be already part of the trail, which means there was a loop
      if q == nil || trail.include?(q)
        failed = true
        break
      end
    end
    
    if failed
      q = trail.last
      puts "*** Warning: CNAME record #{hostname} does not point to actual address!"
      trail.each_with_index do |r, idx|
        puts idx.to_s.rjust(10) + ": " + r.to_s
      end
    end
    
    addresses[r.hostname] = q
  end
  
  return addresses, cnames
end

def check_reverse(records, dns_server)
  errors = 0
  okay = []
  
  puts "[ Checking Reverse Lookups ]".center(72, "=")
  
  records.each do |r|
    next unless r.is_address?
    
    sr = resolve_address(dns_server, r.value)
    
    if sr == nil
      puts "*** Could not resolve host"
      puts "".rjust(12) + r.to_s
      errors += 1
    elsif r.hostname != sr.value
      puts "*** Hostname does not match"
      puts "Primary: ".rjust(12) + r.to_s
      puts "Secondary: ".rjust(12) + sr.to_s
      errors += 1
    else
      okay << [r, sr]
    end
  end
  
  print_summary(records, errors, okay) do |r|
    puts "Primary:".rjust(12) + r[0].to_s
    puts "Secondary:".rjust(12) + r[1].to_s
  end
end

def ping_records(records)
  addresses, cnames = resolve_addresses(records)

  errors = 0
  okay = []

  puts "[ Pinging Records ]".center(72, "=")
  
  addresses.each do |hostname, r|
    ping = "ping -c 5 -t 5 -i 1 -o #{r.value} > /dev/null"
    
    system(ping)
    
    if $?.exitstatus == 0
      okay << r
    else
      puts "*** Could not ping host #{hostname.dump}: #{ping.dump}"
      puts "".rjust(12) + r.to_s
      errors += 1
    end
  end
  
  print_summary(records, errors, okay)
end

def query_records(primary, secondary_server)
  addresses, cnames = resolve_addresses(primary)
  
  okay = []
  errors = 0
  
  primary.each do |r|
    sr = resolve_hostname(secondary_server, r.hostname)
    
    if sr == nil
      puts "*** Could not resolve hostname #{r.hostname.dump}"
      puts "Primary: ".rjust(12) + r.to_s
      
      rsr = resolve_address(secondary_server, (addresses[r.value] || r).value)
      puts "Address: ".rjust(12) + rsr.to_s if rsr
      
      errors += 1
    elsif sr.value != r.value
      ra = addresses[r.value] if r.is_cname?
      sra = addresses[sr.value] if sr.is_cname?
      
      if (sra || sr).value != (ra || r).value
        puts "*** IP Address does not match"
        puts "Primary: ".rjust(12) + r.to_s
        puts "Resolved: ".rjust(12) + ra.to_s if ra
        puts "Secondary: ".rjust(12) + sr.to_s
        puts "Resolved: ".rjust(12) + sra.to_s if sra
        errors += 1
      end
    else
      okay << r
    end
  end
  
  print_summary(primary, errors, okay)
end

def check_records(primary, secondary)
  s = {}
  okay = []
  errors = 0
  
  secondary.each do |r|
    s[r.key] = r
  end

  puts "[ Checking Records ]".center(72, "=")

  primary.each do |r|
    sr = s[r.key]
    
    if sr == nil
      puts "*** Could not find record"
      puts "Primary: ".rjust(12) + r.to_s
      errors += 1
    elsif sr != r
      puts "*** Records are different"
      puts "Primary: ".rjust(12) + r.to_s
      puts "Secondary: ".rjust(12) + sr.to_s
      errors += 1
    else
      okay << r
    end
  end
  
  print_summary(primary, errors, okay)
end

OPTIONS = {
	:DNSServer => nil,
	:DNSRoot => ".",
}

ARGV.options do |o|
  script_name = File.basename($0)
  
  o.set_summary_indent('  ')
  o.banner = "Usage: #{script_name} [options]"
  o.define_head "This script is designed to test and check DNS servers."
  
  o.on("-s ns.my.domain.", "--server ns.my.domain.", String, "The DNS server to query.") { |host| OPTIONS[:DNSServer] = host }
  o.on("-d my.domain.", "--domain my.domain.", String, "The DNS zone to transfer/test.") { |host| OPTIONS[:DNSRoot] = host }
  
  o.on("-f output.yml", "--fetch output.yml", String, "Pull down a list of hosts. Filters TXT and HINFO records. DNS transfers must be enabled.") { |f|
    records = retrieve_records(OPTIONS[:DNSServer], OPTIONS[:DNSRoot])
    
    output = (f ? File.open(f, "w") : STDOUT)
    
    output.write(YAML::dump(records))
    
    puts "#{records.size} record(s) retrieved."
  }
  
  o.on("-c input.yml", "--check input.yml", String, "Check that the DNS server returns results as specified by the file.") { |f|
    input = (f ? File.open(f) : STDIN)
    
    master_records = YAML::load(input.read)
    secondary_records = retrieve_records(OPTIONS[:DNSServer], OPTIONS[:DNSRoot])
    
    check_records(master_records, secondary_records)
  }
  
  o.on("-q input.yml", "--query input.yml", String, "Query the remote DNS server with all hostnames in the given file, and checks the IP addresses are consistent.") { |f|
    input = (f ? File.open(f) : STDIN)
    
    master_records = YAML::load(input.read)
    
    query_records(master_records, OPTIONS[:DNSServer])
  }
  
  o.on("-p input.yml", "--ping input.yml", String, "Ping all hosts to check if they are available or not.") { |f|
    input = (f ? File.open(f) : STDIN)
    
    master_records = YAML::load(input.read)
    
    ping_records(master_records)
  }
  
  o.on("-r input.yml", "--reverse input.yml", String, "Check that all address records have appropriate reverse entries.") { |f|
    input = (f ? File.open(f) : STDIN)
    
    master_records = YAML::load(input.read)
    
    check_reverse(master_records, OPTIONS[:DNSServer])
  }
  
  o.separator ""
  o.separator "Help and Copyright information"
  
  o.on_tail("--copy", "Display copyright information") {
    puts "#{script_name}. Copyright (c) 2009, 2011 Samuel Williams. Released under the MIT license."
    puts "See http://www.oriontransfer.co.nz/ for more information."
    exit
  }
  
  o.on_tail("-h", "--help", "Show this help message.") { puts o; exit }
end.parse!