Files
seanime-docker/seanime-2.9.10/internal/doh/doh_test.go
2025-09-20 14:08:38 +01:00

87 lines
2.2 KiB
Go

package doh
import (
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/ncruces/go-dns"
)
func TestDoHResolver(t *testing.T) {
// Start a temporary HTTP test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello via DoH"))
}))
defer ts.Close()
// Extract hostname and port from the test server URL
host, port, err := net.SplitHostPort(ts.Listener.Addr().String())
if err != nil {
t.Fatalf("failed to parse test server address: %v", err)
}
// Mock a "DNS record" by pointing a custom hostname to the test server's IP
fakeHostname := "test.local"
dohURL := "https://cloudflare-dns.com/dns-query{?dns}"
// Set up the DoH resolver
resolver, err := dns.NewDoHResolver(dohURL, dns.DoHCache())
if err != nil {
t.Fatalf("failed to create DoH resolver: %v", err)
}
// Override the default resolver
net.DefaultResolver = resolver
// Use a custom DialContext to redirect fakeHostname to test server IP
dialer := &net.Dialer{
Timeout: 3 * time.Second,
Resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
// Shortcut: Always return test server's IP for "test.local"
d := net.Dialer{}
if network == "udp" || network == "tcp" {
return d.Dial(network, net.JoinHostPort(host, "53"))
}
return d.Dial(network, address)
},
},
}
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// Intercept DNS for fakeHostname only
if addr[:len(fakeHostname)] == fakeHostname {
addr = net.JoinHostPort(host, port)
}
return dialer.DialContext(ctx, network, addr)
},
},
}
// Make a request to the fake hostname (which we route to the test server)
resp, err := client.Get("http://" + fakeHostname + ":" + port)
if err != nil {
t.Fatalf("failed to GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected 200 OK, got %v", resp.Status)
}
// Read the response body
bodyR, err := io.ReadAll(resp.Body)
t.Log(string(bodyR))
}