package main import ( "context" "encoding/base32" "flag" "log" "net" "os" "os/signal" "strings" "syscall" "time" "github.com/miekg/dns" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) var ( deviceFlag = flag.String("device", "", "name of Wireguard device to manage") dnsServerFlag = flag.String("dns", "", "ip:port of DNS server") dnsZoneFlag = flag.String("zone", "", "dns zone name") serviceFlag = flag.String("service", "", "service ip or public key") ) const ( keyLen = 56 // the number of characters in a base32-encoded Wireguard public key spPrefix = "_wireguard._udp." ) func main() { flag.Parse() if len(*deviceFlag) < 1 { log.Fatal("missing device flag") } if len(*dnsZoneFlag) < 1 { log.Fatal("missing zone flag") } if len(*dnsServerFlag) < 1 { log.Fatal("missing dns flag") } _, _, err := net.SplitHostPort(*dnsServerFlag) if err != nil { log.Fatalf("invalid dns flag value: %v", err) } wgClient, err := wgctrl.New() if err != nil { log.Fatalf("error constructing Wireguard control client: %v", err) } wgDevice, err := wgClient.Device(*deviceFlag) if err != nil { log.Fatalf( "error retrieving Wireguard device '%s' info: %v", *deviceFlag, err) } if len(wgDevice.Peers) < 1 { log.Println("no peers found") os.Exit(0) } ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan struct{}) go func() { defer close(done) dnsClient := &dns.Client{ Timeout: time.Second * 5, } suffix := spPrefix + dns.Fqdn(*dnsZoneFlag) if len(*serviceFlag) >= 1 { fqdn := *serviceFlag suffix = "." + suffix if !strings.HasSuffix(fqdn, suffix) { fqdn += suffix } ConnectPeer(ctx, wgClient, wgDevice, dnsClient, fqdn, *dnsServerFlag) return } srvCtx, srvCancel := context.WithCancel(ctx) m := &dns.Msg{} m.SetQuestion(suffix, dns.TypePTR) r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag) srvCancel() if err != nil { log.Printf("failed to lookup PTR: %v", err) return } if len(r.Answer) < 1 { log.Printf("no PTR records found") return } for _, answer := range r.Answer { select { case <-ctx.Done(): return default: } ptr, ok := answer.(*dns.PTR) if !ok { log.Printf("non-PTR answer in response to PTR query: %s", answer.String()) continue } ConnectPeer(ctx, wgClient, wgDevice, dnsClient, ptr.Ptr, *dnsServerFlag) } }() sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) select { case sig := <-sigCh: log.Printf("exiting due to signal %s", sig) cancel() <-done case <-done: } } func ConnectPeer(ctx context.Context, wgClient *wgctrl.Client, wgDevice *wgtypes.Device, dnsClient *dns.Client, serviceFqdn string, dnsServer string) { srvCtx, srvCancel := context.WithCancel(ctx) m := &dns.Msg{} m.SetQuestion(serviceFqdn, dns.TypeSRV) r, _, err := dnsClient.ExchangeContext(srvCtx, m, dnsServer) srvCancel() if err != nil { log.Printf( "[%s] failed to lookup SRV: %v", serviceFqdn, err) return } if len(r.Answer) < 1 { log.Printf("[%s] no SRV records found", serviceFqdn) return } srv, ok := r.Answer[0].(*dns.SRV) if !ok { log.Printf( "[%s] non-SRV answer in response to SRV query: %s", serviceFqdn, r.Answer[0].String()) return } if len(r.Extra) < 2 { log.Printf("[%s] SRV response missing extra A/AAAA and TXT", serviceFqdn) return } var endpointIP net.IP hostA, ok := r.Extra[0].(*dns.A) if !ok { hostAAAA, ok := r.Extra[0].(*dns.AAAA) if !ok { log.Printf( "[%s] non-A/AAAA extra in SRV response: %s", serviceFqdn, r.Extra[0].String()) return } endpointIP = hostAAAA.AAAA } else { endpointIP = hostA.A } txt, ok := r.Extra[1].(*dns.TXT) if !ok { log.Printf("[%s] non-TXT extra in SRV response: %s", serviceFqdn, r.Extra[1].String()) return } allowedIPsString := strings.TrimPrefix(strings.ToLower(txt.Txt[0]), "allowedip=") _, allowedIPs, err := net.ParseCIDR(allowedIPsString) if err != nil { log.Printf("[%s] failed to parse allowedip in TXT extra: %s", serviceFqdn, r.Extra[1].String()) return } pubKeyString := strings.TrimPrefix(strings.ToUpper(txt.Txt[1]), "PUBKEY=") if len(pubKeyString) < keyLen { pubKeyString += strings.Repeat("=", keyLen-len(pubKeyString)) } pubKeyBytes, err := base32.StdEncoding.DecodeString(strings.ToUpper(pubKeyString)) if err != nil { log.Printf("[%s] failed to decode base32 key %s: %v", serviceFqdn, pubKeyString, err) return } pubKeyWg, err := wgtypes.NewKey(pubKeyBytes) if err != nil { log.Printf("[%s] failed to create wg key: %v", serviceFqdn, err) return } if pubKeyWg == wgDevice.PublicKey { log.Printf("[%s] skipping ourself", serviceFqdn) return } peerConfig := wgtypes.PeerConfig{ PublicKey: pubKeyWg, UpdateOnly: false, Endpoint: &net.UDPAddr{ IP: endpointIP, Port: int(srv.Port), }, ReplaceAllowedIPs: true, AllowedIPs: []net.IPNet{*allowedIPs}, } deviceConfig := wgtypes.Config{ PrivateKey: &wgDevice.PrivateKey, ReplacePeers: false, Peers: []wgtypes.PeerConfig{peerConfig}, } if wgDevice.FirewallMark > 0 { deviceConfig.FirewallMark = &wgDevice.FirewallMark } err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig) if err != nil { log.Printf( "[%s] failed to configure peer on %s, error: %v", serviceFqdn, *deviceFlag, err) return } log.Printf("[%s] configure peer on %s", serviceFqdn, *deviceFlag) }