mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-18 22:09:34 +08:00
serve allowed ips and public key via TXT RR
This commit is contained in:
parent
016a366d0f
commit
a928f85a58
79
wgsd.go
79
wgsd.go
@ -3,6 +3,7 @@ package wgsd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@ -17,16 +18,16 @@ import (
|
|||||||
// coredns plugin-specific logger
|
// coredns plugin-specific logger
|
||||||
var logger = clog.NewWithPlugin("wgsd")
|
var logger = clog.NewWithPlugin("wgsd")
|
||||||
|
|
||||||
// WGSD is a CoreDNS plugin that provides Wireguard peer information via DNS-SD
|
// WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD
|
||||||
// semantics. WGSD implements the plugin.Handler interface.
|
// semantics. WGSD implements the plugin.Handler interface.
|
||||||
type WGSD struct {
|
type WGSD struct {
|
||||||
Next plugin.Handler
|
Next plugin.Handler
|
||||||
|
|
||||||
// the client for retrieving Wireguard peer information
|
// the client for retrieving WireGuard peer information
|
||||||
client wgctrlClient
|
client wgctrlClient
|
||||||
// the DNS zone we are serving records for
|
// the DNS zone we are serving records for
|
||||||
zone string
|
zone string
|
||||||
// the Wireguard device name, e.g. wg0
|
// the WireGuard device name, e.g. wg0
|
||||||
device string
|
device string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,7 +36,7 @@ type wgctrlClient interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
|
keyLen = 56 // the number of characters in a base32-encoded WireGuard public key
|
||||||
spPrefix = "_wireguard._udp."
|
spPrefix = "_wireguard._udp."
|
||||||
spSubPrefix = "." + spPrefix
|
spSubPrefix = "." + spPrefix
|
||||||
serviceInstanceLen = keyLen + len(spSubPrefix)
|
serviceInstanceLen = keyLen + len(spSubPrefix)
|
||||||
@ -55,10 +56,10 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
|
|
||||||
// strip zone from name
|
// strip zone from name
|
||||||
name := strings.TrimSuffix(state.Name(), p.zone)
|
name := strings.TrimSuffix(state.Name(), p.zone)
|
||||||
qtype := state.QType()
|
qType := state.QType()
|
||||||
|
|
||||||
logger.Debugf("received query for: %s type: %s", name,
|
logger.Debugf("received query for: %s type: %s", name,
|
||||||
dns.TypeToString[qtype])
|
dns.TypeToString[qType])
|
||||||
|
|
||||||
device, err := p.client.Device(p.device)
|
device, err := p.client.Device(p.device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -75,7 +76,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
// TODO: handle SOA
|
// TODO: handle SOA
|
||||||
case name == spPrefix && qtype == dns.TypePTR:
|
case name == spPrefix && qType == dns.TypePTR:
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range device.Peers {
|
||||||
if peer.Endpoint == nil {
|
if peer.Endpoint == nil {
|
||||||
continue
|
continue
|
||||||
@ -94,17 +95,18 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
}
|
}
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
w.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
|
case len(name) == serviceInstanceLen && qType == dns.TypeSRV:
|
||||||
pubKey := name[:keyLen]
|
pubKey := name[:keyLen]
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range device.Peers {
|
||||||
if strings.EqualFold(
|
if strings.EqualFold(
|
||||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||||
endpoint := peer.Endpoint
|
endpoint := peer.Endpoint
|
||||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
hostRR := getHostRR(state.Name(), endpoint)
|
||||||
if hostRR == nil {
|
if hostRR == nil {
|
||||||
return nxDomain(p.zone, w, r)
|
return nxDomain(p.zone, w, r)
|
||||||
}
|
}
|
||||||
m.Extra = append(m.Extra, hostRR)
|
txtRR := getTXTRR(state.Name(), peer)
|
||||||
|
m.Extra = append(m.Extra, hostRR, txtRR)
|
||||||
m.Answer = append(m.Answer, &dns.SRV{
|
m.Answer = append(m.Answer, &dns.SRV{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: state.Name(),
|
Name: state.Name(),
|
||||||
@ -115,25 +117,30 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
Priority: 0,
|
Priority: 0,
|
||||||
Weight: 0,
|
Weight: 0,
|
||||||
Port: uint16(endpoint.Port),
|
Port: uint16(endpoint.Port),
|
||||||
Target: strings.ToLower(pubKey) + spSubPrefix + p.zone,
|
Target: state.Name(),
|
||||||
})
|
})
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
w.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nxDomain(p.zone, w, r)
|
return nxDomain(p.zone, w, r)
|
||||||
case len(name) == len(spSubPrefix)+keyLen && (qtype == dns.TypeA ||
|
case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA ||
|
||||||
qtype == dns.TypeAAAA):
|
qType == dns.TypeAAAA || qType == dns.TypeTXT):
|
||||||
pubKey := name[:keyLen]
|
pubKey := name[:keyLen]
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range device.Peers {
|
||||||
if strings.EqualFold(
|
if strings.EqualFold(
|
||||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||||
endpoint := peer.Endpoint
|
endpoint := peer.Endpoint
|
||||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
if qType == dns.TypeA || qType == dns.TypeAAAA {
|
||||||
if hostRR == nil {
|
hostRR := getHostRR(state.Name(), endpoint)
|
||||||
return nxDomain(p.zone, w, r)
|
if hostRR == nil {
|
||||||
|
return nxDomain(p.zone, w, r)
|
||||||
|
}
|
||||||
|
m.Answer = append(m.Answer, hostRR)
|
||||||
|
} else {
|
||||||
|
txtRR := getTXTRR(state.Name(), peer)
|
||||||
|
m.Answer = append(m.Answer, txtRR)
|
||||||
}
|
}
|
||||||
m.Answer = append(m.Answer, hostRR)
|
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
w.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
}
|
}
|
||||||
@ -144,11 +151,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
|
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
||||||
if endpoint == nil || endpoint.IP == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
name := strings.ToLower(pubKey) + spSubPrefix + zone
|
|
||||||
switch {
|
switch {
|
||||||
case endpoint.IP.To4() != nil:
|
case endpoint.IP.To4() != nil:
|
||||||
return &dns.A{
|
return &dns.A{
|
||||||
@ -176,6 +179,38 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// txtVersion is the first key/value pair in the TXT RR. Its serves to aid
|
||||||
|
// clients with maintaining backwards compatibility.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc6763#section-6.7
|
||||||
|
txtVersion = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
|
||||||
|
var allowedIPs string
|
||||||
|
for i, prefix := range peer.AllowedIPs {
|
||||||
|
if i != 0 {
|
||||||
|
allowedIPs += ","
|
||||||
|
}
|
||||||
|
allowedIPs += prefix.String()
|
||||||
|
}
|
||||||
|
return &dns.TXT{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: name,
|
||||||
|
Rrtype: dns.TypeTXT,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 0,
|
||||||
|
},
|
||||||
|
Txt: []string{
|
||||||
|
fmt.Sprintf("txtvers=%d", txtVersion),
|
||||||
|
fmt.Sprintf("pub=%s",
|
||||||
|
base64.StdEncoding.EncodeToString(peer.PublicKey[:])),
|
||||||
|
fmt.Sprintf("allowed=%s", allowedIPs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
|
46
wgsd_test.go
46
wgsd_test.go
@ -3,6 +3,7 @@ package wgsd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@ -25,27 +26,50 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
|
||||||
|
var allowed []net.IPNet
|
||||||
|
var allowedString string
|
||||||
|
for i, s := range prefixes {
|
||||||
|
_, prefix, err := net.ParseCIDR(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error parsing cidr: %v", err)
|
||||||
|
}
|
||||||
|
allowed = append(allowed, *prefix)
|
||||||
|
if i != 0 {
|
||||||
|
allowedString += ","
|
||||||
|
}
|
||||||
|
allowedString += prefix.String()
|
||||||
|
}
|
||||||
|
return allowed, allowedString
|
||||||
|
}
|
||||||
|
|
||||||
func TestWGSD(t *testing.T) {
|
func TestWGSD(t *testing.T) {
|
||||||
key1 := [32]byte{}
|
key1 := [32]byte{}
|
||||||
key1[0] = 1
|
key1[0] = 1
|
||||||
|
peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
|
||||||
peer1 := wgtypes.Peer{
|
peer1 := wgtypes.Peer{
|
||||||
Endpoint: &net.UDPAddr{
|
Endpoint: &net.UDPAddr{
|
||||||
IP: net.ParseIP("1.1.1.1"),
|
IP: net.ParseIP("1.1.1.1"),
|
||||||
Port: 1,
|
Port: 1,
|
||||||
},
|
},
|
||||||
PublicKey: key1,
|
PublicKey: key1,
|
||||||
|
AllowedIPs: peer1Allowed,
|
||||||
}
|
}
|
||||||
peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:]))
|
peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:]))
|
||||||
|
peer1b64 := base64.StdEncoding.EncodeToString(peer1.PublicKey[:])
|
||||||
key2 := [32]byte{}
|
key2 := [32]byte{}
|
||||||
key2[0] = 2
|
key2[0] = 2
|
||||||
|
peer2Allowed, peer2AllowedString := constructAllowedIPs(t, []string{"10.0.0.3/32", "10.0.0.4/32"})
|
||||||
peer2 := wgtypes.Peer{
|
peer2 := wgtypes.Peer{
|
||||||
Endpoint: &net.UDPAddr{
|
Endpoint: &net.UDPAddr{
|
||||||
IP: net.ParseIP("::2"),
|
IP: net.ParseIP("::2"),
|
||||||
Port: 2,
|
Port: 2,
|
||||||
},
|
},
|
||||||
PublicKey: key2,
|
PublicKey: key2,
|
||||||
|
AllowedIPs: peer2Allowed,
|
||||||
}
|
}
|
||||||
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
|
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
|
||||||
|
peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:])
|
||||||
p := &WGSD{
|
p := &WGSD{
|
||||||
Next: test.ErrorHandler(),
|
Next: test.ErrorHandler(),
|
||||||
client: &mockClient{
|
client: &mockClient{
|
||||||
@ -74,6 +98,7 @@ func TestWGSD(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Extra: []dns.RR{
|
Extra: []dns.RR{
|
||||||
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())),
|
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())),
|
||||||
|
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -85,6 +110,7 @@ func TestWGSD(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Extra: []dns.RR{
|
Extra: []dns.RR{
|
||||||
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
||||||
|
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -103,6 +129,22 @@ func TestWGSD(t *testing.T) {
|
|||||||
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
||||||
|
Qtype: dns.TypeTXT,
|
||||||
|
Rcode: dns.RcodeSuccess,
|
||||||
|
Answer: []dns.RR{
|
||||||
|
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer2b32),
|
||||||
|
Qtype: dns.TypeTXT,
|
||||||
|
Rcode: dns.RcodeSuccess,
|
||||||
|
Answer: []dns.RR{
|
||||||
|
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Qname: "nxdomain.example.com.",
|
Qname: "nxdomain.example.com.",
|
||||||
Qtype: dns.TypeAAAA,
|
Qtype: dns.TypeAAAA,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user