skip peers with nil endpoints

This commit is contained in:
Jordan Whited 2021-01-18 15:08:43 -08:00 committed by Jordan Whited
parent d9845d72b8
commit 7eaacc000b
2 changed files with 56 additions and 1 deletions

View File

@ -103,6 +103,9 @@ func handleSRV(state request.Request, peers []wgtypes.Peer) (int, error) {
if strings.EqualFold( if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint endpoint := peer.Endpoint
if endpoint == nil {
return nxDomain(state)
}
hostRR := getHostRR(state.Name(), endpoint) hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil { if hostRR == nil {
return nxDomain(state) return nxDomain(state)
@ -137,6 +140,9 @@ func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) {
if strings.EqualFold( if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint endpoint := peer.Endpoint
if endpoint == nil {
return nxDomain(state)
}
if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA { if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA {
hostRR := getHostRR(state.Name(), endpoint) hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil { if hostRR == nil {

View File

@ -72,6 +72,15 @@ func TestWGSD(t *testing.T) {
} }
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:])) peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:]) peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:])
key3 := [32]byte{}
key3[0] = 3
peer3Allowed, _ := constructAllowedIPs(t, []string{"10.0.0.5/32", "10.0.0.6/32"})
peer3 := wgtypes.Peer{
Endpoint: nil,
PublicKey: key3,
AllowedIPs: peer3Allowed,
}
peer3b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer3.PublicKey[:]))
p := &WGSD{ p := &WGSD{
Next: test.ErrorHandler(), Next: test.ErrorHandler(),
Zones: Zones{ Zones: Zones{
@ -91,7 +100,7 @@ func TestWGSD(t *testing.T) {
Name: "wg0", Name: "wg0",
PublicKey: selfKey, PublicKey: selfKey,
ListenPort: 51820, ListenPort: 51820,
Peers: []wgtypes.Peer{peer1, peer2}, Peers: []wgtypes.Peer{peer1, peer2, peer3},
}, },
}, },
}, },
@ -205,6 +214,46 @@ func TestWGSD(t *testing.T) {
Qtype: dns.TypeAAAA, Qtype: dns.TypeAAAA,
Rcode: dns.RcodeServerFailure, Rcode: dns.RcodeServerFailure,
}, },
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeSRV,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeA,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeAAAA,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeTXT,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(fmt.Sprintf("%s %s", tc.Qname, dns.TypeToString[tc.Qtype]), func(t *testing.T) { t.Run(fmt.Sprintf("%s %s", tc.Qname, dns.TypeToString[tc.Qtype]), func(t *testing.T) {