mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-18 13:59:34 +08:00
serve self peer info
This commit is contained in:
parent
7d03ee7041
commit
6f78170fbe
41
wgsd.go
41
wgsd.go
@ -29,6 +29,10 @@ type WGSD struct {
|
||||
zone string
|
||||
// the WireGuard device name, e.g. wg0
|
||||
device string
|
||||
// overrides the self endpoint value
|
||||
selfEndpoint *net.UDPAddr
|
||||
// self allowed IPs
|
||||
selfAllowedIPs []net.IPNet
|
||||
}
|
||||
|
||||
type wgctrlClient interface {
|
||||
@ -145,6 +149,36 @@ func handleHostOrTXT(ctx context.Context, state request.Request,
|
||||
return nxDomain(state)
|
||||
}
|
||||
|
||||
func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
|
||||
self := wgtypes.Peer{
|
||||
PublicKey: device.PublicKey,
|
||||
}
|
||||
if p.selfEndpoint != nil {
|
||||
self.Endpoint = p.selfEndpoint
|
||||
} else {
|
||||
self.Endpoint = &net.UDPAddr{
|
||||
IP: net.ParseIP(state.LocalIP()),
|
||||
Port: device.ListenPort,
|
||||
}
|
||||
}
|
||||
self.AllowedIPs = p.selfAllowedIPs
|
||||
return self, nil
|
||||
}
|
||||
|
||||
func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) {
|
||||
peers := make([]wgtypes.Peer, 0)
|
||||
device, err := p.client.Device(p.device)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers = append(peers, device.Peers...)
|
||||
self, err := p.getSelfPeer(device, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(peers, self), nil
|
||||
}
|
||||
|
||||
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
r *dns.Msg) (int, error) {
|
||||
// request.Request is a convenience struct we wrap around the msg and
|
||||
@ -169,15 +203,12 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
return nxDomain(state)
|
||||
}
|
||||
|
||||
device, err := p.client.Device(p.device)
|
||||
peers, err := p.getPeers(state)
|
||||
if err != nil {
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
if len(device.Peers) == 0 {
|
||||
return nxDomain(state)
|
||||
}
|
||||
|
||||
return handler(ctx, state, device.Peers)
|
||||
return handler(ctx, state, peers)
|
||||
}
|
||||
|
||||
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
||||
|
53
wgsd_test.go
53
wgsd_test.go
@ -16,14 +16,11 @@ import (
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
peers []wgtypes.Peer
|
||||
device *wgtypes.Device
|
||||
}
|
||||
|
||||
func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
|
||||
return &wgtypes.Device{
|
||||
Name: d,
|
||||
Peers: m.peers,
|
||||
}, nil
|
||||
return m.device, nil
|
||||
}
|
||||
|
||||
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
|
||||
@ -44,6 +41,11 @@ func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string)
|
||||
}
|
||||
|
||||
func TestWGSD(t *testing.T) {
|
||||
selfKey := [32]byte{}
|
||||
selfKey[0] = 99
|
||||
selfb32 := strings.ToLower(base32.StdEncoding.EncodeToString(selfKey[:]))
|
||||
selfb64 := base64.StdEncoding.EncodeToString(selfKey[:])
|
||||
selfAllowed, selfAllowedString := constructAllowedIPs(t, []string{"10.0.0.99/32", "10.0.0.100/32"})
|
||||
key1 := [32]byte{}
|
||||
key1[0] = 1
|
||||
peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
|
||||
@ -73,10 +75,16 @@ func TestWGSD(t *testing.T) {
|
||||
p := &WGSD{
|
||||
Next: test.ErrorHandler(),
|
||||
client: &mockClient{
|
||||
peers: []wgtypes.Peer{peer1, peer2},
|
||||
device: &wgtypes.Device{
|
||||
Name: "wg0",
|
||||
PublicKey: selfKey,
|
||||
ListenPort: 51820,
|
||||
Peers: []wgtypes.Peer{peer1, peer2},
|
||||
},
|
||||
},
|
||||
zone: "example.com.",
|
||||
device: "wg0",
|
||||
zone: "example.com.",
|
||||
device: "wg0",
|
||||
selfAllowedIPs: selfAllowed,
|
||||
}
|
||||
|
||||
testCases := []test.Case{
|
||||
@ -87,6 +95,19 @@ func TestWGSD(t *testing.T) {
|
||||
Answer: []dns.RR{
|
||||
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer1b32)),
|
||||
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer2b32)),
|
||||
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", selfb32)),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
|
||||
Qtype: dns.TypeSRV,
|
||||
Rcode: dns.RcodeSuccess,
|
||||
Answer: []dns.RR{
|
||||
test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 51820 %s._wireguard._udp.example.com.", selfb32, selfb32)),
|
||||
},
|
||||
Extra: []dns.RR{
|
||||
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
|
||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -113,6 +134,14 @@ func TestWGSD(t *testing.T) {
|
||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
|
||||
Qtype: dns.TypeA,
|
||||
Rcode: dns.RcodeSuccess,
|
||||
Answer: []dns.RR{
|
||||
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
||||
Qtype: dns.TypeA,
|
||||
@ -129,6 +158,14 @@ func TestWGSD(t *testing.T) {
|
||||
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
|
||||
Qtype: dns.TypeTXT,
|
||||
Rcode: dns.RcodeSuccess,
|
||||
Answer: []dns.RR{
|
||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
|
||||
},
|
||||
},
|
||||
{
|
||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
||||
Qtype: dns.TypeTXT,
|
||||
|
Loading…
x
Reference in New Issue
Block a user