jwhited-wgsd/wgsd.go

210 lines
5.0 KiB
Go
Raw Normal View History

2020-05-09 16:47:41 -07:00
package wgsd
import (
"context"
2020-05-13 16:15:48 -07:00
"encoding/base32"
2020-05-12 15:39:48 -07:00
"fmt"
2020-05-12 17:35:05 -07:00
"net"
2020-05-12 15:39:48 -07:00
"strings"
2020-05-09 16:47:41 -07:00
"github.com/coredns/coredns/plugin"
2020-05-13 16:15:48 -07:00
clog "github.com/coredns/coredns/plugin/pkg/log"
2020-05-12 15:39:48 -07:00
"github.com/coredns/coredns/request"
2020-05-09 16:47:41 -07:00
"github.com/miekg/dns"
2020-05-12 15:39:48 -07:00
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
2020-05-09 16:47:41 -07:00
)
2020-05-13 16:15:48 -07:00
// coredns plugin-specific logger
var logger = clog.NewWithPlugin("wgsd")
2020-05-09 16:47:41 -07:00
// WGSD is a CoreDNS plugin that provides Wireguard peer information via DNS-SD
// semantics. WGSD implements the plugin.Handler interface.
type WGSD struct {
Next plugin.Handler
2020-05-12 15:39:48 -07:00
// the client for retrieving Wireguard peer information
client wgctrlClient
// the DNS zone we are serving records for
zone string
// the Wireguard device name, e.g. wg0
device string
}
type wgctrlClient interface {
Device(string) (*wgtypes.Device, error)
}
const (
2020-05-13 16:15:48 -07:00
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
2020-05-12 15:39:48 -07:00
spPrefix = "_wireguard._udp."
serviceInstanceLen = keyLen + len(".") + len(spPrefix)
)
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
r *dns.Msg) (int, error) {
// request.Request is a convenience struct we wrap around the msg and
// ResponseWriter.
state := request.Request{W: w, Req: r}
// Check if the request is for the zone we are serving. If it doesn't match
// we pass the request on to the next plugin.
if plugin.Zones([]string{p.zone}).Matches(state.Name()) == "" {
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
}
// strip zone from name
name := strings.TrimSuffix(state.Name(), p.zone)
qtype := state.QType()
2020-05-13 16:15:48 -07:00
logger.Debugf("received query for: %s type: %s", name,
dns.TypeToString[qtype])
2020-05-12 17:13:40 -07:00
device, err := p.client.Device(p.device)
if err != nil {
2020-05-13 16:15:48 -07:00
return dns.RcodeServerFailure, err
2020-05-12 17:13:40 -07:00
}
if len(device.Peers) == 0 {
2020-05-12 17:35:05 -07:00
return nxDomain(p.zone, w, r)
2020-05-12 17:13:40 -07:00
}
2020-05-12 15:39:48 -07:00
// setup our reply message
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
switch {
2020-05-12 17:37:20 -07:00
// TODO: handle SOA
2020-05-12 15:39:48 -07:00
case name == spPrefix && qtype == dns.TypePTR:
for _, peer := range device.Peers {
if peer.Endpoint == nil {
continue
}
2020-05-12 15:39:48 -07:00
m.Answer = append(m.Answer, &dns.PTR{
Hdr: dns.RR_Header{
2020-05-12 17:13:40 -07:00
Name: state.Name(),
2020-05-12 15:39:48 -07:00
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
},
2020-05-13 20:59:42 -07:00
Ptr: fmt.Sprintf("%s.%s%s",
2020-05-27 16:29:09 -07:00
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
2020-05-13 20:59:42 -07:00
spPrefix, p.zone),
2020-05-12 15:39:48 -07:00
})
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
2020-05-13 16:15:48 -07:00
pubKey := name[:keyLen]
2020-05-12 17:13:40 -07:00
for _, peer := range device.Peers {
2020-05-13 16:15:48 -07:00
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
2020-05-12 17:13:40 -07:00
endpoint := peer.Endpoint
2020-05-12 17:35:05 -07:00
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
2020-05-12 17:13:40 -07:00
}
2020-05-12 17:35:05 -07:00
m.Extra = append(m.Extra, hostRR)
2020-05-12 17:13:40 -07:00
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
2020-05-13 20:59:42 -07:00
Target: fmt.Sprintf("%s.%s",
2020-05-27 16:29:09 -07:00
strings.ToLower(pubKey), p.zone),
2020-05-12 17:13:40 -07:00
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
2020-05-12 17:35:05 -07:00
return nxDomain(p.zone, w, r)
2020-05-13 16:15:48 -07:00
case len(name) == keyLen+1 && (qtype == dns.TypeA ||
2020-05-12 15:39:48 -07:00
qtype == dns.TypeAAAA):
2020-05-13 16:15:48 -07:00
pubKey := name[:keyLen]
2020-05-12 17:35:05 -07:00
for _, peer := range device.Peers {
2020-05-13 16:15:48 -07:00
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
2020-05-12 17:35:05 -07:00
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
2020-05-13 16:15:48 -07:00
m.Answer = append(m.Answer, hostRR)
2020-05-12 17:35:05 -07:00
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(p.zone, w, r)
2020-05-12 15:39:48 -07:00
default:
2020-05-12 17:35:05 -07:00
return nxDomain(p.zone, w, r)
2020-05-12 15:39:48 -07:00
}
2020-05-12 17:35:05 -07:00
}
2020-05-12 15:39:48 -07:00
2020-05-12 17:35:05 -07:00
func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
2020-05-15 13:01:58 -07:00
if endpoint == nil || endpoint.IP == nil {
2020-05-12 17:35:05 -07:00
return nil
}
2020-05-27 16:29:09 -07:00
name := fmt.Sprintf("%s.%s", strings.ToLower(pubKey), zone)
2020-05-12 17:35:05 -07:00
switch {
case endpoint.IP.To4() != nil:
return &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: endpoint.IP,
}
case endpoint.IP.To16() != nil:
return &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: endpoint.IP,
}
default:
// TODO: this shouldn't happen
return nil
}
2020-05-12 15:39:48 -07:00
}
2020-05-12 17:44:52 -07:00
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
2020-05-12 15:39:48 -07:00
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
m.Rcode = dns.RcodeNameError
2020-05-12 17:44:52 -07:00
m.Ns = []dns.RR{soa(zone)}
2020-05-12 15:39:48 -07:00
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
2020-05-09 16:47:41 -07:00
}
2020-05-12 17:44:52 -07:00
func soa(zone string) dns.RR {
2020-05-12 15:39:48 -07:00
return &dns.SOA{
Hdr: dns.RR_Header{
2020-05-12 17:44:52 -07:00
Name: zone,
2020-05-12 15:39:48 -07:00
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 60,
},
2020-05-12 17:44:52 -07:00
Ns: fmt.Sprintf("ns1.%s", zone),
Mbox: fmt.Sprintf("postmaster.%s", zone),
2020-05-12 15:39:48 -07:00
Serial: 1,
Refresh: 86400,
Retry: 7200,
Expire: 3600000,
Minttl: 60,
}
2020-05-09 16:47:41 -07:00
}
2020-05-12 15:39:48 -07:00
func (p *WGSD) Name() string {
2020-05-09 16:47:41 -07:00
return "wgsd"
}