From e068f9d9d2d0311c7a5ed8d90ebaeda6fa3c87fb Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Sat, 2 Jan 2021 15:51:53 -0800 Subject: [PATCH] support multiple zone:device mappings --- setup.go | 96 ++++++++++++++++++-------------- setup_test.go | 148 ++++++++++++++++++++++++++++++++++---------------- wgsd.go | 67 ++++++++++++++--------- wgsd_test.go | 32 +++++++---- 4 files changed, 218 insertions(+), 125 deletions(-) diff --git a/setup.go b/setup.go index 44958dc..40df890 100644 --- a/setup.go +++ b/setup.go @@ -16,64 +16,77 @@ func init() { plugin.Register(pluginName, setup) } -const ( - optionSelfAllowedIPs = "self-allowed-ips" - optionSelfEndpoint = "self-endpoint" -) +func parse(c *caddy.Controller) (Zones, error) { + z := make(map[string]*Zone) + names := []string{} -func parse(c *caddy.Controller) (*WGSD, error) { - p := &WGSD{} for c.Next() { + // wgsd zone device args := c.RemainingArgs() if len(args) != 2 { - return nil, fmt.Errorf("expected 2 args, got %d", len(args)) + return Zones{}, fmt.Errorf("expected 2 args, got %d", len(args)) } - p.zone = dns.Fqdn(args[0]) - p.device = args[1] + zone := &Zone{ + name: dns.Fqdn(args[0]), + device: args[1], + } + names = append(names, zone.name) + _, ok := z[zone.name] + if ok { + return Zones{}, fmt.Errorf("duplicate zone name %s", + zone.name) + } + z[zone.name] = zone for c.NextBlock() { switch c.Val() { - case optionSelfAllowedIPs: - p.selfAllowedIPs = make([]net.IPNet, 0) - for _, aip := range c.RemainingArgs() { - _, prefix, err := net.ParseCIDR(aip) + case "self": + // self [endpoint] [allowed-ips ... ] + zone.serveSelf = true + args = c.RemainingArgs() + if len(args) < 1 { + break + } + + // assume first arg is endpoint + host, portS, err := net.SplitHostPort(args[0]) + if err == nil { + port, err := strconv.Atoi(portS) if err != nil { - return nil, fmt.Errorf("invalid self-allowed-ips: %s err: %v", c.Val(), err) + return Zones{}, fmt.Errorf("error converting self endpoint port: %v", err) } - p.selfAllowedIPs = append(p.selfAllowedIPs, *prefix) + ip := net.ParseIP(host) + if ip == nil { + return Zones{}, fmt.Errorf("invalid self endpoint IP address: %s", host) + } + zone.selfEndpoint = &net.UDPAddr{ + IP: ip, + Port: port, + } + args = args[1:] } - case optionSelfEndpoint: - endpoint := c.RemainingArgs() - if len(endpoint) != 1 { - return nil, fmt.Errorf("expected 1 arg, got %d", len(endpoint)) + + if len(args) > 0 { + zone.selfAllowedIPs = make([]net.IPNet, 0) } - host, portS, err := net.SplitHostPort(endpoint[0]) - if err != nil { - return nil, fmt.Errorf("invalid self-endpoint, err: %v", err) - } - port, err := strconv.Atoi(portS) - if err != nil { - return nil, fmt.Errorf("error converting self-endpoint port: %v", err) - } - ip := net.ParseIP(host) - if ip == nil { - return nil, fmt.Errorf("invalid self-endpoint, invalid IP address: %s", host) - } - p.selfEndpoint = &net.UDPAddr{ - IP: ip, - Port: port, + for _, allowedIPString := range args { + _, prefix, err := net.ParseCIDR(allowedIPString) + if err != nil { + return Zones{}, fmt.Errorf("invalid self allowed-ip '%s' err: %v", allowedIPString, err) + } + zone.selfAllowedIPs = append(zone.selfAllowedIPs, *prefix) } default: - return nil, c.ArgErr() + return Zones{}, c.ArgErr() } } } - return p, nil + return Zones{Z: z, Names: names}, nil } func setup(c *caddy.Controller) error { - wgsd, err := parse(c) + zones, err := parse(c) if err != nil { return plugin.Error(pluginName, err) } @@ -84,13 +97,14 @@ func setup(c *caddy.Controller) error { err)) } c.OnFinalShutdown(client.Close) - wgsd.client = client // Add the Plugin to CoreDNS, so Servers can use it in their plugin chain. dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { - wgsd.Next = next - return wgsd + return &WGSD{ + Next: next, + Zones: zones, + client: client, + } }) - return nil } diff --git a/setup_test.go b/setup_test.go index 03ce778..815b872 100644 --- a/setup_test.go +++ b/setup_test.go @@ -9,106 +9,160 @@ import ( ) func TestSetup(t *testing.T) { + _, prefix1, _ := net.ParseCIDR("1.1.1.1/32") + _, prefix2, _ := net.ParseCIDR("2.2.2.2/32") + _, prefix3, _ := net.ParseCIDR("3.3.3.3/32") + _, prefix4, _ := net.ParseCIDR("4.4.4.4/32") + endpoint1 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820} + testCases := []struct { - name string - input string - expectErr bool - expectSelfAllowedIPs []string - expectSelfEndpoint *net.UDPAddr + name string + input string + shouldErr bool + expectedZones Zones }{ { "valid input", "wgsd example.com. wg0", false, - nil, - nil, + Zones{ + Z: map[string]*Zone{ + "example.com.": { + name: "example.com.", + device: "wg0", + }, + }, + Names: []string{"example.com."}, + }, }, { "missing token", "wgsd example.com.", true, - nil, - nil, + Zones{}, }, { "too many tokens", "wgsd example.com. wg0 extra", true, - nil, - nil, + Zones{}, }, { - "valid self-allowed-ips", + "valid self allowed-ips", `wgsd example.com. wg0 { - self-allowed-ips 10.0.0.1/32 10.0.0.2/32 + self 1.1.1.1/32 2.2.2.2/32 }`, false, - []string{"10.0.0.1/32", "10.0.0.2/32"}, - nil, + Zones{ + Z: map[string]*Zone{ + "example.com.": { + name: "example.com.", + device: "wg0", + serveSelf: true, + selfAllowedIPs: []net.IPNet{*prefix1, *prefix2}, + }, + }, + Names: []string{"example.com."}, + }, }, { "invalid self-allowed-ips", `wgsd example.com. wg0 { - self-allowed-ips 10.0.01/32 10.0.0.2/32 + self 1.1.11/32 2.2.2.2/32 }`, true, - nil, - nil, + Zones{}, }, { "valid self-endpoint", `wgsd example.com. wg0 { - self-endpoint 127.0.0.1:51820 + self 127.0.0.1:51820 }`, false, - nil, - &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820}, + Zones{ + Z: map[string]*Zone{ + "example.com.": { + name: "example.com.", + device: "wg0", + serveSelf: true, + selfEndpoint: endpoint1, + }, + }, + Names: []string{"example.com."}, + }, }, { "invalid self-endpoint", `wgsd example.com. wg0 { - self-endpoint hostname:51820 + self hostname:51820 }`, true, - nil, - nil, + Zones{}, + }, + { + "multiple blocks", + `wgsd example.com. wg0 { + self 127.0.0.1:51820 1.1.1.1/32 2.2.2.2/32 + } + wgsd example2.com. wg1 { + self 127.0.0.1:51820 3.3.3.3/32 4.4.4.4/32 + }`, + false, + Zones{ + Z: map[string]*Zone{ + "example.com.": { + name: "example.com.", + device: "wg0", + serveSelf: true, + selfEndpoint: endpoint1, + selfAllowedIPs: []net.IPNet{*prefix1, *prefix2}, + }, + "example2.com.": { + name: "example2.com.", + device: "wg1", + serveSelf: true, + selfEndpoint: endpoint1, + selfAllowedIPs: []net.IPNet{*prefix3, *prefix4}, + }, + }, + Names: []string{"example.com.", "example2.com."}, + }, }, { "all options", `wgsd example.com. wg0 { - self-allowed-ips 10.0.0.1/32 10.0.0.2/32 - self-endpoint 127.0.0.1:51820 + self 127.0.0.1:51820 1.1.1.1/32 2.2.2.2/32 }`, false, - []string{"10.0.0.1/32", "10.0.0.2/32"}, - &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820}, + Zones{ + Z: map[string]*Zone{ + "example.com.": { + name: "example.com.", + device: "wg0", + serveSelf: true, + selfEndpoint: endpoint1, + selfAllowedIPs: []net.IPNet{*prefix1, *prefix2}, + }, + }, + Names: []string{"example.com."}, + }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := caddy.NewTestController("dns", tc.input) - wgsd, err := parse(c) - if (err != nil) != tc.expectErr { - t.Fatalf("expectErr: %v, got err=%v", tc.expectErr, err) - } - if tc.expectErr { - return - } - if !reflect.DeepEqual(wgsd.selfEndpoint, tc.expectSelfEndpoint) { - t.Errorf("expected self-endpoint %s but found: %s", tc.expectSelfEndpoint, wgsd.selfEndpoint) - } - var expectSelfAllowedIPs []net.IPNet - if tc.expectSelfAllowedIPs != nil { - expectSelfAllowedIPs = make([]net.IPNet, 0) - for _, s := range tc.expectSelfAllowedIPs { - _, p, _ := net.ParseCIDR(s) - expectSelfAllowedIPs = append(expectSelfAllowedIPs, *p) + zones, err := parse(c) + + if err == nil && tc.shouldErr { + t.Fatal("expected errors, but got no error") + } else if err != nil && !tc.shouldErr { + t.Fatalf("expected no errors, but got '%v'", err) + } else { + if !reflect.DeepEqual(tc.expectedZones, zones) { + t.Fatalf("expected %v, got %v", tc.expectedZones, zones) } } - if !reflect.DeepEqual(wgsd.selfAllowedIPs, expectSelfAllowedIPs) { - t.Errorf("expected self-allowed-ips %s but found: %s", expectSelfAllowedIPs, wgsd.selfAllowedIPs) - } }) } } diff --git a/wgsd.go b/wgsd.go index b61eca9..15f7777 100644 --- a/wgsd.go +++ b/wgsd.go @@ -26,17 +26,21 @@ const ( // semantics. WGSD implements the plugin.Handler interface. type WGSD struct { Next plugin.Handler + Zones + client wgctrlClient // the client for retrieving WireGuard peer information +} - // the client for retrieving WireGuard peer information - client wgctrlClient - // the DNS zone we are serving records for - zone string - // the WireGuard device name, e.g. wg0 - device string - // overrides the self endpoint value - selfEndpoint *net.UDPAddr - // self allowed IPs - selfAllowedIPs []net.IPNet +type Zones struct { + Z map[string]*Zone // a mapping from zone name to zone data + Names []string // all keys from the map z as a string slice +} + +type Zone struct { + name string // the name of the zone we are authoritative for + device string // the WireGuard device name, e.g. wg0 + serveSelf bool // flag to enable serving data about self + selfEndpoint *net.UDPAddr // overrides the self endpoint value + selfAllowedIPs []net.IPNet // self allowed IPs } type wgctrlClient interface { @@ -150,50 +154,61 @@ func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) { return nxDomain(state) } -func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) { +func getSelfPeer(zone *Zone, device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) { self := wgtypes.Peer{ PublicKey: device.PublicKey, } - if p.selfEndpoint != nil { - self.Endpoint = p.selfEndpoint + if zone.selfEndpoint != nil { + self.Endpoint = zone.selfEndpoint } else { self.Endpoint = &net.UDPAddr{ IP: net.ParseIP(state.LocalIP()), Port: device.ListenPort, } } - self.AllowedIPs = p.selfAllowedIPs + self.AllowedIPs = zone.selfAllowedIPs return self, nil } -func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) { +func getPeers(client wgctrlClient, zone *Zone, state request.Request) ( + []wgtypes.Peer, error) { peers := make([]wgtypes.Peer, 0) - device, err := p.client.Device(p.device) + device, err := client.Device(zone.device) if err != nil { return nil, err } peers = append(peers, device.Peers...) - self, err := p.getSelfPeer(device, state) - if err != nil { - return nil, err + if zone.serveSelf { + self, err := getSelfPeer(zone, device, state) + if err != nil { + return nil, err + } + peers = append(peers, self) } - return append(peers, self), nil + return peers, nil } func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { // request.Request is a convenience struct we wrap around the msg and // ResponseWriter. - state := request.Request{W: w, Req: r, Zone: p.zone} + state := request.Request{W: w, Req: r} - // Check if the request is for the zone we are serving. If it doesn't match - // we pass the request on to the next plugin. - if plugin.Zones([]string{p.zone}).Matches(state.Name()) == "" { + // Check if the request is for a zone we are serving. If it doesn't match we + // pass the request on to the next plugin. + zoneName := plugin.Zones(p.Names).Matches(state.Name()) + if zoneName == "" { return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) } + state.Zone = zoneName + + zone, ok := p.Z[zoneName] + if !ok { + return dns.RcodeServerFailure, nil + } // strip zone from name - name := strings.TrimSuffix(state.Name(), p.zone) + name := strings.TrimSuffix(state.Name(), zoneName) queryType := state.QType() logger.Debugf("received query for: %s type: %s", name, @@ -204,7 +219,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, return nxDomain(state) } - peers, err := p.getPeers(state) + peers, err := getPeers(p.client, zone, state) if err != nil { return dns.RcodeServerFailure, err } diff --git a/wgsd_test.go b/wgsd_test.go index f90160d..e3e27c8 100644 --- a/wgsd_test.go +++ b/wgsd_test.go @@ -16,11 +16,11 @@ import ( ) type mockClient struct { - device *wgtypes.Device + devices map[string]*wgtypes.Device } func (m *mockClient) Device(d string) (*wgtypes.Device, error) { - return m.device, nil + return m.devices[d], nil } func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) { @@ -74,17 +74,27 @@ func TestWGSD(t *testing.T) { peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:]) p := &WGSD{ Next: test.ErrorHandler(), - client: &mockClient{ - device: &wgtypes.Device{ - Name: "wg0", - PublicKey: selfKey, - ListenPort: 51820, - Peers: []wgtypes.Peer{peer1, peer2}, + Zones: Zones{ + Names: []string{"example.com."}, + Z: map[string]*Zone{ + "example.com.": { + name: "example.com.", + device: "wg0", + serveSelf: true, + selfAllowedIPs: selfAllowed, + }, + }, + }, + client: &mockClient{ + devices: map[string]*wgtypes.Device{ + "wg0": { + Name: "wg0", + PublicKey: selfKey, + ListenPort: 51820, + Peers: []wgtypes.Peer{peer1, peer2}, + }, }, }, - zone: "example.com.", - device: "wg0", - selfAllowedIPs: selfAllowed, } testCases := []test.Case{