From 5c11196a75cd04ecf4f2ab53c52d4d9808a0ff66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Ganne?= Date: Mon, 7 Dec 2020 10:56:31 +0100 Subject: [PATCH] add support for peer auto-discovery in wgsd-client This patch allows wgsd-client to auto-discover other mesh peers: - use peer allowed ip and pubkey from TXT record in SRV answer - iterate on peers extracted from PTR answer instead of local wireguard configuration: the client no longer need to pre-configured with all the peers - allow to connect to a specific peer using its pubkey or allowed ip --- cmd/wgsd-client/main.go | 205 +++++++++++++++++++++++++++------------- 1 file changed, 140 insertions(+), 65 deletions(-) diff --git a/cmd/wgsd-client/main.go b/cmd/wgsd-client/main.go index 27e5f4a..7897671 100644 --- a/cmd/wgsd-client/main.go +++ b/cmd/wgsd-client/main.go @@ -3,13 +3,12 @@ package main import ( "context" "encoding/base32" - "encoding/base64" "flag" - "fmt" "log" "net" "os" "os/signal" + "strings" "syscall" "time" @@ -24,6 +23,12 @@ var ( dnsServerFlag = flag.String("dns", "", "ip:port of DNS server") dnsZoneFlag = flag.String("zone", "", "dns zone name") + serviceFlag = flag.String("service", "", "service ip or public key") +) + +const ( + keyLen = 56 // the number of characters in a base32-encoded Wireguard public key + spPrefix = "_wireguard._udp." ) func main() { @@ -64,76 +69,45 @@ func main() { dnsClient := &dns.Client{ Timeout: time.Second * 5, } - for _, peer := range wgDevice.Peers { + + suffix := spPrefix + dns.Fqdn(*dnsZoneFlag) + + if len(*serviceFlag) >= 1 { + fqdn := *serviceFlag + suffix = "." + suffix + if !strings.HasSuffix(fqdn, suffix) { + fqdn += suffix + } + ConnectPeer(ctx, wgClient, wgDevice, dnsClient, fqdn, *dnsServerFlag) + return + } + + srvCtx, srvCancel := context.WithCancel(ctx) + m := &dns.Msg{} + m.SetQuestion(suffix, dns.TypePTR) + r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag) + srvCancel() + if err != nil { + log.Printf("failed to lookup PTR: %v", err) + return + } + if len(r.Answer) < 1 { + log.Printf("no PTR records found") + return + } + + for _, answer := range r.Answer { select { case <-ctx.Done(): return default: } - srvCtx, srvCancel := context.WithCancel(ctx) - pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:]) - pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:]) - m := &dns.Msg{} - question := fmt.Sprintf("%s._wireguard._udp.%s", - pubKeyBase32, dns.Fqdn(*dnsZoneFlag)) - m.SetQuestion(question, dns.TypeSRV) - r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag) - srvCancel() - if err != nil { - log.Printf( - "[%s] failed to lookup SRV: %v", pubKeyBase64, err) + ptr, ok := answer.(*dns.PTR) + if !ok { + log.Printf("non-PTR answer in response to PTR query: %s", answer.String()) continue } - if len(r.Answer) < 1 { - log.Printf("[%s] no SRV records found", pubKeyBase64) - continue - } - srv, ok := r.Answer[0].(*dns.SRV) - if !ok { - log.Printf( - "[%s] non-SRV answer in response to SRV query: %s", - pubKeyBase64, r.Answer[0].String()) - } - if len(r.Extra) < 1 { - log.Printf("[%s] SRV response missing extra A/AAAA", - pubKeyBase64) - } - var endpointIP net.IP - hostA, ok := r.Extra[0].(*dns.A) - if !ok { - hostAAAA, ok := r.Extra[0].(*dns.AAAA) - if !ok { - log.Printf( - "[%s] non-A/AAAA extra in SRV response: %s", - pubKeyBase64, r.Extra[0].String()) - continue - } - endpointIP = hostAAAA.AAAA - } else { - endpointIP = hostA.A - } - peerConfig := wgtypes.PeerConfig{ - PublicKey: peer.PublicKey, - UpdateOnly: true, - Endpoint: &net.UDPAddr{ - IP: endpointIP, - Port: int(srv.Port), - }, - } - deviceConfig := wgtypes.Config{ - PrivateKey: &wgDevice.PrivateKey, - ReplacePeers: false, - Peers: []wgtypes.PeerConfig{peerConfig}, - } - if wgDevice.FirewallMark > 0 { - deviceConfig.FirewallMark = &wgDevice.FirewallMark - } - err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig) - if err != nil { - log.Printf( - "[%s] failed to configure peer on %s, error: %v", - pubKeyBase64, *deviceFlag, err) - } + ConnectPeer(ctx, wgClient, wgDevice, dnsClient, ptr.Ptr, *dnsServerFlag) } }() sigCh := make(chan os.Signal, 1) @@ -146,3 +120,104 @@ func main() { case <-done: } } + +func ConnectPeer(ctx context.Context, wgClient *wgctrl.Client, wgDevice *wgtypes.Device, dnsClient *dns.Client, serviceFqdn string, dnsServer string) { + srvCtx, srvCancel := context.WithCancel(ctx) + m := &dns.Msg{} + m.SetQuestion(serviceFqdn, dns.TypeSRV) + r, _, err := dnsClient.ExchangeContext(srvCtx, m, dnsServer) + srvCancel() + if err != nil { + log.Printf( + "[%s] failed to lookup SRV: %v", serviceFqdn, err) + return + } + if len(r.Answer) < 1 { + log.Printf("[%s] no SRV records found", serviceFqdn) + return + } + srv, ok := r.Answer[0].(*dns.SRV) + if !ok { + log.Printf( + "[%s] non-SRV answer in response to SRV query: %s", + serviceFqdn, r.Answer[0].String()) + return + } + if len(r.Extra) < 2 { + log.Printf("[%s] SRV response missing extra A/AAAA and TXT", + serviceFqdn) + return + } + var endpointIP net.IP + hostA, ok := r.Extra[0].(*dns.A) + if !ok { + hostAAAA, ok := r.Extra[0].(*dns.AAAA) + if !ok { + log.Printf( + "[%s] non-A/AAAA extra in SRV response: %s", + serviceFqdn, r.Extra[0].String()) + return + } + endpointIP = hostAAAA.AAAA + } else { + endpointIP = hostA.A + } + txt, ok := r.Extra[1].(*dns.TXT) + if !ok { + log.Printf("[%s] non-TXT extra in SRV response: %s", + serviceFqdn, r.Extra[1].String()) + return + } + allowedIPsString := strings.TrimPrefix(strings.ToLower(txt.Txt[0]), "allowedip=") + _, allowedIPs, err := net.ParseCIDR(allowedIPsString) + if err != nil { + log.Printf("[%s] failed to parse allowedip in TXT extra: %s", serviceFqdn, r.Extra[1].String()) + return + } + pubKeyString := strings.TrimPrefix(strings.ToUpper(txt.Txt[1]), "PUBKEY=") + if len(pubKeyString) < keyLen { + pubKeyString += strings.Repeat("=", keyLen-len(pubKeyString)) + } + pubKeyBytes, err := base32.StdEncoding.DecodeString(strings.ToUpper(pubKeyString)) + if err != nil { + log.Printf("[%s] failed to decode base32 key %s: %v", serviceFqdn, pubKeyString, err) + return + } + pubKeyWg, err := wgtypes.NewKey(pubKeyBytes) + if err != nil { + log.Printf("[%s] failed to create wg key: %v", serviceFqdn, err) + return + } + if pubKeyWg == wgDevice.PublicKey { + log.Printf("[%s] skipping ourself", serviceFqdn) + return + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKeyWg, + UpdateOnly: false, + Endpoint: &net.UDPAddr{ + IP: endpointIP, + Port: int(srv.Port), + }, + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{*allowedIPs}, + } + deviceConfig := wgtypes.Config{ + PrivateKey: &wgDevice.PrivateKey, + ReplacePeers: false, + Peers: []wgtypes.PeerConfig{peerConfig}, + } + if wgDevice.FirewallMark > 0 { + deviceConfig.FirewallMark = &wgDevice.FirewallMark + } + err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig) + if err != nil { + log.Printf( + "[%s] failed to configure peer on %s, error: %v", + serviceFqdn, *deviceFlag, err) + return + } + + log.Printf("[%s] configure peer on %s", serviceFqdn, *deviceFlag) +}