1
0
mirror of https://github.com/jwhited/wgsd.git synced 2025-04-09 05:29:30 +08:00

also respond with records for the host itself

This commit is contained in:
Tobias Krischer 2020-11-01 16:01:39 +01:00
parent 3f4967eb68
commit 0edb1df552
No known key found for this signature in database
GPG Key ID: E9D8FDB171E04060
2 changed files with 70 additions and 27 deletions

@ -2,6 +2,7 @@ package wgsd
import ( import (
"fmt" "fmt"
"net"
"github.com/coredns/caddy" "github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/core/dnsserver"
@ -29,6 +30,14 @@ func setup(c *caddy.Controller) error {
} }
device := c.Val() device := c.Val()
// parse optional local ip
var wgIP net.IP
if c.NextArg() {
wgIP = net.ParseIP(c.Val())
} else {
wgIP = getOutboundIP()
}
// return an error if there are more tokens on this line // return an error if there are more tokens on this line
if c.NextArg() { if c.NextArg() {
return plugin.Error("wgsd", c.ArgErr()) return plugin.Error("wgsd", c.ArgErr())
@ -48,8 +57,20 @@ func setup(c *caddy.Controller) error {
client: client, client: client,
zone: zone, zone: zone,
device: device, device: device,
wgIP: wgIP,
} }
}) })
return nil return nil
} }
// Get preferred outbound ip of this machine
func getOutboundIP() net.IP {
conn, err := net.Dial("udp", "1.1.1.1:80")
if err != nil {
return nil
}
defer conn.Close()
return conn.LocalAddr().(*net.UDPAddr).IP
}

76
wgsd.go

@ -28,12 +28,19 @@ type WGSD struct {
zone string zone string
// the Wireguard device name, e.g. wg0 // the Wireguard device name, e.g. wg0
device string device string
// the IP the local wireguard is running on
wgIP net.IP
} }
type wgctrlClient interface { type wgctrlClient interface {
Device(string) (*wgtypes.Device, error) Device(string) (*wgtypes.Device, error)
} }
type host struct {
PublicKey wgtypes.Key
Endpoint *net.UDPAddr
}
const ( const (
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
spPrefix = "_wireguard._udp." spPrefix = "_wireguard._udp."
@ -63,19 +70,26 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
if err != nil { if err != nil {
return dns.RcodeServerFailure, err return dns.RcodeServerFailure, err
} }
if len(device.Peers) == 0 {
return nxDomain(p.zone, w, r)
}
// setup our reply message // setup our reply message
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Authoritative = true m.Authoritative = true
allPeers := []host{
{PublicKey: device.PublicKey, Endpoint: &net.UDPAddr{Port: device.ListenPort, IP: p.wgIP}},
}
for _, p := range device.Peers {
allPeers = append(allPeers, host{
PublicKey: p.PublicKey,
Endpoint: p.Endpoint,
})
}
switch { switch {
// TODO: handle SOA // TODO: handle SOA
case name == spPrefix && qtype == dns.TypePTR: case name == spPrefix && qtype == dns.TypePTR:
for _, peer := range device.Peers { for _, peer := range allPeers {
if peer.Endpoint == nil { if peer.Endpoint == nil {
continue continue
} }
@ -93,39 +107,21 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
} }
w.WriteMsg(m) // nolint: errcheck w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil return dns.RcodeSuccess, nil
case len(name) != serviceInstanceLen && qtype == dns.TypeSRV:
return p.sendSRV(m, w, r, allPeers[0])
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV: case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
pubKey := name[:keyLen] pubKey := name[:keyLen]
for _, peer := range device.Peers { for _, peer := range allPeers {
if strings.EqualFold( if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint return p.sendSRV(m, w, r, peer)
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: fmt.Sprintf("%s.%s",
strings.ToLower(pubKey), p.zone),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
} }
} }
return nxDomain(p.zone, w, r) return nxDomain(p.zone, w, r)
case len(name) == keyLen+1 && (qtype == dns.TypeA || case len(name) == keyLen+1 && (qtype == dns.TypeA ||
qtype == dns.TypeAAAA): qtype == dns.TypeAAAA):
pubKey := name[:keyLen] pubKey := name[:keyLen]
for _, peer := range device.Peers { for _, peer := range allPeers {
if strings.EqualFold( if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint endpoint := peer.Endpoint
@ -176,6 +172,32 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
} }
} }
func (p *WGSD) sendSRV(m *dns.Msg, w dns.ResponseWriter, r *dns.Msg, peer host) (int, error) {
state := request.Request{W: w, Req: r}
pubKey := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: fmt.Sprintf("%s.%s",
strings.ToLower(pubKey), p.zone),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) { func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)