mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-18 13:59:34 +08:00
serve allowed ips and public key via TXT RR
This commit is contained in:
parent
016a366d0f
commit
a928f85a58
79
wgsd.go
79
wgsd.go
@ -3,6 +3,7 @@ package wgsd
|
||||
import (
|
||||
"context"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@ -17,16 +18,16 @@ import (
|
||||
// coredns plugin-specific logger
|
||||
var logger = clog.NewWithPlugin("wgsd")
|
||||
|
||||
// WGSD is a CoreDNS plugin that provides Wireguard peer information via DNS-SD
|
||||
// WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD
|
||||
// semantics. WGSD implements the plugin.Handler interface.
|
||||
type WGSD struct {
|
||||
Next plugin.Handler
|
||||
|
||||
// the client for retrieving Wireguard peer information
|
||||
// the client for retrieving WireGuard peer information
|
||||
client wgctrlClient
|
||||
// the DNS zone we are serving records for
|
||||
zone string
|
||||
// the Wireguard device name, e.g. wg0
|
||||
// the WireGuard device name, e.g. wg0
|
||||
device string
|
||||
}
|
||||
|
||||
@ -35,7 +36,7 @@ type wgctrlClient interface {
|
||||
}
|
||||
|
||||
const (
|
||||
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
|
||||
keyLen = 56 // the number of characters in a base32-encoded WireGuard public key
|
||||
spPrefix = "_wireguard._udp."
|
||||
spSubPrefix = "." + spPrefix
|
||||
serviceInstanceLen = keyLen + len(spSubPrefix)
|
||||
@ -55,10 +56,10 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
|
||||
// strip zone from name
|
||||
name := strings.TrimSuffix(state.Name(), p.zone)
|
||||
qtype := state.QType()
|
||||
qType := state.QType()
|
||||
|
||||
logger.Debugf("received query for: %s type: %s", name,
|
||||
dns.TypeToString[qtype])
|
||||
dns.TypeToString[qType])
|
||||
|
||||
device, err := p.client.Device(p.device)
|
||||
if err != nil {
|
||||
@ -75,7 +76,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
|
||||
switch {
|
||||
// TODO: handle SOA
|
||||
case name == spPrefix && qtype == dns.TypePTR:
|
||||
case name == spPrefix && qType == dns.TypePTR:
|
||||
for _, peer := range device.Peers {
|
||||
if peer.Endpoint == nil {
|
||||
continue
|
||||
@ -94,17 +95,18 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
}
|
||||
w.WriteMsg(m) // nolint: errcheck
|
||||
return dns.RcodeSuccess, nil
|
||||
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
|
||||
case len(name) == serviceInstanceLen && qType == dns.TypeSRV:
|
||||
pubKey := name[:keyLen]
|
||||
for _, peer := range device.Peers {
|
||||
if strings.EqualFold(
|
||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||
endpoint := peer.Endpoint
|
||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
||||
hostRR := getHostRR(state.Name(), endpoint)
|
||||
if hostRR == nil {
|
||||
return nxDomain(p.zone, w, r)
|
||||
}
|
||||
m.Extra = append(m.Extra, hostRR)
|
||||
txtRR := getTXTRR(state.Name(), peer)
|
||||
m.Extra = append(m.Extra, hostRR, txtRR)
|
||||
m.Answer = append(m.Answer, &dns.SRV{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: state.Name(),
|
||||
@ -115,25 +117,30 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
Priority: 0,
|
||||
Weight: 0,
|
||||
Port: uint16(endpoint.Port),
|
||||
Target: strings.ToLower(pubKey) + spSubPrefix + p.zone,
|
||||
Target: state.Name(),
|
||||
})
|
||||
w.WriteMsg(m) // nolint: errcheck
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
}
|
||||
return nxDomain(p.zone, w, r)
|
||||
case len(name) == len(spSubPrefix)+keyLen && (qtype == dns.TypeA ||
|
||||
qtype == dns.TypeAAAA):
|
||||
case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA ||
|
||||
qType == dns.TypeAAAA || qType == dns.TypeTXT):
|
||||
pubKey := name[:keyLen]
|
||||
for _, peer := range device.Peers {
|
||||
if strings.EqualFold(
|
||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||
endpoint := peer.Endpoint
|
||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
||||
if hostRR == nil {
|
||||
return nxDomain(p.zone, w, r)
|
||||
if qType == dns.TypeA || qType == dns.TypeAAAA {
|
||||
hostRR := getHostRR(state.Name(), endpoint)
|
||||
if hostRR == nil {
|
||||
return nxDomain(p.zone, w, r)
|
||||
}
|
||||
m.Answer = append(m.Answer, hostRR)
|
||||
} else {
|
||||
txtRR := getTXTRR(state.Name(), peer)
|
||||
m.Answer = append(m.Answer, txtRR)
|
||||
}
|
||||
m.Answer = append(m.Answer, hostRR)
|
||||
w.WriteMsg(m) // nolint: errcheck
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
@ -144,11 +151,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
}
|
||||
}
|
||||
|
||||
func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
|
||||
if endpoint == nil || endpoint.IP == nil {
|
||||
return nil
|
||||
}
|
||||
name := strings.ToLower(pubKey) + spSubPrefix + zone
|
||||
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
||||
switch {
|
||||
case endpoint.IP.To4() != nil:
|
||||
return &dns.A{
|
||||
@ -176,6 +179,38 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// txtVersion is the first key/value pair in the TXT RR. Its serves to aid
|
||||
// clients with maintaining backwards compatibility.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc6763#section-6.7
|
||||
txtVersion = 1
|
||||
)
|
||||
|
||||
func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
|
||||
var allowedIPs string
|
||||
for i, prefix := range peer.AllowedIPs {
|
||||
if i != 0 {
|
||||
allowedIPs += ","
|
||||
}
|
||||
allowedIPs += prefix.String()
|
||||
}
|
||||
return &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Txt: []string{
|
||||
fmt.Sprintf("txtvers=%d", txtVersion),
|
||||
fmt.Sprintf("pub=%s",
|
||||
base64.StdEncoding.EncodeToString(peer.PublicKey[:])),
|
||||
fmt.Sprintf("allowed=%s", allowedIPs),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
|
46
wgsd_test.go
46
wgsd_test.go
@ -3,6 +3,7 @@ package wgsd
|
||||
import (
|
||||
"context"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@ -25,27 +26,50 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
|
||||
var allowed []net.IPNet
|
||||
var allowedString string
|
||||
for i, s := range prefixes {
|
||||
_, prefix, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing cidr: %v", err)
|
||||
}
|
||||
allowed = append(allowed, *prefix)
|
||||
if i != 0 {
|
||||
allowedString += ","
|
||||
}
|
||||
allowedString += prefix.String()
|
||||
}
|
||||
return allowed, allowedString
|
||||
}
|
||||
|
||||
func TestWGSD(t *testing.T) {
|
||||
key1 := [32]byte{}
|
||||
key1[0] = 1
|
||||
peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
|
||||
peer1 := wgtypes.Peer{
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.ParseIP("1.1.1.1"),
|
||||
Port: 1,
|
||||
},
|
||||
PublicKey: key1,
|
||||
PublicKey: key1,
|
||||
AllowedIPs: peer1Allowed,
|
||||
}
|
||||
peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:]))
|
||||
peer1b64 := base64.StdEncoding.EncodeToString(peer1.PublicKey[:])
|
||||
key2 := [32]byte{}
|
||||
key2[0] = 2
|
||||
peer2Allowed, peer2AllowedString := constructAllowedIPs(t, []string{"10.0.0.3/32", "10.0.0.4/32"})
|
||||
peer2 := wgtypes.Peer{
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.ParseIP("::2"),
|
||||
Port: 2,
|
||||
},
|
||||
PublicKey: key2,
|
||||
PublicKey: key2,
|
||||
AllowedIPs: peer2Allowed,
|
||||
}
|
||||
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
|
||||
peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:])
|
||||
p := &WGSD{
|
||||
Next: test.ErrorHandler(),
|
||||
client: &mockClient{
|
||||
@ -74,6 +98,7 @@ func TestWGSD(t *testing.T) {
|
||||
},
|
||||
Extra: []dns.RR{
|
||||
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())),
|
||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -85,6 +110,7 @@ func TestWGSD(t *testing.T) {
|
||||
},
|
||||
Extra: []dns.RR{
|
||||
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -103,6 +129,22 @@ func TestWGSD(t *testing.T) {
|
||||
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
||||
Qtype: dns.TypeTXT,
|
||||
Rcode: dns.RcodeSuccess,
|
||||
Answer: []dns.RR{
|
||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer2b32),
|
||||
Qtype: dns.TypeTXT,
|
||||
Rcode: dns.RcodeSuccess,
|
||||
Answer: []dns.RR{
|
||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: "nxdomain.example.com.",
|
||||
Qtype: dns.TypeAAAA,
|
||||
|
Loading…
x
Reference in New Issue
Block a user