add support for peer auto-discovery in wgsd-client

This patch allows wgsd-client to auto-discover other mesh peers:
 - use peer allowed ip and pubkey from TXT record in SRV answer
 - iterate on peers extracted from PTR answer instead of local wireguard
configuration: the client no longer need to pre-configured with all the
peers
 - allow to connect to a specific peer using its pubkey or allowed ip
This commit is contained in:
Benoît Ganne 2020-12-07 10:56:31 +01:00
parent 6dd954687e
commit 5c11196a75

View File

@ -3,13 +3,12 @@ package main
import ( import (
"context" "context"
"encoding/base32" "encoding/base32"
"encoding/base64"
"flag" "flag"
"fmt"
"log" "log"
"net" "net"
"os" "os"
"os/signal" "os/signal"
"strings"
"syscall" "syscall"
"time" "time"
@ -24,6 +23,12 @@ var (
dnsServerFlag = flag.String("dns", "", dnsServerFlag = flag.String("dns", "",
"ip:port of DNS server") "ip:port of DNS server")
dnsZoneFlag = flag.String("zone", "", "dns zone name") dnsZoneFlag = flag.String("zone", "", "dns zone name")
serviceFlag = flag.String("service", "", "service ip or public key")
)
const (
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
spPrefix = "_wireguard._udp."
) )
func main() { func main() {
@ -64,39 +69,84 @@ func main() {
dnsClient := &dns.Client{ dnsClient := &dns.Client{
Timeout: time.Second * 5, Timeout: time.Second * 5,
} }
for _, peer := range wgDevice.Peers {
suffix := spPrefix + dns.Fqdn(*dnsZoneFlag)
if len(*serviceFlag) >= 1 {
fqdn := *serviceFlag
suffix = "." + suffix
if !strings.HasSuffix(fqdn, suffix) {
fqdn += suffix
}
ConnectPeer(ctx, wgClient, wgDevice, dnsClient, fqdn, *dnsServerFlag)
return
}
srvCtx, srvCancel := context.WithCancel(ctx)
m := &dns.Msg{}
m.SetQuestion(suffix, dns.TypePTR)
r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
srvCancel()
if err != nil {
log.Printf("failed to lookup PTR: %v", err)
return
}
if len(r.Answer) < 1 {
log.Printf("no PTR records found")
return
}
for _, answer := range r.Answer {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
} }
ptr, ok := answer.(*dns.PTR)
if !ok {
log.Printf("non-PTR answer in response to PTR query: %s", answer.String())
continue
}
ConnectPeer(ctx, wgClient, wgDevice, dnsClient, ptr.Ptr, *dnsServerFlag)
}
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigCh:
log.Printf("exiting due to signal %s", sig)
cancel()
<-done
case <-done:
}
}
func ConnectPeer(ctx context.Context, wgClient *wgctrl.Client, wgDevice *wgtypes.Device, dnsClient *dns.Client, serviceFqdn string, dnsServer string) {
srvCtx, srvCancel := context.WithCancel(ctx) srvCtx, srvCancel := context.WithCancel(ctx)
pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:])
m := &dns.Msg{} m := &dns.Msg{}
question := fmt.Sprintf("%s._wireguard._udp.%s", m.SetQuestion(serviceFqdn, dns.TypeSRV)
pubKeyBase32, dns.Fqdn(*dnsZoneFlag)) r, _, err := dnsClient.ExchangeContext(srvCtx, m, dnsServer)
m.SetQuestion(question, dns.TypeSRV)
r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
srvCancel() srvCancel()
if err != nil { if err != nil {
log.Printf( log.Printf(
"[%s] failed to lookup SRV: %v", pubKeyBase64, err) "[%s] failed to lookup SRV: %v", serviceFqdn, err)
continue return
} }
if len(r.Answer) < 1 { if len(r.Answer) < 1 {
log.Printf("[%s] no SRV records found", pubKeyBase64) log.Printf("[%s] no SRV records found", serviceFqdn)
continue return
} }
srv, ok := r.Answer[0].(*dns.SRV) srv, ok := r.Answer[0].(*dns.SRV)
if !ok { if !ok {
log.Printf( log.Printf(
"[%s] non-SRV answer in response to SRV query: %s", "[%s] non-SRV answer in response to SRV query: %s",
pubKeyBase64, r.Answer[0].String()) serviceFqdn, r.Answer[0].String())
return
} }
if len(r.Extra) < 1 { if len(r.Extra) < 2 {
log.Printf("[%s] SRV response missing extra A/AAAA", log.Printf("[%s] SRV response missing extra A/AAAA and TXT",
pubKeyBase64) serviceFqdn)
return
} }
var endpointIP net.IP var endpointIP net.IP
hostA, ok := r.Extra[0].(*dns.A) hostA, ok := r.Extra[0].(*dns.A)
@ -105,20 +155,53 @@ func main() {
if !ok { if !ok {
log.Printf( log.Printf(
"[%s] non-A/AAAA extra in SRV response: %s", "[%s] non-A/AAAA extra in SRV response: %s",
pubKeyBase64, r.Extra[0].String()) serviceFqdn, r.Extra[0].String())
continue return
} }
endpointIP = hostAAAA.AAAA endpointIP = hostAAAA.AAAA
} else { } else {
endpointIP = hostA.A endpointIP = hostA.A
} }
txt, ok := r.Extra[1].(*dns.TXT)
if !ok {
log.Printf("[%s] non-TXT extra in SRV response: %s",
serviceFqdn, r.Extra[1].String())
return
}
allowedIPsString := strings.TrimPrefix(strings.ToLower(txt.Txt[0]), "allowedip=")
_, allowedIPs, err := net.ParseCIDR(allowedIPsString)
if err != nil {
log.Printf("[%s] failed to parse allowedip in TXT extra: %s", serviceFqdn, r.Extra[1].String())
return
}
pubKeyString := strings.TrimPrefix(strings.ToUpper(txt.Txt[1]), "PUBKEY=")
if len(pubKeyString) < keyLen {
pubKeyString += strings.Repeat("=", keyLen-len(pubKeyString))
}
pubKeyBytes, err := base32.StdEncoding.DecodeString(strings.ToUpper(pubKeyString))
if err != nil {
log.Printf("[%s] failed to decode base32 key %s: %v", serviceFqdn, pubKeyString, err)
return
}
pubKeyWg, err := wgtypes.NewKey(pubKeyBytes)
if err != nil {
log.Printf("[%s] failed to create wg key: %v", serviceFqdn, err)
return
}
if pubKeyWg == wgDevice.PublicKey {
log.Printf("[%s] skipping ourself", serviceFqdn)
return
}
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
PublicKey: peer.PublicKey, PublicKey: pubKeyWg,
UpdateOnly: true, UpdateOnly: false,
Endpoint: &net.UDPAddr{ Endpoint: &net.UDPAddr{
IP: endpointIP, IP: endpointIP,
Port: int(srv.Port), Port: int(srv.Port),
}, },
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*allowedIPs},
} }
deviceConfig := wgtypes.Config{ deviceConfig := wgtypes.Config{
PrivateKey: &wgDevice.PrivateKey, PrivateKey: &wgDevice.PrivateKey,
@ -132,17 +215,9 @@ func main() {
if err != nil { if err != nil {
log.Printf( log.Printf(
"[%s] failed to configure peer on %s, error: %v", "[%s] failed to configure peer on %s, error: %v",
pubKeyBase64, *deviceFlag, err) serviceFqdn, *deviceFlag, err)
} return
}
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigCh:
log.Printf("exiting due to signal %s", sig)
cancel()
<-done
case <-done:
} }
log.Printf("[%s] configure peer on %s", serviceFqdn, *deviceFlag)
} }