serve allowed ips and public key via TXT RR

This commit is contained in:
Jordan Whited
2020-12-31 14:14:27 -08:00
committed by Jordan Whited
parent 016a366d0f
commit a928f85a58
2 changed files with 101 additions and 24 deletions

79
wgsd.go
View File

@@ -3,6 +3,7 @@ package wgsd
import (
"context"
"encoding/base32"
"encoding/base64"
"fmt"
"net"
"strings"
@@ -17,16 +18,16 @@ import (
// coredns plugin-specific logger
var logger = clog.NewWithPlugin("wgsd")
// WGSD is a CoreDNS plugin that provides Wireguard peer information via DNS-SD
// WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD
// semantics. WGSD implements the plugin.Handler interface.
type WGSD struct {
Next plugin.Handler
// the client for retrieving Wireguard peer information
// the client for retrieving WireGuard peer information
client wgctrlClient
// the DNS zone we are serving records for
zone string
// the Wireguard device name, e.g. wg0
// the WireGuard device name, e.g. wg0
device string
}
@@ -35,7 +36,7 @@ type wgctrlClient interface {
}
const (
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
keyLen = 56 // the number of characters in a base32-encoded WireGuard public key
spPrefix = "_wireguard._udp."
spSubPrefix = "." + spPrefix
serviceInstanceLen = keyLen + len(spSubPrefix)
@@ -55,10 +56,10 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
// strip zone from name
name := strings.TrimSuffix(state.Name(), p.zone)
qtype := state.QType()
qType := state.QType()
logger.Debugf("received query for: %s type: %s", name,
dns.TypeToString[qtype])
dns.TypeToString[qType])
device, err := p.client.Device(p.device)
if err != nil {
@@ -75,7 +76,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
switch {
// TODO: handle SOA
case name == spPrefix && qtype == dns.TypePTR:
case name == spPrefix && qType == dns.TypePTR:
for _, peer := range device.Peers {
if peer.Endpoint == nil {
continue
@@ -94,17 +95,18 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
case len(name) == serviceInstanceLen && qType == dns.TypeSRV:
pubKey := name[:keyLen]
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
txtRR := getTXTRR(state.Name(), peer)
m.Extra = append(m.Extra, hostRR, txtRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
@@ -115,25 +117,30 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: strings.ToLower(pubKey) + spSubPrefix + p.zone,
Target: state.Name(),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(p.zone, w, r)
case len(name) == len(spSubPrefix)+keyLen && (qtype == dns.TypeA ||
qtype == dns.TypeAAAA):
case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA ||
qType == dns.TypeAAAA || qType == dns.TypeTXT):
pubKey := name[:keyLen]
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
if qType == dns.TypeA || qType == dns.TypeAAAA {
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Answer = append(m.Answer, hostRR)
} else {
txtRR := getTXTRR(state.Name(), peer)
m.Answer = append(m.Answer, txtRR)
}
m.Answer = append(m.Answer, hostRR)
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
@@ -144,11 +151,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
}
}
func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
if endpoint == nil || endpoint.IP == nil {
return nil
}
name := strings.ToLower(pubKey) + spSubPrefix + zone
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
switch {
case endpoint.IP.To4() != nil:
return &dns.A{
@@ -176,6 +179,38 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
}
}
const (
// txtVersion is the first key/value pair in the TXT RR. Its serves to aid
// clients with maintaining backwards compatibility.
//
// https://tools.ietf.org/html/rfc6763#section-6.7
txtVersion = 1
)
func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
var allowedIPs string
for i, prefix := range peer.AllowedIPs {
if i != 0 {
allowedIPs += ","
}
allowedIPs += prefix.String()
}
return &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: []string{
fmt.Sprintf("txtvers=%d", txtVersion),
fmt.Sprintf("pub=%s",
base64.StdEncoding.EncodeToString(peer.PublicKey[:])),
fmt.Sprintf("allowed=%s", allowedIPs),
},
}
}
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetReply(r)