diff --git a/wgsd.go b/wgsd.go index 5acf349..d560eb9 100644 --- a/wgsd.go +++ b/wgsd.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "net" "strings" "github.com/coredns/coredns/plugin" @@ -56,7 +57,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, return dns.RcodeServerFailure, nil } if len(device.Peers) == 0 { - return nxdomain(p.zone, w, r) + return nxDomain(p.zone, w, r) } // setup our reply message @@ -86,10 +87,11 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, for _, peer := range device.Peers { if base64.StdEncoding.EncodeToString(peer.PublicKey[:]) == pubKey { endpoint := peer.Endpoint - if endpoint.IP == nil { - return nxdomain(p.zone, w, r) + hostRR := getHostRR(pubKey, p.zone, endpoint) + if hostRR == nil { + return nxDomain(p.zone, w, r) } - srvTarget := fmt.Sprintf("%s.%s", pubKey, p.zone) + m.Extra = append(m.Extra, hostRR) m.Answer = append(m.Answer, &dns.SRV{ Hdr: dns.RR_Header{ Name: state.Name(), @@ -100,49 +102,66 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, Priority: 0, Weight: 0, Port: uint16(endpoint.Port), - Target: srvTarget, + Target: fmt.Sprintf("%s.%s", pubKey, p.zone), }) - switch { - case endpoint.IP.To4() != nil: - m.Extra = append(m.Extra, &dns.A{ - Hdr: dns.RR_Header{ - Name: srvTarget, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 0, - }, - A: endpoint.IP, - }) - case endpoint.IP.To16() != nil: - m.Extra = append(m.Extra, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: srvTarget, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 0, - }, - AAAA: endpoint.IP, - }) - default: - // TODO: this shouldn't happen + w.WriteMsg(m) // nolint: errcheck + return dns.RcodeSuccess, nil + } + } + return nxDomain(p.zone, w, r) + case len(name) == keyLen+len(".") && (qtype == dns.TypeA || + qtype == dns.TypeAAAA): + pubKey := name[:44] + for _, peer := range device.Peers { + if base64.StdEncoding.EncodeToString(peer.PublicKey[:]) == pubKey { + endpoint := peer.Endpoint + hostRR := getHostRR(pubKey, p.zone, endpoint) + if hostRR == nil { + return nxDomain(p.zone, w, r) } w.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil } } - return nxdomain(p.zone, w, r) - case len(name) == keyLen+len(".") && (qtype == dns.TypeA || - qtype == dns.TypeAAAA): - // TODO: return A/AAAA for of peer + return nxDomain(p.zone, w, r) default: - return nxdomain(p.zone, w, r) + return nxDomain(p.zone, w, r) } - - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil } -func nxdomain(name string, w dns.ResponseWriter, r *dns.Msg) (int, error) { +func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR { + if endpoint.IP == nil { + return nil + } + name := fmt.Sprintf("%s.%s", pubKey, zone) + switch { + case endpoint.IP.To4() != nil: + return &dns.A{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: endpoint.IP, + } + case endpoint.IP.To16() != nil: + return &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 0, + }, + AAAA: endpoint.IP, + } + default: + // TODO: this shouldn't happen + return nil + } +} + +func nxDomain(name string, w dns.ResponseWriter, r *dns.Msg) (int, error) { m := new(dns.Msg) m.SetReply(r) m.Authoritative = true