From 7eaacc000ba331634a8023ce0d08521267020f99 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 18 Jan 2021 15:08:43 -0800 Subject: [PATCH] skip peers with nil endpoints --- wgsd.go | 6 ++++++ wgsd_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/wgsd.go b/wgsd.go index 15f7777..d0cf8ea 100644 --- a/wgsd.go +++ b/wgsd.go @@ -103,6 +103,9 @@ func handleSRV(state request.Request, peers []wgtypes.Peer) (int, error) { if strings.EqualFold( base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { endpoint := peer.Endpoint + if endpoint == nil { + return nxDomain(state) + } hostRR := getHostRR(state.Name(), endpoint) if hostRR == nil { return nxDomain(state) @@ -137,6 +140,9 @@ func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) { if strings.EqualFold( base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { endpoint := peer.Endpoint + if endpoint == nil { + return nxDomain(state) + } if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA { hostRR := getHostRR(state.Name(), endpoint) if hostRR == nil { diff --git a/wgsd_test.go b/wgsd_test.go index e3e27c8..cbd229e 100644 --- a/wgsd_test.go +++ b/wgsd_test.go @@ -72,6 +72,15 @@ func TestWGSD(t *testing.T) { } peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:])) peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:]) + key3 := [32]byte{} + key3[0] = 3 + peer3Allowed, _ := constructAllowedIPs(t, []string{"10.0.0.5/32", "10.0.0.6/32"}) + peer3 := wgtypes.Peer{ + Endpoint: nil, + PublicKey: key3, + AllowedIPs: peer3Allowed, + } + peer3b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer3.PublicKey[:])) p := &WGSD{ Next: test.ErrorHandler(), Zones: Zones{ @@ -91,7 +100,7 @@ func TestWGSD(t *testing.T) { Name: "wg0", PublicKey: selfKey, ListenPort: 51820, - Peers: []wgtypes.Peer{peer1, peer2}, + Peers: []wgtypes.Peer{peer1, peer2, peer3}, }, }, }, @@ -205,6 +214,46 @@ func TestWGSD(t *testing.T) { Qtype: dns.TypeAAAA, Rcode: dns.RcodeServerFailure, }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32), + Qtype: dns.TypeSRV, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA(soa("example.com.").String()), + }, + Answer: []dns.RR{}, + Extra: []dns.RR{}, + }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32), + Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA(soa("example.com.").String()), + }, + Answer: []dns.RR{}, + Extra: []dns.RR{}, + }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32), + Qtype: dns.TypeAAAA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA(soa("example.com.").String()), + }, + Answer: []dns.RR{}, + Extra: []dns.RR{}, + }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32), + Qtype: dns.TypeTXT, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA(soa("example.com.").String()), + }, + Answer: []dns.RR{}, + Extra: []dns.RR{}, + }, } for _, tc := range testCases { t.Run(fmt.Sprintf("%s %s", tc.Qname, dns.TypeToString[tc.Qtype]), func(t *testing.T) {