diff --git a/wgsd.go b/wgsd.go index b8c199a..b70ce45 100644 --- a/wgsd.go +++ b/wgsd.go @@ -3,6 +3,7 @@ package wgsd import ( "context" "encoding/base32" + "encoding/base64" "fmt" "net" "strings" @@ -17,16 +18,16 @@ import ( // coredns plugin-specific logger var logger = clog.NewWithPlugin("wgsd") -// WGSD is a CoreDNS plugin that provides Wireguard peer information via DNS-SD +// WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD // semantics. WGSD implements the plugin.Handler interface. type WGSD struct { Next plugin.Handler - // the client for retrieving Wireguard peer information + // the client for retrieving WireGuard peer information client wgctrlClient // the DNS zone we are serving records for zone string - // the Wireguard device name, e.g. wg0 + // the WireGuard device name, e.g. wg0 device string } @@ -35,7 +36,7 @@ type wgctrlClient interface { } const ( - keyLen = 56 // the number of characters in a base32-encoded Wireguard public key + keyLen = 56 // the number of characters in a base32-encoded WireGuard public key spPrefix = "_wireguard._udp." spSubPrefix = "." + spPrefix serviceInstanceLen = keyLen + len(spSubPrefix) @@ -55,10 +56,10 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, // strip zone from name name := strings.TrimSuffix(state.Name(), p.zone) - qtype := state.QType() + qType := state.QType() logger.Debugf("received query for: %s type: %s", name, - dns.TypeToString[qtype]) + dns.TypeToString[qType]) device, err := p.client.Device(p.device) if err != nil { @@ -75,7 +76,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, switch { // TODO: handle SOA - case name == spPrefix && qtype == dns.TypePTR: + case name == spPrefix && qType == dns.TypePTR: for _, peer := range device.Peers { if peer.Endpoint == nil { continue @@ -94,17 +95,18 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, } w.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil - case len(name) == serviceInstanceLen && qtype == dns.TypeSRV: + case len(name) == serviceInstanceLen && qType == dns.TypeSRV: pubKey := name[:keyLen] for _, peer := range device.Peers { if strings.EqualFold( base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { endpoint := peer.Endpoint - hostRR := getHostRR(pubKey, p.zone, endpoint) + hostRR := getHostRR(state.Name(), endpoint) if hostRR == nil { return nxDomain(p.zone, w, r) } - m.Extra = append(m.Extra, hostRR) + txtRR := getTXTRR(state.Name(), peer) + m.Extra = append(m.Extra, hostRR, txtRR) m.Answer = append(m.Answer, &dns.SRV{ Hdr: dns.RR_Header{ Name: state.Name(), @@ -115,25 +117,30 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, Priority: 0, Weight: 0, Port: uint16(endpoint.Port), - Target: strings.ToLower(pubKey) + spSubPrefix + p.zone, + Target: state.Name(), }) w.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil } } return nxDomain(p.zone, w, r) - case len(name) == len(spSubPrefix)+keyLen && (qtype == dns.TypeA || - qtype == dns.TypeAAAA): + case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA || + qType == dns.TypeAAAA || qType == dns.TypeTXT): pubKey := name[:keyLen] for _, peer := range device.Peers { if strings.EqualFold( base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { endpoint := peer.Endpoint - hostRR := getHostRR(pubKey, p.zone, endpoint) - if hostRR == nil { - return nxDomain(p.zone, w, r) + if qType == dns.TypeA || qType == dns.TypeAAAA { + hostRR := getHostRR(state.Name(), endpoint) + if hostRR == nil { + return nxDomain(p.zone, w, r) + } + m.Answer = append(m.Answer, hostRR) + } else { + txtRR := getTXTRR(state.Name(), peer) + m.Answer = append(m.Answer, txtRR) } - m.Answer = append(m.Answer, hostRR) w.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil } @@ -144,11 +151,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, } } -func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR { - if endpoint == nil || endpoint.IP == nil { - return nil - } - name := strings.ToLower(pubKey) + spSubPrefix + zone +func getHostRR(name string, endpoint *net.UDPAddr) dns.RR { switch { case endpoint.IP.To4() != nil: return &dns.A{ @@ -176,6 +179,38 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR { } } +const ( + // txtVersion is the first key/value pair in the TXT RR. Its serves to aid + // clients with maintaining backwards compatibility. + // + // https://tools.ietf.org/html/rfc6763#section-6.7 + txtVersion = 1 +) + +func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT { + var allowedIPs string + for i, prefix := range peer.AllowedIPs { + if i != 0 { + allowedIPs += "," + } + allowedIPs += prefix.String() + } + return &dns.TXT{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 0, + }, + Txt: []string{ + fmt.Sprintf("txtvers=%d", txtVersion), + fmt.Sprintf("pub=%s", + base64.StdEncoding.EncodeToString(peer.PublicKey[:])), + fmt.Sprintf("allowed=%s", allowedIPs), + }, + } +} + func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) { m := new(dns.Msg) m.SetReply(r) diff --git a/wgsd_test.go b/wgsd_test.go index 35569c5..15741b5 100644 --- a/wgsd_test.go +++ b/wgsd_test.go @@ -3,6 +3,7 @@ package wgsd import ( "context" "encoding/base32" + "encoding/base64" "fmt" "net" "strings" @@ -25,27 +26,50 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) { }, nil } +func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) { + var allowed []net.IPNet + var allowedString string + for i, s := range prefixes { + _, prefix, err := net.ParseCIDR(s) + if err != nil { + t.Fatalf("error parsing cidr: %v", err) + } + allowed = append(allowed, *prefix) + if i != 0 { + allowedString += "," + } + allowedString += prefix.String() + } + return allowed, allowedString +} + func TestWGSD(t *testing.T) { key1 := [32]byte{} key1[0] = 1 + peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"}) peer1 := wgtypes.Peer{ Endpoint: &net.UDPAddr{ IP: net.ParseIP("1.1.1.1"), Port: 1, }, - PublicKey: key1, + PublicKey: key1, + AllowedIPs: peer1Allowed, } peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:])) + peer1b64 := base64.StdEncoding.EncodeToString(peer1.PublicKey[:]) key2 := [32]byte{} key2[0] = 2 + peer2Allowed, peer2AllowedString := constructAllowedIPs(t, []string{"10.0.0.3/32", "10.0.0.4/32"}) peer2 := wgtypes.Peer{ Endpoint: &net.UDPAddr{ IP: net.ParseIP("::2"), Port: 2, }, - PublicKey: key2, + PublicKey: key2, + AllowedIPs: peer2Allowed, } peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:])) + peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:]) p := &WGSD{ Next: test.ErrorHandler(), client: &mockClient{ @@ -74,6 +98,7 @@ func TestWGSD(t *testing.T) { }, Extra: []dns.RR{ test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())), + test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)), }, }, { @@ -85,6 +110,7 @@ func TestWGSD(t *testing.T) { }, Extra: []dns.RR{ test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())), + test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)), }, }, { @@ -103,6 +129,22 @@ func TestWGSD(t *testing.T) { test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())), }, }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32), + Qtype: dns.TypeTXT, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)), + }, + }, + { + Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer2b32), + Qtype: dns.TypeTXT, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)), + }, + }, { Qname: "nxdomain.example.com.", Qtype: dns.TypeAAAA,