From 27efcca09f1b52c19affc5b05252ae97b09652bd Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 15 May 2020 14:18:45 -0700 Subject: [PATCH] use miekg/dns instead of net.Resolver --- cmd/client/main.go | 73 +++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/cmd/client/main.go b/cmd/client/main.go index 6755904..b519ac6 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -21,7 +21,7 @@ var ( deviceFlag = flag.String("device", "", "name of Wireguard device to manage") dnsServerFlag = flag.String("dns", "", - "ip:port of DNS server; defaults to OS resolver") + "ip:port of DNS server") dnsZoneFlag = flag.String("zone", "", "dns zone name") ) @@ -33,21 +33,12 @@ func main() { if len(*dnsZoneFlag) < 1 { log.Fatal("missing zone flag") } - resolver := net.DefaultResolver - if len(*dnsServerFlag) > 0 { - _, _, err := net.SplitHostPort(*dnsServerFlag) - if err != nil { - log.Fatalf("invalid dns server flag: %v", err) - } - dialer := net.Dialer{} - dialFn := func(ctx context.Context, network, address string) (net.Conn, - error) { - return dialer.DialContext(ctx, network, *dnsServerFlag) - } - resolver = &net.Resolver{ - PreferGo: true, - Dial: dialFn, - } + if len(*dnsServerFlag) < 1 { + log.Fatal("missing dns flag") + } + _, _, err := net.SplitHostPort(*dnsServerFlag) + if err != nil { + log.Fatalf("invalid dns flag value: %v", err) } wgClient, err := wgctrl.New() if err != nil { @@ -69,6 +60,7 @@ func main() { done := make(chan struct{}) go func() { defer close(done) + dnsClient := &dns.Client{} for _, peer := range wgDevice.Peers { select { case <-ctx.Done(): @@ -78,42 +70,51 @@ func main() { srvCtx, srvCancel := context.WithCancel(ctx) pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:]) pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:]) - queryName := fmt.Sprintf("%s._wireguard._udp.%s", + m := &dns.Msg{} + question := fmt.Sprintf("%s._wireguard._udp.%s", pubKeyBase32, dns.Fqdn(*dnsZoneFlag)) - _, srvs, err := resolver.LookupSRV(srvCtx, "", "", - queryName) + m.SetQuestion(question, dns.TypeSRV).RecursionDesired = false + r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag) srvCancel() if err != nil { log.Printf( - "failed to lookup SRV for peer %s error: %v", - pubKeyBase64, err) + "[%s] failed to lookup SRV: %v", pubKeyBase64, err) continue } - if len(srvs) < 1 { - log.Printf("no SRV records found for peer %s", - pubKeyBase64) + if len(r.Answer) < 1 { + log.Printf("[%s] no SRV records found", pubKeyBase64) continue } - hostCtx, hostCancel := context.WithCancel(ctx) - addrs, err := resolver.LookupIPAddr(hostCtx, srvs[1].Target) - hostCancel() - if err != nil { + srv, ok := r.Answer[0].(*dns.SRV) + if !ok { log.Printf( - "failed to lookup A/AAAA for peer %s error: %v", - pubKeyBase64, err) - continue + "[%s] non-SRV answer in response to SRV query: %s", + pubKeyBase64, r.Answer[0].String()) } - if len(addrs) < 1 { - log.Printf("no A/AAAA records found for peer %s", + if len(r.Extra) < 1 { + log.Printf("[%s] SRV response missing extra A/AAAA", pubKeyBase64) - continue + } + var endpointIP net.IP + hostA, ok := r.Answer[0].(*dns.A) + if !ok { + hostAAAA, ok := r.Answer[0].(*dns.AAAA) + if !ok { + log.Printf( + "[%s] non-A/AAAA extra in SRV response: %s", + pubKeyBase64, r.Extra[0].Header()) + continue + } + endpointIP = hostAAAA.AAAA + } else { + endpointIP = hostA.A } peerConfig := wgtypes.PeerConfig{ PublicKey: peer.PublicKey, UpdateOnly: true, Endpoint: &net.UDPAddr{ - IP: addrs[0].IP, - Port: int(srvs[0].Port), + IP: endpointIP, + Port: int(srv.Port), }, } deviceConfig := wgtypes.Config{