also respond with records for the host itself

This commit is contained in:
Tobias Krischer 2020-11-01 16:01:39 +01:00
parent 3f4967eb68
commit 0edb1df552
No known key found for this signature in database
GPG Key ID: E9D8FDB171E04060
2 changed files with 70 additions and 27 deletions

View File

@ -2,6 +2,7 @@ package wgsd
import (
"fmt"
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
@ -29,6 +30,14 @@ func setup(c *caddy.Controller) error {
}
device := c.Val()
// parse optional local ip
var wgIP net.IP
if c.NextArg() {
wgIP = net.ParseIP(c.Val())
} else {
wgIP = getOutboundIP()
}
// return an error if there are more tokens on this line
if c.NextArg() {
return plugin.Error("wgsd", c.ArgErr())
@ -48,8 +57,20 @@ func setup(c *caddy.Controller) error {
client: client,
zone: zone,
device: device,
wgIP: wgIP,
}
})
return nil
}
// Get preferred outbound ip of this machine
func getOutboundIP() net.IP {
conn, err := net.Dial("udp", "1.1.1.1:80")
if err != nil {
return nil
}
defer conn.Close()
return conn.LocalAddr().(*net.UDPAddr).IP
}

76
wgsd.go
View File

@ -28,12 +28,19 @@ type WGSD struct {
zone string
// the Wireguard device name, e.g. wg0
device string
// the IP the local wireguard is running on
wgIP net.IP
}
type wgctrlClient interface {
Device(string) (*wgtypes.Device, error)
}
type host struct {
PublicKey wgtypes.Key
Endpoint *net.UDPAddr
}
const (
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
spPrefix = "_wireguard._udp."
@ -63,19 +70,26 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
if err != nil {
return dns.RcodeServerFailure, err
}
if len(device.Peers) == 0 {
return nxDomain(p.zone, w, r)
}
// setup our reply message
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
allPeers := []host{
{PublicKey: device.PublicKey, Endpoint: &net.UDPAddr{Port: device.ListenPort, IP: p.wgIP}},
}
for _, p := range device.Peers {
allPeers = append(allPeers, host{
PublicKey: p.PublicKey,
Endpoint: p.Endpoint,
})
}
switch {
// TODO: handle SOA
case name == spPrefix && qtype == dns.TypePTR:
for _, peer := range device.Peers {
for _, peer := range allPeers {
if peer.Endpoint == nil {
continue
}
@ -93,39 +107,21 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
case len(name) != serviceInstanceLen && qtype == dns.TypeSRV:
return p.sendSRV(m, w, r, allPeers[0])
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
pubKey := name[:keyLen]
for _, peer := range device.Peers {
for _, peer := range allPeers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: fmt.Sprintf("%s.%s",
strings.ToLower(pubKey), p.zone),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
return p.sendSRV(m, w, r, peer)
}
}
return nxDomain(p.zone, w, r)
case len(name) == keyLen+1 && (qtype == dns.TypeA ||
qtype == dns.TypeAAAA):
pubKey := name[:keyLen]
for _, peer := range device.Peers {
for _, peer := range allPeers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
@ -176,6 +172,32 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
}
}
func (p *WGSD) sendSRV(m *dns.Msg, w dns.ResponseWriter, r *dns.Msg, peer host) (int, error) {
state := request.Request{W: w, Req: r}
pubKey := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: fmt.Sprintf("%s.%s",
strings.ToLower(pubKey), p.zone),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetReply(r)