jwhited-wgsd/wgsd.go

321 lines
7.7 KiB
Go
Raw Normal View History

2020-05-09 16:47:41 -07:00
package wgsd
import (
"context"
2020-05-13 16:15:48 -07:00
"encoding/base32"
"encoding/base64"
2020-05-12 15:39:48 -07:00
"fmt"
2020-05-12 17:35:05 -07:00
"net"
2020-05-12 15:39:48 -07:00
"strings"
2020-05-09 16:47:41 -07:00
"github.com/coredns/coredns/plugin"
2020-05-13 16:15:48 -07:00
clog "github.com/coredns/coredns/plugin/pkg/log"
2020-05-12 15:39:48 -07:00
"github.com/coredns/coredns/request"
2020-05-09 16:47:41 -07:00
"github.com/miekg/dns"
2020-05-12 15:39:48 -07:00
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
2020-05-09 16:47:41 -07:00
)
2020-05-13 16:15:48 -07:00
// coredns plugin-specific logger
var logger = clog.NewWithPlugin(pluginName)
const (
pluginName = "wgsd"
)
2020-05-13 16:15:48 -07:00
// WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD
2020-05-09 16:47:41 -07:00
// semantics. WGSD implements the plugin.Handler interface.
type WGSD struct {
Next plugin.Handler
2021-01-02 15:51:53 -08:00
Zones
client wgctrlClient // the client for retrieving WireGuard peer information
}
type Zones struct {
Z map[string]*Zone // a mapping from zone name to zone data
Names []string // all keys from the map z as a string slice
}
2020-05-12 15:39:48 -07:00
2021-01-02 15:51:53 -08:00
type Zone struct {
name string // the name of the zone we are authoritative for
device string // the WireGuard device name, e.g. wg0
serveSelf bool // flag to enable serving data about self
selfEndpoint *net.UDPAddr // overrides the self endpoint value
selfAllowedIPs []net.IPNet // self allowed IPs
2020-05-12 15:39:48 -07:00
}
type wgctrlClient interface {
Device(string) (*wgtypes.Device, error)
}
const (
keyLen = 56 // the number of characters in a base32-encoded WireGuard public key
2020-05-12 15:39:48 -07:00
spPrefix = "_wireguard._udp."
spSubPrefix = "." + spPrefix
serviceInstanceLen = keyLen + len(spSubPrefix)
2020-05-12 15:39:48 -07:00
)
type handlerFn func(state request.Request, peers []wgtypes.Peer) (int, error)
2021-01-01 15:41:11 -08:00
func getHandlerFn(queryType uint16, name string) handlerFn {
switch {
case name == spPrefix && queryType == dns.TypePTR:
return handlePTR
case len(name) == serviceInstanceLen && queryType == dns.TypeSRV:
return handleSRV
case len(name) == len(spSubPrefix)+keyLen && (queryType == dns.TypeA ||
queryType == dns.TypeAAAA || queryType == dns.TypeTXT):
return handleHostOrTXT
default:
return nil
}
}
func handlePTR(state request.Request, peers []wgtypes.Peer) (int, error) {
2021-01-01 15:41:11 -08:00
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
for _, peer := range peers {
if peer.Endpoint == nil {
continue
}
m.Answer = append(m.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
},
Ptr: fmt.Sprintf("%s.%s%s",
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
spPrefix, state.Zone),
})
}
state.W.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
func handleSRV(state request.Request, peers []wgtypes.Peer) (int, error) {
2021-01-01 15:41:11 -08:00
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
pubKey := state.Name()[:keyLen]
for _, peer := range peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(state)
}
txtRR := getTXTRR(state.Name(), peer)
m.Extra = append(m.Extra, hostRR, txtRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: state.Name(),
})
state.W.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(state)
}
func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) {
2021-01-01 15:41:11 -08:00
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
pubKey := state.Name()[:keyLen]
for _, peer := range peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA {
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(state)
}
m.Answer = append(m.Answer, hostRR)
} else {
txtRR := getTXTRR(state.Name(), peer)
m.Answer = append(m.Answer, txtRR)
}
state.W.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(state)
}
2021-01-02 15:51:53 -08:00
func getSelfPeer(zone *Zone, device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
2021-01-01 16:23:12 -08:00
self := wgtypes.Peer{
PublicKey: device.PublicKey,
}
2021-01-02 15:51:53 -08:00
if zone.selfEndpoint != nil {
self.Endpoint = zone.selfEndpoint
2021-01-01 16:23:12 -08:00
} else {
self.Endpoint = &net.UDPAddr{
IP: net.ParseIP(state.LocalIP()),
Port: device.ListenPort,
}
}
2021-01-02 15:51:53 -08:00
self.AllowedIPs = zone.selfAllowedIPs
2021-01-01 16:23:12 -08:00
return self, nil
}
2021-01-02 15:51:53 -08:00
func getPeers(client wgctrlClient, zone *Zone, state request.Request) (
[]wgtypes.Peer, error) {
2021-01-01 16:23:12 -08:00
peers := make([]wgtypes.Peer, 0)
2021-01-02 15:51:53 -08:00
device, err := client.Device(zone.device)
2021-01-01 16:23:12 -08:00
if err != nil {
return nil, err
}
peers = append(peers, device.Peers...)
2021-01-02 15:51:53 -08:00
if zone.serveSelf {
self, err := getSelfPeer(zone, device, state)
if err != nil {
return nil, err
}
peers = append(peers, self)
2021-01-01 16:23:12 -08:00
}
2021-01-02 15:51:53 -08:00
return peers, nil
2021-01-01 16:23:12 -08:00
}
2020-05-12 15:39:48 -07:00
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
r *dns.Msg) (int, error) {
// request.Request is a convenience struct we wrap around the msg and
// ResponseWriter.
2021-01-02 15:51:53 -08:00
state := request.Request{W: w, Req: r}
2020-05-12 15:39:48 -07:00
2021-01-02 15:51:53 -08:00
// Check if the request is for a zone we are serving. If it doesn't match we
// pass the request on to the next plugin.
zoneName := plugin.Zones(p.Names).Matches(state.Name())
if zoneName == "" {
2020-05-12 15:39:48 -07:00
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
}
2021-01-02 15:51:53 -08:00
state.Zone = zoneName
zone, ok := p.Z[zoneName]
if !ok {
return dns.RcodeServerFailure, nil
}
2020-05-12 15:39:48 -07:00
// strip zone from name
2021-01-02 15:51:53 -08:00
name := strings.TrimSuffix(state.Name(), zoneName)
2021-01-01 15:41:11 -08:00
queryType := state.QType()
2020-05-12 15:39:48 -07:00
2020-05-13 16:15:48 -07:00
logger.Debugf("received query for: %s type: %s", name,
2021-01-01 15:41:11 -08:00
dns.TypeToString[queryType])
handler := getHandlerFn(queryType, name)
if handler == nil {
return nxDomain(state)
}
2020-05-13 16:15:48 -07:00
2021-01-02 15:51:53 -08:00
peers, err := getPeers(p.client, zone, state)
2020-05-12 17:13:40 -07:00
if err != nil {
2020-05-13 16:15:48 -07:00
return dns.RcodeServerFailure, err
2020-05-12 17:13:40 -07:00
}
return handler(state, peers)
2020-05-12 17:35:05 -07:00
}
2020-05-12 15:39:48 -07:00
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
2020-05-12 17:35:05 -07:00
switch {
case endpoint.IP.To4() != nil:
return &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: endpoint.IP,
}
case endpoint.IP.To16() != nil:
return &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: endpoint.IP,
}
default:
// TODO: this shouldn't happen
return nil
}
2020-05-12 15:39:48 -07:00
}
const (
// txtVersion is the first key/value pair in the TXT RR. Its serves to aid
// clients with maintaining backwards compatibility.
//
// https://tools.ietf.org/html/rfc6763#section-6.7
txtVersion = 1
)
func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
var allowedIPs string
for i, prefix := range peer.AllowedIPs {
if i != 0 {
allowedIPs += ","
}
allowedIPs += prefix.String()
}
return &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: []string{
fmt.Sprintf("txtvers=%d", txtVersion),
fmt.Sprintf("pub=%s",
base64.StdEncoding.EncodeToString(peer.PublicKey[:])),
fmt.Sprintf("allowed=%s", allowedIPs),
},
}
}
2021-01-01 15:41:11 -08:00
func nxDomain(state request.Request) (int, error) {
2020-05-12 15:39:48 -07:00
m := new(dns.Msg)
2021-01-01 15:41:11 -08:00
m.SetReply(state.Req)
2020-05-12 15:39:48 -07:00
m.Authoritative = true
m.Rcode = dns.RcodeNameError
2021-01-01 15:41:11 -08:00
m.Ns = []dns.RR{soa(state.Zone)}
state.W.WriteMsg(m) // nolint: errcheck
2020-05-12 15:39:48 -07:00
return dns.RcodeSuccess, nil
2020-05-09 16:47:41 -07:00
}
2020-05-12 17:44:52 -07:00
func soa(zone string) dns.RR {
2020-05-12 15:39:48 -07:00
return &dns.SOA{
Hdr: dns.RR_Header{
2020-05-12 17:44:52 -07:00
Name: zone,
2020-05-12 15:39:48 -07:00
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 60,
},
2020-05-12 17:44:52 -07:00
Ns: fmt.Sprintf("ns1.%s", zone),
Mbox: fmt.Sprintf("postmaster.%s", zone),
2020-05-12 15:39:48 -07:00
Serial: 1,
Refresh: 86400,
Retry: 7200,
Expire: 3600000,
Minttl: 60,
}
2020-05-09 16:47:41 -07:00
}
2020-05-12 15:39:48 -07:00
func (p *WGSD) Name() string {
return pluginName
2020-05-09 16:47:41 -07:00
}