From b39240d5e60196e521b370c554ffc01953bf77ed Mon Sep 17 00:00:00 2001
From: Jordan Whited <jordan@jordanwhited.com>
Date: Fri, 1 Jan 2021 16:23:12 -0800
Subject: [PATCH] serve self peer info

---
 wgsd.go      | 41 +++++++++++++++++++++++++++++++++++-----
 wgsd_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 81 insertions(+), 13 deletions(-)

diff --git a/wgsd.go b/wgsd.go
index 6cc2aa0..1674544 100644
--- a/wgsd.go
+++ b/wgsd.go
@@ -29,6 +29,10 @@ type WGSD struct {
 	zone string
 	// the WireGuard device name, e.g. wg0
 	device string
+	// overrides the self endpoint value
+	selfEndpoint *net.UDPAddr
+	// self allowed IPs
+	selfAllowedIPs []net.IPNet
 }
 
 type wgctrlClient interface {
@@ -145,6 +149,36 @@ func handleHostOrTXT(ctx context.Context, state request.Request,
 	return nxDomain(state)
 }
 
+func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
+	self := wgtypes.Peer{
+		PublicKey: device.PublicKey,
+	}
+	if p.selfEndpoint != nil {
+		self.Endpoint = p.selfEndpoint
+	} else {
+		self.Endpoint = &net.UDPAddr{
+			IP:   net.ParseIP(state.LocalIP()),
+			Port: device.ListenPort,
+		}
+	}
+	self.AllowedIPs = p.selfAllowedIPs
+	return self, nil
+}
+
+func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) {
+	peers := make([]wgtypes.Peer, 0)
+	device, err := p.client.Device(p.device)
+	if err != nil {
+		return nil, err
+	}
+	peers = append(peers, device.Peers...)
+	self, err := p.getSelfPeer(device, state)
+	if err != nil {
+		return nil, err
+	}
+	return append(peers, self), nil
+}
+
 func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
 	r *dns.Msg) (int, error) {
 	// request.Request is a convenience struct we wrap around the msg and
@@ -169,15 +203,12 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
 		return nxDomain(state)
 	}
 
-	device, err := p.client.Device(p.device)
+	peers, err := p.getPeers(state)
 	if err != nil {
 		return dns.RcodeServerFailure, err
 	}
-	if len(device.Peers) == 0 {
-		return nxDomain(state)
-	}
 
-	return handler(ctx, state, device.Peers)
+	return handler(ctx, state, peers)
 }
 
 func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
diff --git a/wgsd_test.go b/wgsd_test.go
index 15741b5..f90160d 100644
--- a/wgsd_test.go
+++ b/wgsd_test.go
@@ -16,14 +16,11 @@ import (
 )
 
 type mockClient struct {
-	peers []wgtypes.Peer
+	device *wgtypes.Device
 }
 
 func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
-	return &wgtypes.Device{
-		Name:  d,
-		Peers: m.peers,
-	}, nil
+	return m.device, nil
 }
 
 func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
@@ -44,6 +41,11 @@ func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string)
 }
 
 func TestWGSD(t *testing.T) {
+	selfKey := [32]byte{}
+	selfKey[0] = 99
+	selfb32 := strings.ToLower(base32.StdEncoding.EncodeToString(selfKey[:]))
+	selfb64 := base64.StdEncoding.EncodeToString(selfKey[:])
+	selfAllowed, selfAllowedString := constructAllowedIPs(t, []string{"10.0.0.99/32", "10.0.0.100/32"})
 	key1 := [32]byte{}
 	key1[0] = 1
 	peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
@@ -73,10 +75,16 @@ func TestWGSD(t *testing.T) {
 	p := &WGSD{
 		Next: test.ErrorHandler(),
 		client: &mockClient{
-			peers: []wgtypes.Peer{peer1, peer2},
+			device: &wgtypes.Device{
+				Name:       "wg0",
+				PublicKey:  selfKey,
+				ListenPort: 51820,
+				Peers:      []wgtypes.Peer{peer1, peer2},
+			},
 		},
-		zone:   "example.com.",
-		device: "wg0",
+		zone:           "example.com.",
+		device:         "wg0",
+		selfAllowedIPs: selfAllowed,
 	}
 
 	testCases := []test.Case{
@@ -87,6 +95,19 @@ func TestWGSD(t *testing.T) {
 			Answer: []dns.RR{
 				test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer1b32)),
 				test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer2b32)),
+				test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", selfb32)),
+			},
+		},
+		{
+			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
+			Qtype: dns.TypeSRV,
+			Rcode: dns.RcodeSuccess,
+			Answer: []dns.RR{
+				test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 51820 %s._wireguard._udp.example.com.", selfb32, selfb32)),
+			},
+			Extra: []dns.RR{
+				test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
+				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
 			},
 		},
 		{
@@ -113,6 +134,14 @@ func TestWGSD(t *testing.T) {
 				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
 			},
 		},
+		{
+			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
+			Qtype: dns.TypeA,
+			Rcode: dns.RcodeSuccess,
+			Answer: []dns.RR{
+				test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
+			},
+		},
 		{
 			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
 			Qtype: dns.TypeA,
@@ -129,6 +158,14 @@ func TestWGSD(t *testing.T) {
 				test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
 			},
 		},
+		{
+			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
+			Qtype: dns.TypeTXT,
+			Rcode: dns.RcodeSuccess,
+			Answer: []dns.RR{
+				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
+			},
+		},
 		{
 			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
 			Qtype: dns.TypeTXT,