package main

import (
	"context"
	"encoding/base32"
	"encoding/base64"
	"flag"
	"fmt"
	"log"
	"net"
	"os"
	"os/signal"
	"syscall"
	"time"

	"github.com/miekg/dns"
	"golang.zx2c4.com/wireguard/wgctrl"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var (
	deviceFlag = flag.String("device", "",
		"name of Wireguard device to manage")
	dnsServerFlag = flag.String("dns", "",
		"ip:port of DNS server")
	dnsZoneFlag = flag.String("zone", "", "dns zone name")
)

func main() {
	flag.Parse()
	if len(*deviceFlag) < 1 {
		log.Fatal("missing device flag")
	}
	if len(*dnsZoneFlag) < 1 {
		log.Fatal("missing zone flag")
	}
	if len(*dnsServerFlag) < 1 {
		log.Fatal("missing dns flag")
	}
	_, _, err := net.SplitHostPort(*dnsServerFlag)
	if err != nil {
		log.Fatalf("invalid dns flag value: %v", err)
	}
	wgClient, err := wgctrl.New()
	if err != nil {
		log.Fatalf("error constructing Wireguard control client: %v",
			err)
	}
	wgDevice, err := wgClient.Device(*deviceFlag)
	if err != nil {
		log.Fatalf(
			"error retrieving Wireguard device '%s' info: %v",
			*deviceFlag, err)
	}
	if len(wgDevice.Peers) < 1 {
		log.Println("no peers found")
		os.Exit(0)
	}
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	done := make(chan struct{})
	go func() {
		defer close(done)
		dnsClient := &dns.Client{
			Timeout: time.Second * 5,
		}
		for _, peer := range wgDevice.Peers {
			select {
			case <-ctx.Done():
				return
			default:
			}
			srvCtx, srvCancel := context.WithCancel(ctx)
			pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
			pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:])
			m := &dns.Msg{}
			question := fmt.Sprintf("%s._wireguard._udp.%s",
				pubKeyBase32, dns.Fqdn(*dnsZoneFlag))
			m.SetQuestion(question, dns.TypeSRV)
			r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
			srvCancel()
			if err != nil {
				log.Printf(
					"[%s] failed to lookup SRV: %v", pubKeyBase64, err)
				continue
			}
			if len(r.Answer) < 1 {
				log.Printf("[%s] no SRV records found", pubKeyBase64)
				continue
			}
			srv, ok := r.Answer[0].(*dns.SRV)
			if !ok {
				log.Printf(
					"[%s] non-SRV answer in response to SRV query: %s",
					pubKeyBase64, r.Answer[0].String())
			}
			if len(r.Extra) < 1 {
				log.Printf("[%s] SRV response missing extra A/AAAA",
					pubKeyBase64)
			}
			var endpointIP net.IP
			hostA, ok := r.Extra[0].(*dns.A)
			if !ok {
				hostAAAA, ok := r.Extra[0].(*dns.AAAA)
				if !ok {
					log.Printf(
						"[%s] non-A/AAAA extra in SRV response: %s",
						pubKeyBase64, r.Extra[0].String())
					continue
				}
				endpointIP = hostAAAA.AAAA
			} else {
				endpointIP = hostA.A
			}
			peerConfig := wgtypes.PeerConfig{
				PublicKey:  peer.PublicKey,
				UpdateOnly: true,
				Endpoint: &net.UDPAddr{
					IP:   endpointIP,
					Port: int(srv.Port),
				},
			}
			deviceConfig := wgtypes.Config{
				PrivateKey:   &wgDevice.PrivateKey,
				ReplacePeers: false,
				Peers:        []wgtypes.PeerConfig{peerConfig},
			}
			if wgDevice.FirewallMark > 0 {
				deviceConfig.FirewallMark = &wgDevice.FirewallMark
			}
			err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig)
			if err != nil {
				log.Printf(
					"[%s] failed to configure peer on %s, error: %v",
					pubKeyBase64, *deviceFlag, err)
			}
		}
	}()
	sigCh := make(chan os.Signal, 1)
	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
	select {
	case sig := <-sigCh:
		log.Printf("exiting due to signal %s", sig)
		cancel()
		<-done
	case <-done:
	}
}