mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-18 22:09:34 +08:00
149 lines
3.5 KiB
Go
149 lines
3.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base32"
|
|
"encoding/base64"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"golang.zx2c4.com/wireguard/wgctrl"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
var (
|
|
deviceFlag = flag.String("device", "",
|
|
"name of Wireguard device to manage")
|
|
dnsServerFlag = flag.String("dns", "",
|
|
"ip:port of DNS server")
|
|
dnsZoneFlag = flag.String("zone", "", "dns zone name")
|
|
)
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
if len(*deviceFlag) < 1 {
|
|
log.Fatal("missing device flag")
|
|
}
|
|
if len(*dnsZoneFlag) < 1 {
|
|
log.Fatal("missing zone flag")
|
|
}
|
|
if len(*dnsServerFlag) < 1 {
|
|
log.Fatal("missing dns flag")
|
|
}
|
|
_, _, err := net.SplitHostPort(*dnsServerFlag)
|
|
if err != nil {
|
|
log.Fatalf("invalid dns flag value: %v", err)
|
|
}
|
|
wgClient, err := wgctrl.New()
|
|
if err != nil {
|
|
log.Fatalf("error constructing Wireguard control client: %v",
|
|
err)
|
|
}
|
|
wgDevice, err := wgClient.Device(*deviceFlag)
|
|
if err != nil {
|
|
log.Fatalf(
|
|
"error retrieving Wireguard device '%s' info: %v",
|
|
*deviceFlag, err)
|
|
}
|
|
if len(wgDevice.Peers) < 1 {
|
|
log.Println("no peers found")
|
|
os.Exit(0)
|
|
}
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
dnsClient := &dns.Client{
|
|
Timeout: time.Second * 5,
|
|
}
|
|
for _, peer := range wgDevice.Peers {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
srvCtx, srvCancel := context.WithCancel(ctx)
|
|
pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
|
|
pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:])
|
|
m := &dns.Msg{}
|
|
question := fmt.Sprintf("%s._wireguard._udp.%s",
|
|
pubKeyBase32, dns.Fqdn(*dnsZoneFlag))
|
|
m.SetQuestion(question, dns.TypeSRV)
|
|
r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
|
|
srvCancel()
|
|
if err != nil {
|
|
log.Printf(
|
|
"[%s] failed to lookup SRV: %v", pubKeyBase64, err)
|
|
continue
|
|
}
|
|
if len(r.Answer) < 1 {
|
|
log.Printf("[%s] no SRV records found", pubKeyBase64)
|
|
continue
|
|
}
|
|
srv, ok := r.Answer[0].(*dns.SRV)
|
|
if !ok {
|
|
log.Printf(
|
|
"[%s] non-SRV answer in response to SRV query: %s",
|
|
pubKeyBase64, r.Answer[0].String())
|
|
}
|
|
if len(r.Extra) < 1 {
|
|
log.Printf("[%s] SRV response missing extra A/AAAA",
|
|
pubKeyBase64)
|
|
}
|
|
var endpointIP net.IP
|
|
hostA, ok := r.Extra[0].(*dns.A)
|
|
if !ok {
|
|
hostAAAA, ok := r.Extra[0].(*dns.AAAA)
|
|
if !ok {
|
|
log.Printf(
|
|
"[%s] non-A/AAAA extra in SRV response: %s",
|
|
pubKeyBase64, r.Extra[0].String())
|
|
continue
|
|
}
|
|
endpointIP = hostAAAA.AAAA
|
|
} else {
|
|
endpointIP = hostA.A
|
|
}
|
|
peerConfig := wgtypes.PeerConfig{
|
|
PublicKey: peer.PublicKey,
|
|
UpdateOnly: true,
|
|
Endpoint: &net.UDPAddr{
|
|
IP: endpointIP,
|
|
Port: int(srv.Port),
|
|
},
|
|
}
|
|
deviceConfig := wgtypes.Config{
|
|
PrivateKey: &wgDevice.PrivateKey,
|
|
ReplacePeers: false,
|
|
Peers: []wgtypes.PeerConfig{peerConfig},
|
|
}
|
|
if wgDevice.FirewallMark > 0 {
|
|
deviceConfig.FirewallMark = &wgDevice.FirewallMark
|
|
}
|
|
err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig)
|
|
if err != nil {
|
|
log.Printf(
|
|
"[%s] failed to configure peer on %s, error: %v",
|
|
pubKeyBase64, *deviceFlag, err)
|
|
}
|
|
}
|
|
}()
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
select {
|
|
case sig := <-sigCh:
|
|
log.Printf("exiting due to signal %s", sig)
|
|
cancel()
|
|
<-done
|
|
case <-done:
|
|
}
|
|
}
|