mirror of
https://github.com/jwhited/wgsd.git
synced 2025-04-04 11:09:31 +08:00
add support for peer auto-discovery in wgsd-client
This patch allows wgsd-client to auto-discover other mesh peers: - use peer allowed ip and pubkey from TXT record in SRV answer - iterate on peers extracted from PTR answer instead of local wireguard configuration: the client no longer need to pre-configured with all the peers - allow to connect to a specific peer using its pubkey or allowed ip
This commit is contained in:
parent
6dd954687e
commit
5c11196a75
@ -3,13 +3,12 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"encoding/base64"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -24,6 +23,12 @@ var (
|
|||||||
dnsServerFlag = flag.String("dns", "",
|
dnsServerFlag = flag.String("dns", "",
|
||||||
"ip:port of DNS server")
|
"ip:port of DNS server")
|
||||||
dnsZoneFlag = flag.String("zone", "", "dns zone name")
|
dnsZoneFlag = flag.String("zone", "", "dns zone name")
|
||||||
|
serviceFlag = flag.String("service", "", "service ip or public key")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
|
||||||
|
spPrefix = "_wireguard._udp."
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -64,76 +69,45 @@ func main() {
|
|||||||
dnsClient := &dns.Client{
|
dnsClient := &dns.Client{
|
||||||
Timeout: time.Second * 5,
|
Timeout: time.Second * 5,
|
||||||
}
|
}
|
||||||
for _, peer := range wgDevice.Peers {
|
|
||||||
|
suffix := spPrefix + dns.Fqdn(*dnsZoneFlag)
|
||||||
|
|
||||||
|
if len(*serviceFlag) >= 1 {
|
||||||
|
fqdn := *serviceFlag
|
||||||
|
suffix = "." + suffix
|
||||||
|
if !strings.HasSuffix(fqdn, suffix) {
|
||||||
|
fqdn += suffix
|
||||||
|
}
|
||||||
|
ConnectPeer(ctx, wgClient, wgDevice, dnsClient, fqdn, *dnsServerFlag)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srvCtx, srvCancel := context.WithCancel(ctx)
|
||||||
|
m := &dns.Msg{}
|
||||||
|
m.SetQuestion(suffix, dns.TypePTR)
|
||||||
|
r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
|
||||||
|
srvCancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to lookup PTR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(r.Answer) < 1 {
|
||||||
|
log.Printf("no PTR records found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, answer := range r.Answer {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
srvCtx, srvCancel := context.WithCancel(ctx)
|
ptr, ok := answer.(*dns.PTR)
|
||||||
pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
|
if !ok {
|
||||||
pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:])
|
log.Printf("non-PTR answer in response to PTR query: %s", answer.String())
|
||||||
m := &dns.Msg{}
|
|
||||||
question := fmt.Sprintf("%s._wireguard._udp.%s",
|
|
||||||
pubKeyBase32, dns.Fqdn(*dnsZoneFlag))
|
|
||||||
m.SetQuestion(question, dns.TypeSRV)
|
|
||||||
r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
|
|
||||||
srvCancel()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf(
|
|
||||||
"[%s] failed to lookup SRV: %v", pubKeyBase64, err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if len(r.Answer) < 1 {
|
ConnectPeer(ctx, wgClient, wgDevice, dnsClient, ptr.Ptr, *dnsServerFlag)
|
||||||
log.Printf("[%s] no SRV records found", pubKeyBase64)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
srv, ok := r.Answer[0].(*dns.SRV)
|
|
||||||
if !ok {
|
|
||||||
log.Printf(
|
|
||||||
"[%s] non-SRV answer in response to SRV query: %s",
|
|
||||||
pubKeyBase64, r.Answer[0].String())
|
|
||||||
}
|
|
||||||
if len(r.Extra) < 1 {
|
|
||||||
log.Printf("[%s] SRV response missing extra A/AAAA",
|
|
||||||
pubKeyBase64)
|
|
||||||
}
|
|
||||||
var endpointIP net.IP
|
|
||||||
hostA, ok := r.Extra[0].(*dns.A)
|
|
||||||
if !ok {
|
|
||||||
hostAAAA, ok := r.Extra[0].(*dns.AAAA)
|
|
||||||
if !ok {
|
|
||||||
log.Printf(
|
|
||||||
"[%s] non-A/AAAA extra in SRV response: %s",
|
|
||||||
pubKeyBase64, r.Extra[0].String())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
endpointIP = hostAAAA.AAAA
|
|
||||||
} else {
|
|
||||||
endpointIP = hostA.A
|
|
||||||
}
|
|
||||||
peerConfig := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peer.PublicKey,
|
|
||||||
UpdateOnly: true,
|
|
||||||
Endpoint: &net.UDPAddr{
|
|
||||||
IP: endpointIP,
|
|
||||||
Port: int(srv.Port),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
deviceConfig := wgtypes.Config{
|
|
||||||
PrivateKey: &wgDevice.PrivateKey,
|
|
||||||
ReplacePeers: false,
|
|
||||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
|
||||||
}
|
|
||||||
if wgDevice.FirewallMark > 0 {
|
|
||||||
deviceConfig.FirewallMark = &wgDevice.FirewallMark
|
|
||||||
}
|
|
||||||
err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf(
|
|
||||||
"[%s] failed to configure peer on %s, error: %v",
|
|
||||||
pubKeyBase64, *deviceFlag, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
@ -146,3 +120,104 @@ func main() {
|
|||||||
case <-done:
|
case <-done:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConnectPeer(ctx context.Context, wgClient *wgctrl.Client, wgDevice *wgtypes.Device, dnsClient *dns.Client, serviceFqdn string, dnsServer string) {
|
||||||
|
srvCtx, srvCancel := context.WithCancel(ctx)
|
||||||
|
m := &dns.Msg{}
|
||||||
|
m.SetQuestion(serviceFqdn, dns.TypeSRV)
|
||||||
|
r, _, err := dnsClient.ExchangeContext(srvCtx, m, dnsServer)
|
||||||
|
srvCancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf(
|
||||||
|
"[%s] failed to lookup SRV: %v", serviceFqdn, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(r.Answer) < 1 {
|
||||||
|
log.Printf("[%s] no SRV records found", serviceFqdn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv, ok := r.Answer[0].(*dns.SRV)
|
||||||
|
if !ok {
|
||||||
|
log.Printf(
|
||||||
|
"[%s] non-SRV answer in response to SRV query: %s",
|
||||||
|
serviceFqdn, r.Answer[0].String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(r.Extra) < 2 {
|
||||||
|
log.Printf("[%s] SRV response missing extra A/AAAA and TXT",
|
||||||
|
serviceFqdn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var endpointIP net.IP
|
||||||
|
hostA, ok := r.Extra[0].(*dns.A)
|
||||||
|
if !ok {
|
||||||
|
hostAAAA, ok := r.Extra[0].(*dns.AAAA)
|
||||||
|
if !ok {
|
||||||
|
log.Printf(
|
||||||
|
"[%s] non-A/AAAA extra in SRV response: %s",
|
||||||
|
serviceFqdn, r.Extra[0].String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
endpointIP = hostAAAA.AAAA
|
||||||
|
} else {
|
||||||
|
endpointIP = hostA.A
|
||||||
|
}
|
||||||
|
txt, ok := r.Extra[1].(*dns.TXT)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("[%s] non-TXT extra in SRV response: %s",
|
||||||
|
serviceFqdn, r.Extra[1].String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
allowedIPsString := strings.TrimPrefix(strings.ToLower(txt.Txt[0]), "allowedip=")
|
||||||
|
_, allowedIPs, err := net.ParseCIDR(allowedIPsString)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[%s] failed to parse allowedip in TXT extra: %s", serviceFqdn, r.Extra[1].String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pubKeyString := strings.TrimPrefix(strings.ToUpper(txt.Txt[1]), "PUBKEY=")
|
||||||
|
if len(pubKeyString) < keyLen {
|
||||||
|
pubKeyString += strings.Repeat("=", keyLen-len(pubKeyString))
|
||||||
|
}
|
||||||
|
pubKeyBytes, err := base32.StdEncoding.DecodeString(strings.ToUpper(pubKeyString))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[%s] failed to decode base32 key %s: %v", serviceFqdn, pubKeyString, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pubKeyWg, err := wgtypes.NewKey(pubKeyBytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[%s] failed to create wg key: %v", serviceFqdn, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pubKeyWg == wgDevice.PublicKey {
|
||||||
|
log.Printf("[%s] skipping ourself", serviceFqdn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peerConfig := wgtypes.PeerConfig{
|
||||||
|
PublicKey: pubKeyWg,
|
||||||
|
UpdateOnly: false,
|
||||||
|
Endpoint: &net.UDPAddr{
|
||||||
|
IP: endpointIP,
|
||||||
|
Port: int(srv.Port),
|
||||||
|
},
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: []net.IPNet{*allowedIPs},
|
||||||
|
}
|
||||||
|
deviceConfig := wgtypes.Config{
|
||||||
|
PrivateKey: &wgDevice.PrivateKey,
|
||||||
|
ReplacePeers: false,
|
||||||
|
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||||
|
}
|
||||||
|
if wgDevice.FirewallMark > 0 {
|
||||||
|
deviceConfig.FirewallMark = &wgDevice.FirewallMark
|
||||||
|
}
|
||||||
|
err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf(
|
||||||
|
"[%s] failed to configure peer on %s, error: %v",
|
||||||
|
serviceFqdn, *deviceFlag, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[%s] configure peer on %s", serviceFqdn, *deviceFlag)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user