diff --git a/setup.go b/setup.go index ee057ae..0d670a0 100644 --- a/setup.go +++ b/setup.go @@ -2,6 +2,7 @@ package wgsd import ( "fmt" + "net" "github.com/coredns/caddy" "github.com/coredns/coredns/core/dnsserver" @@ -29,6 +30,14 @@ func setup(c *caddy.Controller) error { } device := c.Val() + // parse optional local ip + var wgIP net.IP + if c.NextArg() { + wgIP = net.ParseIP(c.Val()) + } else { + wgIP = getOutboundIP() + } + // return an error if there are more tokens on this line if c.NextArg() { return plugin.Error("wgsd", c.ArgErr()) @@ -48,8 +57,20 @@ func setup(c *caddy.Controller) error { client: client, zone: zone, device: device, + wgIP: wgIP, } }) return nil } + +// Get preferred outbound ip of this machine +func getOutboundIP() net.IP { + conn, err := net.Dial("udp", "1.1.1.1:80") + if err != nil { + return nil + } + defer conn.Close() + + return conn.LocalAddr().(*net.UDPAddr).IP +} diff --git a/wgsd.go b/wgsd.go index a63529e..67e8136 100644 --- a/wgsd.go +++ b/wgsd.go @@ -28,12 +28,19 @@ type WGSD struct { zone string // the Wireguard device name, e.g. wg0 device string + // the IP the local wireguard is running on + wgIP net.IP } type wgctrlClient interface { Device(string) (*wgtypes.Device, error) } +type host struct { + PublicKey wgtypes.Key + Endpoint *net.UDPAddr +} + const ( keyLen = 56 // the number of characters in a base32-encoded Wireguard public key spPrefix = "_wireguard._udp." @@ -63,19 +70,26 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, if err != nil { return dns.RcodeServerFailure, err } - if len(device.Peers) == 0 { - return nxDomain(p.zone, w, r) - } // setup our reply message m := new(dns.Msg) m.SetReply(r) m.Authoritative = true + allPeers := []host{ + {PublicKey: device.PublicKey, Endpoint: &net.UDPAddr{Port: device.ListenPort, IP: p.wgIP}}, + } + for _, p := range device.Peers { + allPeers = append(allPeers, host{ + PublicKey: p.PublicKey, + Endpoint: p.Endpoint, + }) + } + switch { // TODO: handle SOA case name == spPrefix && qtype == dns.TypePTR: - for _, peer := range device.Peers { + for _, peer := range allPeers { if peer.Endpoint == nil { continue } @@ -93,39 +107,21 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, } w.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil + case len(name) != serviceInstanceLen && qtype == dns.TypeSRV: + return p.sendSRV(m, w, r, allPeers[0]) case len(name) == serviceInstanceLen && qtype == dns.TypeSRV: pubKey := name[:keyLen] - for _, peer := range device.Peers { + for _, peer := range allPeers { if strings.EqualFold( base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { - endpoint := peer.Endpoint - hostRR := getHostRR(pubKey, p.zone, endpoint) - if hostRR == nil { - return nxDomain(p.zone, w, r) - } - m.Extra = append(m.Extra, hostRR) - m.Answer = append(m.Answer, &dns.SRV{ - Hdr: dns.RR_Header{ - Name: state.Name(), - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: 0, - }, - Priority: 0, - Weight: 0, - Port: uint16(endpoint.Port), - Target: fmt.Sprintf("%s.%s", - strings.ToLower(pubKey), p.zone), - }) - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil + return p.sendSRV(m, w, r, peer) } } return nxDomain(p.zone, w, r) case len(name) == keyLen+1 && (qtype == dns.TypeA || qtype == dns.TypeAAAA): pubKey := name[:keyLen] - for _, peer := range device.Peers { + for _, peer := range allPeers { if strings.EqualFold( base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { endpoint := peer.Endpoint @@ -176,6 +172,32 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR { } } +func (p *WGSD) sendSRV(m *dns.Msg, w dns.ResponseWriter, r *dns.Msg, peer host) (int, error) { + state := request.Request{W: w, Req: r} + pubKey := base32.StdEncoding.EncodeToString(peer.PublicKey[:]) + endpoint := peer.Endpoint + hostRR := getHostRR(pubKey, p.zone, endpoint) + if hostRR == nil { + return nxDomain(p.zone, w, r) + } + m.Extra = append(m.Extra, hostRR) + m.Answer = append(m.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: state.Name(), + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 0, + }, + Priority: 0, + Weight: 0, + Port: uint16(endpoint.Port), + Target: fmt.Sprintf("%s.%s", + strings.ToLower(pubKey), p.zone), + }) + w.WriteMsg(m) // nolint: errcheck + return dns.RcodeSuccess, nil +} + func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) { m := new(dns.Msg) m.SetReply(r)