149 lines
3.5 KiB
Go
Raw Normal View History

2020-05-15 12:50:16 -07:00
package main
import (
"context"
"encoding/base32"
"encoding/base64"
"flag"
"fmt"
"log"
"net"
"os"
"os/signal"
"syscall"
2020-05-26 15:37:38 -07:00
"time"
2020-05-15 12:50:16 -07:00
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
deviceFlag = flag.String("device", "",
"name of Wireguard device to manage")
dnsServerFlag = flag.String("dns", "",
2020-05-15 14:18:45 -07:00
"ip:port of DNS server")
2020-05-15 12:50:16 -07:00
dnsZoneFlag = flag.String("zone", "", "dns zone name")
)
func main() {
2020-05-15 13:01:58 -07:00
flag.Parse()
2020-05-15 12:50:16 -07:00
if len(*deviceFlag) < 1 {
log.Fatal("missing device flag")
}
if len(*dnsZoneFlag) < 1 {
2020-05-15 13:01:58 -07:00
log.Fatal("missing zone flag")
2020-05-15 12:50:16 -07:00
}
2020-05-15 14:18:45 -07:00
if len(*dnsServerFlag) < 1 {
log.Fatal("missing dns flag")
}
_, _, err := net.SplitHostPort(*dnsServerFlag)
if err != nil {
log.Fatalf("invalid dns flag value: %v", err)
2020-05-15 12:50:16 -07:00
}
wgClient, err := wgctrl.New()
if err != nil {
log.Fatalf("error constructing Wireguard control client: %v",
err)
}
wgDevice, err := wgClient.Device(*deviceFlag)
if err != nil {
log.Fatalf(
"error retrieving Wireguard device '%s' info: %v",
*deviceFlag, err)
}
2020-05-15 13:01:58 -07:00
if len(wgDevice.Peers) < 1 {
log.Println("no peers found")
os.Exit(0)
}
2020-05-15 12:50:16 -07:00
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
2020-05-26 15:37:38 -07:00
dnsClient := &dns.Client{
Timeout: time.Second * 5,
}
2020-05-15 12:50:16 -07:00
for _, peer := range wgDevice.Peers {
select {
case <-ctx.Done():
return
default:
}
srvCtx, srvCancel := context.WithCancel(ctx)
pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:])
2020-05-15 14:18:45 -07:00
m := &dns.Msg{}
question := fmt.Sprintf("%s._wireguard._udp.%s",
2020-05-15 12:50:16 -07:00
pubKeyBase32, dns.Fqdn(*dnsZoneFlag))
2020-05-15 14:20:39 -07:00
m.SetQuestion(question, dns.TypeSRV)
2020-05-15 14:18:45 -07:00
r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
2020-05-15 12:50:16 -07:00
srvCancel()
if err != nil {
log.Printf(
2020-05-15 14:18:45 -07:00
"[%s] failed to lookup SRV: %v", pubKeyBase64, err)
2020-05-15 12:50:16 -07:00
continue
}
2020-05-15 14:18:45 -07:00
if len(r.Answer) < 1 {
log.Printf("[%s] no SRV records found", pubKeyBase64)
2020-05-15 12:50:16 -07:00
continue
}
2020-05-15 14:18:45 -07:00
srv, ok := r.Answer[0].(*dns.SRV)
if !ok {
2020-05-15 12:50:16 -07:00
log.Printf(
2020-05-15 14:18:45 -07:00
"[%s] non-SRV answer in response to SRV query: %s",
pubKeyBase64, r.Answer[0].String())
2020-05-15 12:50:16 -07:00
}
2020-05-15 14:18:45 -07:00
if len(r.Extra) < 1 {
log.Printf("[%s] SRV response missing extra A/AAAA",
2020-05-15 12:50:16 -07:00
pubKeyBase64)
2020-05-15 14:18:45 -07:00
}
var endpointIP net.IP
2020-05-15 16:30:33 -07:00
hostA, ok := r.Extra[0].(*dns.A)
2020-05-15 14:18:45 -07:00
if !ok {
2020-05-15 16:30:33 -07:00
hostAAAA, ok := r.Extra[0].(*dns.AAAA)
2020-05-15 14:18:45 -07:00
if !ok {
log.Printf(
"[%s] non-A/AAAA extra in SRV response: %s",
2020-05-15 16:30:33 -07:00
pubKeyBase64, r.Extra[0].String())
2020-05-15 14:18:45 -07:00
continue
}
endpointIP = hostAAAA.AAAA
} else {
endpointIP = hostA.A
2020-05-15 12:50:16 -07:00
}
peerConfig := wgtypes.PeerConfig{
PublicKey: peer.PublicKey,
UpdateOnly: true,
Endpoint: &net.UDPAddr{
2020-05-15 14:18:45 -07:00
IP: endpointIP,
Port: int(srv.Port),
2020-05-15 12:50:16 -07:00
},
}
deviceConfig := wgtypes.Config{
PrivateKey: &wgDevice.PrivateKey,
ReplacePeers: false,
Peers: []wgtypes.PeerConfig{peerConfig},
}
if wgDevice.FirewallMark > 0 {
deviceConfig.FirewallMark = &wgDevice.FirewallMark
}
err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig)
if err != nil {
log.Printf(
2020-05-26 15:37:38 -07:00
"[%s] failed to configure peer on %s, error: %v",
2020-05-15 12:50:16 -07:00
pubKeyBase64, *deviceFlag, err)
}
}
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigCh:
log.Printf("exiting due to signal %s", sig)
cancel()
<-done
case <-done:
}
}