serve allowed ips and public key via TXT RR

This commit is contained in:
Jordan Whited 2020-12-31 14:14:27 -08:00 committed by Jordan Whited
parent 016a366d0f
commit a928f85a58
2 changed files with 101 additions and 24 deletions

73
wgsd.go
View File

@ -3,6 +3,7 @@ package wgsd
import (
"context"
"encoding/base32"
"encoding/base64"
"fmt"
"net"
"strings"
@ -17,16 +18,16 @@ import (
// coredns plugin-specific logger
var logger = clog.NewWithPlugin("wgsd")
// WGSD is a CoreDNS plugin that provides Wireguard peer information via DNS-SD
// WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD
// semantics. WGSD implements the plugin.Handler interface.
type WGSD struct {
Next plugin.Handler
// the client for retrieving Wireguard peer information
// the client for retrieving WireGuard peer information
client wgctrlClient
// the DNS zone we are serving records for
zone string
// the Wireguard device name, e.g. wg0
// the WireGuard device name, e.g. wg0
device string
}
@ -35,7 +36,7 @@ type wgctrlClient interface {
}
const (
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
keyLen = 56 // the number of characters in a base32-encoded WireGuard public key
spPrefix = "_wireguard._udp."
spSubPrefix = "." + spPrefix
serviceInstanceLen = keyLen + len(spSubPrefix)
@ -55,10 +56,10 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
// strip zone from name
name := strings.TrimSuffix(state.Name(), p.zone)
qtype := state.QType()
qType := state.QType()
logger.Debugf("received query for: %s type: %s", name,
dns.TypeToString[qtype])
dns.TypeToString[qType])
device, err := p.client.Device(p.device)
if err != nil {
@ -75,7 +76,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
switch {
// TODO: handle SOA
case name == spPrefix && qtype == dns.TypePTR:
case name == spPrefix && qType == dns.TypePTR:
for _, peer := range device.Peers {
if peer.Endpoint == nil {
continue
@ -94,17 +95,18 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
case len(name) == serviceInstanceLen && qType == dns.TypeSRV:
pubKey := name[:keyLen]
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
txtRR := getTXTRR(state.Name(), peer)
m.Extra = append(m.Extra, hostRR, txtRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
@ -115,25 +117,30 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: strings.ToLower(pubKey) + spSubPrefix + p.zone,
Target: state.Name(),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(p.zone, w, r)
case len(name) == len(spSubPrefix)+keyLen && (qtype == dns.TypeA ||
qtype == dns.TypeAAAA):
case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA ||
qType == dns.TypeAAAA || qType == dns.TypeTXT):
pubKey := name[:keyLen]
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if qType == dns.TypeA || qType == dns.TypeAAAA {
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Answer = append(m.Answer, hostRR)
} else {
txtRR := getTXTRR(state.Name(), peer)
m.Answer = append(m.Answer, txtRR)
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
@ -144,11 +151,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
}
}
func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
if endpoint == nil || endpoint.IP == nil {
return nil
}
name := strings.ToLower(pubKey) + spSubPrefix + zone
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
switch {
case endpoint.IP.To4() != nil:
return &dns.A{
@ -176,6 +179,38 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
}
}
const (
// txtVersion is the first key/value pair in the TXT RR. Its serves to aid
// clients with maintaining backwards compatibility.
//
// https://tools.ietf.org/html/rfc6763#section-6.7
txtVersion = 1
)
func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
var allowedIPs string
for i, prefix := range peer.AllowedIPs {
if i != 0 {
allowedIPs += ","
}
allowedIPs += prefix.String()
}
return &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: []string{
fmt.Sprintf("txtvers=%d", txtVersion),
fmt.Sprintf("pub=%s",
base64.StdEncoding.EncodeToString(peer.PublicKey[:])),
fmt.Sprintf("allowed=%s", allowedIPs),
},
}
}
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetReply(r)

View File

@ -3,6 +3,7 @@ package wgsd
import (
"context"
"encoding/base32"
"encoding/base64"
"fmt"
"net"
"strings"
@ -25,27 +26,50 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
}, nil
}
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
var allowed []net.IPNet
var allowedString string
for i, s := range prefixes {
_, prefix, err := net.ParseCIDR(s)
if err != nil {
t.Fatalf("error parsing cidr: %v", err)
}
allowed = append(allowed, *prefix)
if i != 0 {
allowedString += ","
}
allowedString += prefix.String()
}
return allowed, allowedString
}
func TestWGSD(t *testing.T) {
key1 := [32]byte{}
key1[0] = 1
peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
peer1 := wgtypes.Peer{
Endpoint: &net.UDPAddr{
IP: net.ParseIP("1.1.1.1"),
Port: 1,
},
PublicKey: key1,
AllowedIPs: peer1Allowed,
}
peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:]))
peer1b64 := base64.StdEncoding.EncodeToString(peer1.PublicKey[:])
key2 := [32]byte{}
key2[0] = 2
peer2Allowed, peer2AllowedString := constructAllowedIPs(t, []string{"10.0.0.3/32", "10.0.0.4/32"})
peer2 := wgtypes.Peer{
Endpoint: &net.UDPAddr{
IP: net.ParseIP("::2"),
Port: 2,
},
PublicKey: key2,
AllowedIPs: peer2Allowed,
}
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:])
p := &WGSD{
Next: test.ErrorHandler(),
client: &mockClient{
@ -74,6 +98,7 @@ func TestWGSD(t *testing.T) {
},
Extra: []dns.RR{
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())),
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
},
},
{
@ -85,6 +110,7 @@ func TestWGSD(t *testing.T) {
},
Extra: []dns.RR{
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
},
},
{
@ -103,6 +129,22 @@ func TestWGSD(t *testing.T) {
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
Qtype: dns.TypeTXT,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer2b32),
Qtype: dns.TypeTXT,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
},
},
{
Qname: "nxdomain.example.com.",
Qtype: dns.TypeAAAA,