From 53c17947984b7b1b5d78a64f6faf3c73332f9d71 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 15 May 2020 12:50:16 -0700 Subject: [PATCH] add client --- cmd/client/main.go | 139 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 cmd/client/main.go diff --git a/cmd/client/main.go b/cmd/client/main.go new file mode 100644 index 0000000..43cac71 --- /dev/null +++ b/cmd/client/main.go @@ -0,0 +1,139 @@ +package main + +import ( + "context" + "encoding/base32" + "encoding/base64" + "flag" + "fmt" + "log" + "net" + "os" + "os/signal" + "syscall" + + "github.com/miekg/dns" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var ( + deviceFlag = flag.String("device", "", + "name of Wireguard device to manage") + dnsServerFlag = flag.String("dns", "", + "ip:port of DNS server; defaults to OS resolver") + dnsZoneFlag = flag.String("zone", "", "dns zone name") +) + +func main() { + if len(*deviceFlag) < 1 { + log.Fatal("missing device flag") + } + if len(*dnsZoneFlag) < 1 { + log.Fatal("missing dns flag") + } + resolver := net.DefaultResolver + if len(*dnsServerFlag) > 0 { + _, _, err := net.SplitHostPort(*dnsServerFlag) + if err != nil { + log.Fatalf("invalid dns server flag: %v", err) + } + dialer := net.Dialer{} + dialFn := func(ctx context.Context, network, address string) (net.Conn, + error) { + return dialer.DialContext(ctx, network, *dnsServerFlag) + } + resolver = &net.Resolver{ + PreferGo: true, + Dial: dialFn, + } + } + wgClient, err := wgctrl.New() + if err != nil { + log.Fatalf("error constructing Wireguard control client: %v", + err) + } + wgDevice, err := wgClient.Device(*deviceFlag) + if err != nil { + log.Fatalf( + "error retrieving Wireguard device '%s' info: %v", + *deviceFlag, err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan struct{}) + go func() { + defer close(done) + for _, peer := range wgDevice.Peers { + select { + case <-ctx.Done(): + return + default: + } + srvCtx, srvCancel := context.WithCancel(ctx) + pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:]) + pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:]) + queryName := fmt.Sprintf("%s._wireguard._udp.%s", + pubKeyBase32, dns.Fqdn(*dnsZoneFlag)) + _, srvs, err := resolver.LookupSRV(srvCtx, "", "", + queryName) + srvCancel() + if err != nil { + log.Printf( + "failed to lookup SRV for peer %s error: %v", + pubKeyBase64, err) + continue + } + if len(srvs) < 1 { + log.Printf("no SRV records found for peer %s", + pubKeyBase64) + continue + } + hostCtx, hostCancel := context.WithCancel(ctx) + addrs, err := resolver.LookupIPAddr(hostCtx, srvs[1].Target) + hostCancel() + if err != nil { + log.Printf( + "failed to lookup A/AAAA for peer %s error: %v", + pubKeyBase64, err) + continue + } + if len(addrs) < 1 { + log.Printf("no A/AAAA records found for peer %s", + pubKeyBase64) + continue + } + peerConfig := wgtypes.PeerConfig{ + PublicKey: peer.PublicKey, + UpdateOnly: true, + Endpoint: &net.UDPAddr{ + IP: addrs[0].IP, + Port: int(srvs[0].Port), + }, + } + deviceConfig := wgtypes.Config{ + PrivateKey: &wgDevice.PrivateKey, + ReplacePeers: false, + Peers: []wgtypes.PeerConfig{peerConfig}, + } + if wgDevice.FirewallMark > 0 { + deviceConfig.FirewallMark = &wgDevice.FirewallMark + } + err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig) + if err != nil { + log.Printf( + "failed to configure peer %s on %s, error: %v", + pubKeyBase64, *deviceFlag, err) + } + } + }() + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + select { + case sig := <-sigCh: + log.Printf("exiting due to signal %s", sig) + cancel() + <-done + case <-done: + } +}