mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-18 22:09:34 +08:00
327 lines
7.9 KiB
Go
327 lines
7.9 KiB
Go
package wgsd
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base32"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
|
|
"github.com/coredns/coredns/plugin"
|
|
clog "github.com/coredns/coredns/plugin/pkg/log"
|
|
"github.com/coredns/coredns/request"
|
|
"github.com/miekg/dns"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
// coredns plugin-specific logger
|
|
var logger = clog.NewWithPlugin(pluginName)
|
|
|
|
const (
|
|
pluginName = "wgsd"
|
|
)
|
|
|
|
// WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD
|
|
// semantics. WGSD implements the plugin.Handler interface.
|
|
type WGSD struct {
|
|
Next plugin.Handler
|
|
Zones
|
|
client wgctrlClient // the client for retrieving WireGuard peer information
|
|
}
|
|
|
|
type Zones struct {
|
|
Z map[string]*Zone // a mapping from zone name to zone data
|
|
Names []string // all keys from the map z as a string slice
|
|
}
|
|
|
|
type Zone struct {
|
|
name string // the name of the zone we are authoritative for
|
|
device string // the WireGuard device name, e.g. wg0
|
|
serveSelf bool // flag to enable serving data about self
|
|
selfEndpoint *net.UDPAddr // overrides the self endpoint value
|
|
selfAllowedIPs []net.IPNet // self allowed IPs
|
|
}
|
|
|
|
type wgctrlClient interface {
|
|
Device(string) (*wgtypes.Device, error)
|
|
}
|
|
|
|
const (
|
|
keyLen = 56 // the number of characters in a base32-encoded WireGuard public key
|
|
spPrefix = "_wireguard._udp."
|
|
spSubPrefix = "." + spPrefix
|
|
serviceInstanceLen = keyLen + len(spSubPrefix)
|
|
)
|
|
|
|
type handlerFn func(state request.Request, peers []wgtypes.Peer) (int, error)
|
|
|
|
func getHandlerFn(queryType uint16, name string) handlerFn {
|
|
switch {
|
|
case name == spPrefix && queryType == dns.TypePTR:
|
|
return handlePTR
|
|
case len(name) == serviceInstanceLen && queryType == dns.TypeSRV:
|
|
return handleSRV
|
|
case len(name) == len(spSubPrefix)+keyLen && (queryType == dns.TypeA ||
|
|
queryType == dns.TypeAAAA || queryType == dns.TypeTXT):
|
|
return handleHostOrTXT
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func handlePTR(state request.Request, peers []wgtypes.Peer) (int, error) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(state.Req)
|
|
m.Authoritative = true
|
|
for _, peer := range peers {
|
|
if peer.Endpoint == nil {
|
|
continue
|
|
}
|
|
m.Answer = append(m.Answer, &dns.PTR{
|
|
Hdr: dns.RR_Header{
|
|
Name: state.Name(),
|
|
Rrtype: dns.TypePTR,
|
|
Class: dns.ClassINET,
|
|
Ttl: 0,
|
|
},
|
|
Ptr: fmt.Sprintf("%s.%s%s",
|
|
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
|
|
spPrefix, state.Zone),
|
|
})
|
|
}
|
|
state.W.WriteMsg(m) // nolint: errcheck
|
|
return dns.RcodeSuccess, nil
|
|
}
|
|
|
|
func handleSRV(state request.Request, peers []wgtypes.Peer) (int, error) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(state.Req)
|
|
m.Authoritative = true
|
|
pubKey := state.Name()[:keyLen]
|
|
for _, peer := range peers {
|
|
if strings.EqualFold(
|
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
|
endpoint := peer.Endpoint
|
|
if endpoint == nil {
|
|
return nxDomain(state)
|
|
}
|
|
hostRR := getHostRR(state.Name(), endpoint)
|
|
if hostRR == nil {
|
|
return nxDomain(state)
|
|
}
|
|
txtRR := getTXTRR(state.Name(), peer)
|
|
m.Extra = append(m.Extra, hostRR, txtRR)
|
|
m.Answer = append(m.Answer, &dns.SRV{
|
|
Hdr: dns.RR_Header{
|
|
Name: state.Name(),
|
|
Rrtype: dns.TypeSRV,
|
|
Class: dns.ClassINET,
|
|
Ttl: 0,
|
|
},
|
|
Priority: 0,
|
|
Weight: 0,
|
|
Port: uint16(endpoint.Port),
|
|
Target: state.Name(),
|
|
})
|
|
state.W.WriteMsg(m) // nolint: errcheck
|
|
return dns.RcodeSuccess, nil
|
|
}
|
|
}
|
|
return nxDomain(state)
|
|
}
|
|
|
|
func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(state.Req)
|
|
m.Authoritative = true
|
|
pubKey := state.Name()[:keyLen]
|
|
for _, peer := range peers {
|
|
if strings.EqualFold(
|
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
|
endpoint := peer.Endpoint
|
|
if endpoint == nil {
|
|
return nxDomain(state)
|
|
}
|
|
if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA {
|
|
hostRR := getHostRR(state.Name(), endpoint)
|
|
if hostRR == nil {
|
|
return nxDomain(state)
|
|
}
|
|
m.Answer = append(m.Answer, hostRR)
|
|
} else {
|
|
txtRR := getTXTRR(state.Name(), peer)
|
|
m.Answer = append(m.Answer, txtRR)
|
|
}
|
|
state.W.WriteMsg(m) // nolint: errcheck
|
|
return dns.RcodeSuccess, nil
|
|
}
|
|
}
|
|
return nxDomain(state)
|
|
}
|
|
|
|
func getSelfPeer(zone *Zone, device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
|
|
self := wgtypes.Peer{
|
|
PublicKey: device.PublicKey,
|
|
}
|
|
if zone.selfEndpoint != nil {
|
|
self.Endpoint = zone.selfEndpoint
|
|
} else {
|
|
self.Endpoint = &net.UDPAddr{
|
|
IP: net.ParseIP(state.LocalIP()),
|
|
Port: device.ListenPort,
|
|
}
|
|
}
|
|
self.AllowedIPs = zone.selfAllowedIPs
|
|
return self, nil
|
|
}
|
|
|
|
func getPeers(client wgctrlClient, zone *Zone, state request.Request) (
|
|
[]wgtypes.Peer, error) {
|
|
peers := make([]wgtypes.Peer, 0)
|
|
device, err := client.Device(zone.device)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
peers = append(peers, device.Peers...)
|
|
if zone.serveSelf {
|
|
self, err := getSelfPeer(zone, device, state)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
peers = append(peers, self)
|
|
}
|
|
return peers, nil
|
|
}
|
|
|
|
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|
r *dns.Msg) (int, error) {
|
|
// request.Request is a convenience struct we wrap around the msg and
|
|
// ResponseWriter.
|
|
state := request.Request{W: w, Req: r}
|
|
|
|
// Check if the request is for a zone we are serving. If it doesn't match we
|
|
// pass the request on to the next plugin.
|
|
zoneName := plugin.Zones(p.Names).Matches(state.Name())
|
|
if zoneName == "" {
|
|
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
|
}
|
|
state.Zone = zoneName
|
|
|
|
zone, ok := p.Z[zoneName]
|
|
if !ok {
|
|
return dns.RcodeServerFailure, nil
|
|
}
|
|
|
|
// strip zone from name
|
|
name := strings.TrimSuffix(state.Name(), zoneName)
|
|
queryType := state.QType()
|
|
|
|
logger.Debugf("received query for: %s type: %s", name,
|
|
dns.TypeToString[queryType])
|
|
|
|
handler := getHandlerFn(queryType, name)
|
|
if handler == nil {
|
|
return nxDomain(state)
|
|
}
|
|
|
|
peers, err := getPeers(p.client, zone, state)
|
|
if err != nil {
|
|
return dns.RcodeServerFailure, err
|
|
}
|
|
|
|
return handler(state, peers)
|
|
}
|
|
|
|
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
|
switch {
|
|
case endpoint.IP.To4() != nil:
|
|
return &dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: 0,
|
|
},
|
|
A: endpoint.IP,
|
|
}
|
|
case endpoint.IP.To16() != nil:
|
|
return &dns.AAAA{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: 0,
|
|
},
|
|
AAAA: endpoint.IP,
|
|
}
|
|
default:
|
|
// TODO: this shouldn't happen
|
|
return nil
|
|
}
|
|
}
|
|
|
|
const (
|
|
// txtVersion is the first key/value pair in the TXT RR. Its serves to aid
|
|
// clients with maintaining backwards compatibility.
|
|
//
|
|
// https://tools.ietf.org/html/rfc6763#section-6.7
|
|
txtVersion = 1
|
|
)
|
|
|
|
func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
|
|
var allowedIPs string
|
|
for i, prefix := range peer.AllowedIPs {
|
|
if i != 0 {
|
|
allowedIPs += ","
|
|
}
|
|
allowedIPs += prefix.String()
|
|
}
|
|
return &dns.TXT{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeTXT,
|
|
Class: dns.ClassINET,
|
|
Ttl: 0,
|
|
},
|
|
Txt: []string{
|
|
fmt.Sprintf("txtvers=%d", txtVersion),
|
|
fmt.Sprintf("pub=%s",
|
|
base64.StdEncoding.EncodeToString(peer.PublicKey[:])),
|
|
fmt.Sprintf("allowed=%s", allowedIPs),
|
|
},
|
|
}
|
|
}
|
|
|
|
func nxDomain(state request.Request) (int, error) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(state.Req)
|
|
m.Authoritative = true
|
|
m.Rcode = dns.RcodeNameError
|
|
m.Ns = []dns.RR{soa(state.Zone)}
|
|
state.W.WriteMsg(m) // nolint: errcheck
|
|
return dns.RcodeSuccess, nil
|
|
}
|
|
|
|
func soa(zone string) dns.RR {
|
|
return &dns.SOA{
|
|
Hdr: dns.RR_Header{
|
|
Name: zone,
|
|
Rrtype: dns.TypeSOA,
|
|
Class: dns.ClassINET,
|
|
Ttl: 60,
|
|
},
|
|
Ns: fmt.Sprintf("ns1.%s", zone),
|
|
Mbox: fmt.Sprintf("postmaster.%s", zone),
|
|
Serial: 1,
|
|
Refresh: 86400,
|
|
Retry: 7200,
|
|
Expire: 3600000,
|
|
Minttl: 60,
|
|
}
|
|
}
|
|
|
|
func (p *WGSD) Name() string {
|
|
return pluginName
|
|
}
|