mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-18 22:09:34 +08:00
serve self peer info
This commit is contained in:
parent
7d03ee7041
commit
6f78170fbe
41
wgsd.go
41
wgsd.go
@ -29,6 +29,10 @@ type WGSD struct {
|
|||||||
zone string
|
zone string
|
||||||
// the WireGuard device name, e.g. wg0
|
// the WireGuard device name, e.g. wg0
|
||||||
device string
|
device string
|
||||||
|
// overrides the self endpoint value
|
||||||
|
selfEndpoint *net.UDPAddr
|
||||||
|
// self allowed IPs
|
||||||
|
selfAllowedIPs []net.IPNet
|
||||||
}
|
}
|
||||||
|
|
||||||
type wgctrlClient interface {
|
type wgctrlClient interface {
|
||||||
@ -145,6 +149,36 @@ func handleHostOrTXT(ctx context.Context, state request.Request,
|
|||||||
return nxDomain(state)
|
return nxDomain(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
|
||||||
|
self := wgtypes.Peer{
|
||||||
|
PublicKey: device.PublicKey,
|
||||||
|
}
|
||||||
|
if p.selfEndpoint != nil {
|
||||||
|
self.Endpoint = p.selfEndpoint
|
||||||
|
} else {
|
||||||
|
self.Endpoint = &net.UDPAddr{
|
||||||
|
IP: net.ParseIP(state.LocalIP()),
|
||||||
|
Port: device.ListenPort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.AllowedIPs = p.selfAllowedIPs
|
||||||
|
return self, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) {
|
||||||
|
peers := make([]wgtypes.Peer, 0)
|
||||||
|
device, err := p.client.Device(p.device)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
peers = append(peers, device.Peers...)
|
||||||
|
self, err := p.getSelfPeer(device, state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return append(peers, self), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||||
r *dns.Msg) (int, error) {
|
r *dns.Msg) (int, error) {
|
||||||
// request.Request is a convenience struct we wrap around the msg and
|
// request.Request is a convenience struct we wrap around the msg and
|
||||||
@ -169,15 +203,12 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
return nxDomain(state)
|
return nxDomain(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := p.client.Device(p.device)
|
peers, err := p.getPeers(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return dns.RcodeServerFailure, err
|
return dns.RcodeServerFailure, err
|
||||||
}
|
}
|
||||||
if len(device.Peers) == 0 {
|
|
||||||
return nxDomain(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
return handler(ctx, state, device.Peers)
|
return handler(ctx, state, peers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
||||||
|
53
wgsd_test.go
53
wgsd_test.go
@ -16,14 +16,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type mockClient struct {
|
type mockClient struct {
|
||||||
peers []wgtypes.Peer
|
device *wgtypes.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
|
func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
|
||||||
return &wgtypes.Device{
|
return m.device, nil
|
||||||
Name: d,
|
|
||||||
Peers: m.peers,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
|
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
|
||||||
@ -44,6 +41,11 @@ func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWGSD(t *testing.T) {
|
func TestWGSD(t *testing.T) {
|
||||||
|
selfKey := [32]byte{}
|
||||||
|
selfKey[0] = 99
|
||||||
|
selfb32 := strings.ToLower(base32.StdEncoding.EncodeToString(selfKey[:]))
|
||||||
|
selfb64 := base64.StdEncoding.EncodeToString(selfKey[:])
|
||||||
|
selfAllowed, selfAllowedString := constructAllowedIPs(t, []string{"10.0.0.99/32", "10.0.0.100/32"})
|
||||||
key1 := [32]byte{}
|
key1 := [32]byte{}
|
||||||
key1[0] = 1
|
key1[0] = 1
|
||||||
peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
|
peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
|
||||||
@ -73,10 +75,16 @@ func TestWGSD(t *testing.T) {
|
|||||||
p := &WGSD{
|
p := &WGSD{
|
||||||
Next: test.ErrorHandler(),
|
Next: test.ErrorHandler(),
|
||||||
client: &mockClient{
|
client: &mockClient{
|
||||||
peers: []wgtypes.Peer{peer1, peer2},
|
device: &wgtypes.Device{
|
||||||
|
Name: "wg0",
|
||||||
|
PublicKey: selfKey,
|
||||||
|
ListenPort: 51820,
|
||||||
|
Peers: []wgtypes.Peer{peer1, peer2},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
zone: "example.com.",
|
zone: "example.com.",
|
||||||
device: "wg0",
|
device: "wg0",
|
||||||
|
selfAllowedIPs: selfAllowed,
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []test.Case{
|
testCases := []test.Case{
|
||||||
@ -87,6 +95,19 @@ func TestWGSD(t *testing.T) {
|
|||||||
Answer: []dns.RR{
|
Answer: []dns.RR{
|
||||||
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer1b32)),
|
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer1b32)),
|
||||||
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer2b32)),
|
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer2b32)),
|
||||||
|
test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", selfb32)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
|
||||||
|
Qtype: dns.TypeSRV,
|
||||||
|
Rcode: dns.RcodeSuccess,
|
||||||
|
Answer: []dns.RR{
|
||||||
|
test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 51820 %s._wireguard._udp.example.com.", selfb32, selfb32)),
|
||||||
|
},
|
||||||
|
Extra: []dns.RR{
|
||||||
|
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
|
||||||
|
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -113,6 +134,14 @@ func TestWGSD(t *testing.T) {
|
|||||||
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
|
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Rcode: dns.RcodeSuccess,
|
||||||
|
Answer: []dns.RR{
|
||||||
|
test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
||||||
Qtype: dns.TypeA,
|
Qtype: dns.TypeA,
|
||||||
@ -129,6 +158,14 @@ func TestWGSD(t *testing.T) {
|
|||||||
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
|
||||||
|
Qtype: dns.TypeTXT,
|
||||||
|
Rcode: dns.RcodeSuccess,
|
||||||
|
Answer: []dns.RR{
|
||||||
|
test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
|
||||||
Qtype: dns.TypeTXT,
|
Qtype: dns.TypeTXT,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user