diff --git a/wgsd.go b/wgsd.go index 6cc2aa0..1674544 100644 --- a/wgsd.go +++ b/wgsd.go @@ -29,6 +29,10 @@ type WGSD struct { zone string // the WireGuard device name, e.g. wg0 device string + // overrides the self endpoint value + selfEndpoint *net.UDPAddr + // self allowed IPs + selfAllowedIPs []net.IPNet } type wgctrlClient interface { @@ -145,6 +149,36 @@ func handleHostOrTXT(ctx context.Context, state request.Request, return nxDomain(state) } +func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) { + self := wgtypes.Peer{ + PublicKey: device.PublicKey, + } + if p.selfEndpoint != nil { + self.Endpoint = p.selfEndpoint + } else { + self.Endpoint = &net.UDPAddr{ + IP: net.ParseIP(state.LocalIP()), + Port: device.ListenPort, + } + } + self.AllowedIPs = p.selfAllowedIPs + return self, nil +} + +func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) { + peers := make([]wgtypes.Peer, 0) + device, err := p.client.Device(p.device) + if err != nil { + return nil, err + } + peers = append(peers, device.Peers...) + self, err := p.getSelfPeer(device, state) + if err != nil { + return nil, err + } + return append(peers, self), nil +} + func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { // request.Request is a convenience struct we wrap around the msg and @@ -169,15 +203,12 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, return nxDomain(state) } - device, err := p.client.Device(p.device) + peers, err := p.getPeers(state) if err != nil { return dns.RcodeServerFailure, err } - if len(device.Peers) == 0 { - return nxDomain(state) - } - return handler(ctx, state, device.Peers) + return handler(ctx, state, peers) } func getHostRR(name string, endpoint *net.UDPAddr) dns.RR { diff --git a/wgsd_test.go b/wgsd_test.go index 15741b5..f90160d 100644 --- a/wgsd_test.go +++ b/wgsd_test.go @@ -16,14 +16,11 @@ import ( ) type mockClient struct { - peers []wgtypes.Peer + device *wgtypes.Device } func (m *mockClient) Device(d string) (*wgtypes.Device, error) { - return &wgtypes.Device{ - Name: d, - Peers: m.peers, - }, nil + return m.device, nil } func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) { @@ -44,6 +41,11 @@ func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) } func TestWGSD(t *testing.T) { + selfKey := [32]byte{} + selfKey[0] = 99 + selfb32 := strings.ToLower(base32.StdEncoding.EncodeToString(selfKey[:])) + selfb64 := base64.StdEncoding.EncodeToString(selfKey[:]) + selfAllowed, selfAllowedString := constructAllowedIPs(t, []string{"10.0.0.99/32", "10.0.0.100/32"}) key1 := [32]byte{} key1[0] = 1 peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"}) @@ -73,10 +75,16 @@ func TestWGSD(t *testing.T) { p := &WGSD{ Next: test.ErrorHandler(), client: &mockClient{ - peers: []wgtypes.Peer{peer1, peer2}, + device: &wgtypes.Device{ + Name: "wg0", + PublicKey: selfKey, + ListenPort: 51820, + Peers: []wgtypes.Peer{peer1, peer2}, + }, }, - zone: "example.com.", - device: "wg0", + zone: "example.com.", + device: "wg0", + selfAllowedIPs: selfAllowed, } testCases := []test.Case{ @@ -87,6 +95,19 @@ func TestWGSD(t *testing.T) { Answer: []dns.RR{ test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer1b32)), test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer2b32)), + test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", selfb32)), + }, + }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32), + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 51820 %s._wireguard._udp.example.com.", selfb32, selfb32)), + }, + Extra: []dns.RR{ + test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")), + test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)), }, }, { @@ -113,6 +134,14 @@ func TestWGSD(t *testing.T) { test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)), }, }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32), + Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")), + }, + }, { Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32), Qtype: dns.TypeA, @@ -129,6 +158,14 @@ func TestWGSD(t *testing.T) { test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())), }, }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32), + Qtype: dns.TypeTXT, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)), + }, + }, { Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32), Qtype: dns.TypeTXT,