serve self peer info

This commit is contained in:
Jordan Whited 2021-01-01 16:23:12 -08:00 committed by Jordan Whited
parent 7d03ee7041
commit 6f78170fbe
2 changed files with 81 additions and 13 deletions

41
wgsd.go
View File

@ -29,6 +29,10 @@ type WGSD struct {
zone string
// the WireGuard device name, e.g. wg0
device string
// overrides the self endpoint value
selfEndpoint *net.UDPAddr
// self allowed IPs
selfAllowedIPs []net.IPNet
}
type wgctrlClient interface {
@ -145,6 +149,36 @@ func handleHostOrTXT(ctx context.Context, state request.Request,
return nxDomain(state)
}
func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
self := wgtypes.Peer{
PublicKey: device.PublicKey,
}
if p.selfEndpoint != nil {
self.Endpoint = p.selfEndpoint
} else {
self.Endpoint = &net.UDPAddr{
IP: net.ParseIP(state.LocalIP()),
Port: device.ListenPort,
}
}
self.AllowedIPs = p.selfAllowedIPs
return self, nil
}
func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) {
peers := make([]wgtypes.Peer, 0)
device, err := p.client.Device(p.device)
if err != nil {
return nil, err
}
peers = append(peers, device.Peers...)
self, err := p.getSelfPeer(device, state)
if err != nil {
return nil, err
}
return append(peers, self), nil
}
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
r *dns.Msg) (int, error) {
// request.Request is a convenience struct we wrap around the msg and
@ -169,15 +203,12 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
return nxDomain(state)
}
device, err := p.client.Device(p.device)
peers, err := p.getPeers(state)
if err != nil {
return dns.RcodeServerFailure, err
}
if len(device.Peers) == 0 {
return nxDomain(state)
}
return handler(ctx, state, device.Peers)
return handler(ctx, state, peers)
}
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {

View File

@ -16,14 +16,11 @@ import (
)
type mockClient struct {
peers []wgtypes.Peer
device *wgtypes.Device
}
func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
return &wgtypes.Device{
Name: d,
Peers: m.peers,
}, nil
return m.device, nil
}
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
@ -44,6 +41,11 @@ func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string)
}
func TestWGSD(t *testing.T) {
selfKey := [32]byte{}
selfKey[0] = 99
selfb32 := strings.ToLower(base32.StdEncoding.EncodeToString(selfKey[:]))
selfb64 := base64.StdEncoding.EncodeToString(selfKey[:])
selfAllowed, selfAllowedString := constructAllowedIPs(t, []string{"10.0.0.99/32", "10.0.0.100/32"})
key1 := [32]byte{}
key1[0] = 1
peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
@ -73,10 +75,16 @@ func TestWGSD(t *testing.T) {
p := &WGSD{
Next: test.ErrorHandler(),
client: &mockClient{
peers: []wgtypes.Peer{peer1, peer2},
device: &wgtypes.Device{
Name: "wg0",
PublicKey: selfKey,
ListenPort: 51820,
Peers: []wgtypes.Peer{peer1, peer2},
},
},
zone: "example.com.",
device: "wg0",
zone: "example.com.",
device: "wg0",
selfAllowedIPs: selfAllowed,
}
testCases := []test.Case{
@ -87,6 +95,19 @@ func TestWGSD(t *testing.T) {
Answer: []dns.RR{
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer1b32)),
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer2b32)),
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", selfb32)),
},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 51820 %s._wireguard._udp.example.com.", selfb32, selfb32)),
},
Extra: []dns.RR{
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
},
},
{
@ -113,6 +134,14 @@ func TestWGSD(t *testing.T) {
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
Qtype: dns.TypeA,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
Qtype: dns.TypeA,
@ -129,6 +158,14 @@ func TestWGSD(t *testing.T) {
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
Qtype: dns.TypeTXT,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
Qtype: dns.TypeTXT,