dohproxy/dohproxy.go

207 lines
5.4 KiB
Go
Raw Normal View History

2024-08-31 11:36:28 +01:00
package main
import (
"bytes"
"crypto/tls"
"flag"
"io"
"log"
"net"
"net/http"
"os"
"time"
"github.com/miekg/dns"
"golang.org/x/net/http2"
)
// ProtocolType represents the type of DNS protocol
type ProtocolType int
const (
// Enum-like constants for ProtocolType
TCP ProtocolType = iota
UDP
)
// DoHProxy represents the DNS-over-HTTPS proxy
type DoHProxy struct {
listenAddress string
port string
upstreamURLs []string
client *http.Client
protocols []ProtocolType // List of protocols to listen on
logRequests bool
}
// NewDoHProxy initializes a new DoHProxy instance
func NewDoHProxy(listenAddress, port string, upstreamURLs []string, protocols []ProtocolType, logRequests bool) *DoHProxy {
// HTTP client with support for HTTP/2
transport := &http.Transport{
TLSClientConfig: &tls.Config{},
}
http2.ConfigureTransport(transport) // Enable HTTP/2 support
// HTTP client for DoH requests
client := &http.Client{
Transport: transport,
Timeout: 5 * time.Second, // Set a timeout for requests
}
return &DoHProxy{
listenAddress: listenAddress,
port: port,
upstreamURLs: upstreamURLs,
client: client,
protocols: protocols,
logRequests: logRequests,
}
}
// HandleDNSRequest handles incoming DNS requests and forwards them to DoH servers
func (p *DoHProxy) HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
dnsQuery, err := r.Pack()
if err != nil {
log.Printf("Failed to pack DNS request: %v", err)
return
}
var response *http.Response
var currentUpstream string
// Send the DNS query to each upstream DoH server until one succeeds
for _, upstream := range p.upstreamURLs {
req, _ := http.NewRequest("POST", upstream, bytes.NewBuffer(dnsQuery))
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
response, err = p.client.Do(req)
if err == nil && response.StatusCode == http.StatusOK {
currentUpstream = upstream
break
}
}
if currentUpstream == "" {
if p.logRequests {
log.Printf("Failed to query any upstream")
}
return
}
if response == nil {
log.Printf("All upstream DoH servers failed")
dns.HandleFailed(w, r)
return
}
defer response.Body.Close()
// Read the response body
body, err := io.ReadAll(response.Body)
if err != nil {
log.Printf("Failed to read response from upstream: %v", err)
dns.HandleFailed(w, r)
return
}
// Unpack the DNS response and send it back to the client
dnsResponse := new(dns.Msg)
if err := dnsResponse.Unpack(body); err != nil {
log.Printf("Failed to unpack DNS response: %v", err)
dns.HandleFailed(w, r)
return
}
if p.logRequests {
log.Println("Successfully proxied request through ", currentUpstream)
}
w.WriteMsg(dnsResponse)
}
// Run starts the DNS server and listens for incoming queries based on the protocol
func (p *DoHProxy) Run() {
for _, proto := range p.protocols {
switch proto {
case TCP:
go func() {
server := &dns.Server{Addr: net.JoinHostPort(p.listenAddress, p.port), Net: "tcp"}
dns.HandleFunc(".", p.HandleDNSRequest)
log.Printf("Starting DoH Proxy on %s:%s over TCP", p.listenAddress, p.port)
if err := server.ListenAndServe(); err != nil {
log.Fatalf("Failed to start DNS server (TCP): %v", err)
}
}()
case UDP:
go func() {
server := &dns.Server{Addr: net.JoinHostPort(p.listenAddress, p.port), Net: "udp"}
dns.HandleFunc(".", p.HandleDNSRequest)
log.Printf("Starting DoH Proxy on %s:%s over UDP", p.listenAddress, p.port)
if err := server.ListenAndServe(); err != nil {
log.Fatalf("Failed to start DNS server (UDP): %v", err)
}
}()
}
}
// Keep the main goroutine running indefinitely
select {}
}
func main() {
// Define flags using the flag library
listenAddress := flag.String("l", "127.0.0.1", "Listen address for the DNS server")
port := flag.String("p", "53", "Port for the DNS server")
// Define flags for protocols
tcpFlag := flag.Bool("tcp", false, "Listen on TCP")
udpFlag := flag.Bool("udp", false, "Listen on UDP")
// Define flag for logging
logFlag := flag.Bool("log", false, "Log each request proxied through an upstream")
var upstreamURLs []string
// Custom flag for handling multiple upstream URLs
flag.Func("u", `Upstream DoH server URL (can be specified multiple times)
Example:
-u https://dns.quad9.net/dns-query -u https://1.1.1.1/dns-query
WARNING:
If this is your system's default DNS resolver
and the server URL is a domain name, at least one other
DNS server after this one must be specified as an IP address
in order to resolve the domain name of the first one.`, func(value string) error {
log.Printf("Added %s as an upstream DoH server\n", value)
upstreamURLs = append(upstreamURLs, value)
return nil
})
// Parse the flags
flag.Parse()
// Check if at least one upstream DoH URL is provided
if len(upstreamURLs) == 0 {
flag.PrintDefaults()
os.Exit(1)
}
// Determine which protocols to use
var protocols []ProtocolType
if *tcpFlag {
protocols = append(protocols, TCP)
}
if *udpFlag {
protocols = append(protocols, UDP)
}
if !*tcpFlag && !*udpFlag {
// Default to both if no specific flag is provided
protocols = []ProtocolType{TCP, UDP}
}
if *logFlag {
log.Println("Logging requests")
}
// Initialize and run the DoH proxy
proxy := NewDoHProxy(*listenAddress, *port, upstreamURLs, protocols, *logFlag)
proxy.Run()
}