diff --git a/wgsd.go b/wgsd.go index 7baf360..3cfecf4 100644 --- a/wgsd.go +++ b/wgsd.go @@ -35,12 +35,13 @@ type wgctrlClient interface { } const ( - keyLen = 56 // the number of characters in a base32-encoded Wireguard public key - spPrefix = "_wireguard._udp." - spSubPrefix = "." + spPrefix - serviceInstanceLen = keyLen + len(spSubPrefix) + keyLen = 56 // the number of characters in a base32-encoded Wireguard public key + spPrefix = "_wireguard._udp." + spSubPrefix = "." + spPrefix ) +var emptySubnet = net.IPNet{IP: net.IPv4zero, Mask: net.IPv4Mask(0, 0, 0, 0)} + func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { // request.Request is a convenience struct we wrap around the msg and @@ -94,32 +95,61 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, } w.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil - case len(name) == serviceInstanceLen && qtype == dns.TypeSRV: - pubKey := name[:keyLen] + case qtype == dns.TypeSRV && strings.HasSuffix(name, spSubPrefix): + name = name[:len(name)-len(spSubPrefix)] + logger.Debugf("received query for: %s type: %s", name, + dns.TypeToString[qtype]) for _, peer := range device.Peers { - if strings.EqualFold( - base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { - endpoint := peer.Endpoint - hostRR := getHostRR(pubKey, p.zone, endpoint) - if hostRR == nil { - return nxDomain(p.zone, w, r) - } - m.Extra = append(m.Extra, hostRR) - m.Answer = append(m.Answer, &dns.SRV{ - Hdr: dns.RR_Header{ - Name: state.Name(), - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: 0, - }, - Priority: 0, - Weight: 0, - Port: uint16(endpoint.Port), - Target: pubKey + spSubPrefix + p.zone, - }) - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil + pubKey := base32.StdEncoding.EncodeToString(peer.PublicKey[:]) + allowedIPs := &emptySubnet + if len(peer.AllowedIPs) >= 1 { + allowedIPs = &peer.AllowedIPs[0] } + if len(name) == keyLen { + // check by keyname + if !strings.EqualFold(pubKey, name[:keyLen]) { + continue + } + } else { + // check by ip + ip := net.ParseIP(name) + if ip == nil || !allowedIPs.Contains(ip) { + continue + } + } + endpoint := peer.Endpoint + hostRR := getHostRR(pubKey, p.zone, endpoint) + if hostRR == nil { + return nxDomain(p.zone, w, r) + } + m.Extra = append(m.Extra, hostRR) + pubKey = strings.ToLower(pubKey) + m.Extra = append(m.Extra, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: state.Name(), + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 0, + }, + Txt: []string{ + "allowedip=" + allowedIPs.String(), + "pubkey=" + strings.TrimRight(pubKey, "="), + }, + }) + m.Answer = append(m.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: state.Name(), + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 0, + }, + Priority: 0, + Weight: 0, + Port: uint16(endpoint.Port), + Target: pubKey + spSubPrefix + p.zone, + }) + w.WriteMsg(m) // nolint: errcheck + return dns.RcodeSuccess, nil } return nxDomain(p.zone, w, r) case len(name) == len(spSubPrefix)+keyLen && (qtype == dns.TypeA || diff --git a/wgsd_test.go b/wgsd_test.go index 35569c5..7b92053 100644 --- a/wgsd_test.go +++ b/wgsd_test.go @@ -28,22 +28,26 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) { func TestWGSD(t *testing.T) { key1 := [32]byte{} key1[0] = 1 + _, allowedip1, _ := net.ParseCIDR("2.2.2.2/32") peer1 := wgtypes.Peer{ Endpoint: &net.UDPAddr{ IP: net.ParseIP("1.1.1.1"), Port: 1, }, - PublicKey: key1, + AllowedIPs: []net.IPNet{*allowedip1}, + PublicKey: key1, } peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:])) key2 := [32]byte{} key2[0] = 2 + _, allowedip2, _ := net.ParseCIDR("::3/128") peer2 := wgtypes.Peer{ Endpoint: &net.UDPAddr{ IP: net.ParseIP("::2"), Port: 2, }, - PublicKey: key2, + AllowedIPs: []net.IPNet{*allowedip2}, + PublicKey: key2, } peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:])) p := &WGSD{ @@ -74,6 +78,7 @@ func TestWGSD(t *testing.T) { }, Extra: []dns.RR{ test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())), + test.TXT(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN TXT allowedip=%s pubkey=%s", peer1b32, peer1.AllowedIPs[0].String(), strings.TrimRight(peer1b32, "="))), }, }, { @@ -85,6 +90,7 @@ func TestWGSD(t *testing.T) { }, Extra: []dns.RR{ test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())), + test.TXT(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN TXT allowedip=%s pubkey=%s", peer2b32, peer2.AllowedIPs[0].String(), strings.TrimRight(peer2b32, "="))), }, }, {