use miekg/dns instead of net.Resolver

This commit is contained in:
Jordan Whited 2020-05-15 14:18:45 -07:00
parent dbe0623624
commit 27efcca09f

View File

@ -21,7 +21,7 @@ var (
deviceFlag = flag.String("device", "", deviceFlag = flag.String("device", "",
"name of Wireguard device to manage") "name of Wireguard device to manage")
dnsServerFlag = flag.String("dns", "", dnsServerFlag = flag.String("dns", "",
"ip:port of DNS server; defaults to OS resolver") "ip:port of DNS server")
dnsZoneFlag = flag.String("zone", "", "dns zone name") dnsZoneFlag = flag.String("zone", "", "dns zone name")
) )
@ -33,21 +33,12 @@ func main() {
if len(*dnsZoneFlag) < 1 { if len(*dnsZoneFlag) < 1 {
log.Fatal("missing zone flag") log.Fatal("missing zone flag")
} }
resolver := net.DefaultResolver if len(*dnsServerFlag) < 1 {
if len(*dnsServerFlag) > 0 { log.Fatal("missing dns flag")
_, _, err := net.SplitHostPort(*dnsServerFlag) }
if err != nil { _, _, err := net.SplitHostPort(*dnsServerFlag)
log.Fatalf("invalid dns server flag: %v", err) if err != nil {
} log.Fatalf("invalid dns flag value: %v", err)
dialer := net.Dialer{}
dialFn := func(ctx context.Context, network, address string) (net.Conn,
error) {
return dialer.DialContext(ctx, network, *dnsServerFlag)
}
resolver = &net.Resolver{
PreferGo: true,
Dial: dialFn,
}
} }
wgClient, err := wgctrl.New() wgClient, err := wgctrl.New()
if err != nil { if err != nil {
@ -69,6 +60,7 @@ func main() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
dnsClient := &dns.Client{}
for _, peer := range wgDevice.Peers { for _, peer := range wgDevice.Peers {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -78,42 +70,51 @@ func main() {
srvCtx, srvCancel := context.WithCancel(ctx) srvCtx, srvCancel := context.WithCancel(ctx)
pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:]) pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:]) pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:])
queryName := fmt.Sprintf("%s._wireguard._udp.%s", m := &dns.Msg{}
question := fmt.Sprintf("%s._wireguard._udp.%s",
pubKeyBase32, dns.Fqdn(*dnsZoneFlag)) pubKeyBase32, dns.Fqdn(*dnsZoneFlag))
_, srvs, err := resolver.LookupSRV(srvCtx, "", "", m.SetQuestion(question, dns.TypeSRV).RecursionDesired = false
queryName) r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
srvCancel() srvCancel()
if err != nil { if err != nil {
log.Printf( log.Printf(
"failed to lookup SRV for peer %s error: %v", "[%s] failed to lookup SRV: %v", pubKeyBase64, err)
pubKeyBase64, err)
continue continue
} }
if len(srvs) < 1 { if len(r.Answer) < 1 {
log.Printf("no SRV records found for peer %s", log.Printf("[%s] no SRV records found", pubKeyBase64)
pubKeyBase64)
continue continue
} }
hostCtx, hostCancel := context.WithCancel(ctx) srv, ok := r.Answer[0].(*dns.SRV)
addrs, err := resolver.LookupIPAddr(hostCtx, srvs[1].Target) if !ok {
hostCancel()
if err != nil {
log.Printf( log.Printf(
"failed to lookup A/AAAA for peer %s error: %v", "[%s] non-SRV answer in response to SRV query: %s",
pubKeyBase64, err) pubKeyBase64, r.Answer[0].String())
continue
} }
if len(addrs) < 1 { if len(r.Extra) < 1 {
log.Printf("no A/AAAA records found for peer %s", log.Printf("[%s] SRV response missing extra A/AAAA",
pubKeyBase64) pubKeyBase64)
continue }
var endpointIP net.IP
hostA, ok := r.Answer[0].(*dns.A)
if !ok {
hostAAAA, ok := r.Answer[0].(*dns.AAAA)
if !ok {
log.Printf(
"[%s] non-A/AAAA extra in SRV response: %s",
pubKeyBase64, r.Extra[0].Header())
continue
}
endpointIP = hostAAAA.AAAA
} else {
endpointIP = hostA.A
} }
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
PublicKey: peer.PublicKey, PublicKey: peer.PublicKey,
UpdateOnly: true, UpdateOnly: true,
Endpoint: &net.UDPAddr{ Endpoint: &net.UDPAddr{
IP: addrs[0].IP, IP: endpointIP,
Port: int(srvs[0].Port), Port: int(srv.Port),
}, },
} }
deviceConfig := wgtypes.Config{ deviceConfig := wgtypes.Config{