From 577c703785d6618dcb68fca494e744364f14d726 Mon Sep 17 00:00:00 2001
From: Jordan Whited <jordan@jordanwhited.com>
Date: Fri, 1 Jan 2021 17:15:04 -0800
Subject: [PATCH] add configuration support for self overrides

---
 setup.go      | 92 ++++++++++++++++++++++++++++++++++++---------------
 setup_test.go | 63 +++++++++++++++++++++++++++++++++--
 wgsd.go       |  8 +++--
 3 files changed, 132 insertions(+), 31 deletions(-)

diff --git a/setup.go b/setup.go
index 8e35436..44958dc 100644
--- a/setup.go
+++ b/setup.go
@@ -2,6 +2,8 @@ package wgsd
 
 import (
 	"fmt"
+	"net"
+	"strconv"
 
 	"github.com/coredns/caddy"
 	"github.com/coredns/coredns/core/dnsserver"
@@ -11,45 +13,83 @@ import (
 )
 
 func init() {
-	plugin.Register("wgsd", setup)
+	plugin.Register(pluginName, setup)
+}
+
+const (
+	optionSelfAllowedIPs = "self-allowed-ips"
+	optionSelfEndpoint   = "self-endpoint"
+)
+
+func parse(c *caddy.Controller) (*WGSD, error) {
+	p := &WGSD{}
+	for c.Next() {
+		args := c.RemainingArgs()
+		if len(args) != 2 {
+			return nil, fmt.Errorf("expected 2 args, got %d", len(args))
+		}
+		p.zone = dns.Fqdn(args[0])
+		p.device = args[1]
+
+		for c.NextBlock() {
+			switch c.Val() {
+			case optionSelfAllowedIPs:
+				p.selfAllowedIPs = make([]net.IPNet, 0)
+				for _, aip := range c.RemainingArgs() {
+					_, prefix, err := net.ParseCIDR(aip)
+					if err != nil {
+						return nil, fmt.Errorf("invalid self-allowed-ips: %s err: %v", c.Val(), err)
+					}
+					p.selfAllowedIPs = append(p.selfAllowedIPs, *prefix)
+				}
+			case optionSelfEndpoint:
+				endpoint := c.RemainingArgs()
+				if len(endpoint) != 1 {
+					return nil, fmt.Errorf("expected 1 arg, got %d", len(endpoint))
+				}
+				host, portS, err := net.SplitHostPort(endpoint[0])
+				if err != nil {
+					return nil, fmt.Errorf("invalid self-endpoint, err: %v", err)
+				}
+				port, err := strconv.Atoi(portS)
+				if err != nil {
+					return nil, fmt.Errorf("error converting self-endpoint port: %v", err)
+				}
+				ip := net.ParseIP(host)
+				if ip == nil {
+					return nil, fmt.Errorf("invalid self-endpoint, invalid IP address: %s", host)
+				}
+				p.selfEndpoint = &net.UDPAddr{
+					IP:   ip,
+					Port: port,
+				}
+			default:
+				return nil, c.ArgErr()
+			}
+		}
+	}
+
+	return p, nil
 }
 
 func setup(c *caddy.Controller) error {
-	c.Next() // Ignore "wgsd" and give us the next token.
-
-	// return an error if there is no zone specified
-	if !c.NextArg() {
-		return plugin.Error("wgsd", c.ArgErr())
+	wgsd, err := parse(c)
+	if err != nil {
+		return plugin.Error(pluginName, err)
 	}
-	zone := dns.Fqdn(c.Val())
-
-	// return an error if there is no device name specified
-	if !c.NextArg() {
-		return plugin.Error("wgsd", c.ArgErr())
-	}
-	device := c.Val()
-
-	// return an error if there are more tokens on this line
-	if c.NextArg() {
-		return plugin.Error("wgsd", c.ArgErr())
-	}
-
 	client, err := wgctrl.New()
 	if err != nil {
-		return plugin.Error("wgsd",
+		return plugin.Error(pluginName,
 			fmt.Errorf("error constructing wgctrl client: %v",
 				err))
 	}
 	c.OnFinalShutdown(client.Close)
+	wgsd.client = client
 
 	// Add the Plugin to CoreDNS, so Servers can use it in their plugin chain.
 	dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
-		return &WGSD{
-			Next:   next,
-			client: client,
-			zone:   zone,
-			device: device,
-		}
+		wgsd.Next = next
+		return wgsd
 	})
 
 	return nil
diff --git a/setup_test.go b/setup_test.go
index 39e7410..6b39a37 100644
--- a/setup_test.go
+++ b/setup_test.go
@@ -8,24 +8,78 @@ import (
 
 func TestSetup(t *testing.T) {
 	testCases := []struct {
-		name      string
-		input     string
-		expectErr bool
+		name                 string
+		input                string
+		expectErr            bool
+		expectSelfAllowedIPs []string
+		expectSelfEndpoint   []string
 	}{
 		{
 			"valid input",
 			"wgsd example.com. wg0",
 			false,
+			nil,
+			nil,
 		},
 		{
 			"missing token",
 			"wgsd example.com.",
 			true,
+			nil,
+			nil,
 		},
 		{
 			"too many tokens",
 			"wgsd example.com. wg0 extra",
 			true,
+			nil,
+			nil,
+		},
+		{
+			"valid self-allowed-ips",
+			`wgsd example.com. wg0 {
+						self-allowed-ips 10.0.0.1/32 10.0.0.2/32
+					}`,
+			false,
+			nil,
+			nil,
+		},
+		{
+			"invalid self-allowed-ips",
+			`wgsd example.com. wg0 {
+						self-allowed-ips 10.0.01/32 10.0.0.2/32
+					}`,
+			true,
+			nil,
+			nil,
+		},
+		{
+			"valid self-endpoint",
+			`wgsd example.com. wg0 {
+						self-endpoint 127.0.0.1:51820
+					}`,
+			false,
+			nil,
+			nil,
+		},
+		{
+			"invalid self-endpoint",
+			`wgsd example.com. wg0 {
+						self-endpoint hostname:51820
+					}`,
+			true,
+			nil,
+			nil,
+		},
+		{
+			"all options",
+			`wgsd example.com. wg0 {
+						self-allowed-ips 10.0.0.1/32 10.0.0.2/32
+						self-endpoint 127.0.0.1:51820
+					}`,
+			false,
+			nil,
+			nil,
 		},
 	}
 
@@ -36,6 +90,9 @@ func TestSetup(t *testing.T) {
 			if (err != nil) != tc.expectErr {
 				t.Fatalf("expectErr: %v, got err=%v", tc.expectErr, err)
 			}
+			if tc.expectErr {
+				return
+			}
 		})
 	}
 }
diff --git a/wgsd.go b/wgsd.go
index 1674544..ab572ce 100644
--- a/wgsd.go
+++ b/wgsd.go
@@ -16,7 +16,11 @@ import (
 )
 
 // coredns plugin-specific logger
-var logger = clog.NewWithPlugin("wgsd")
+var logger = clog.NewWithPlugin(pluginName)
+
+const (
+	pluginName = "wgsd"
+)
 
 // WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD
 // semantics. WGSD implements the plugin.Handler interface.
@@ -300,5 +304,5 @@ func soa(zone string) dns.RR {
 }
 
 func (p *WGSD) Name() string {
-	return "wgsd"
+	return pluginName
 }