diff --git a/setup.go b/setup.go index 8e35436..44958dc 100644 --- a/setup.go +++ b/setup.go @@ -2,6 +2,8 @@ package wgsd import ( "fmt" + "net" + "strconv" "github.com/coredns/caddy" "github.com/coredns/coredns/core/dnsserver" @@ -11,45 +13,83 @@ import ( ) func init() { - plugin.Register("wgsd", setup) + plugin.Register(pluginName, setup) +} + +const ( + optionSelfAllowedIPs = "self-allowed-ips" + optionSelfEndpoint = "self-endpoint" +) + +func parse(c *caddy.Controller) (*WGSD, error) { + p := &WGSD{} + for c.Next() { + args := c.RemainingArgs() + if len(args) != 2 { + return nil, fmt.Errorf("expected 2 args, got %d", len(args)) + } + p.zone = dns.Fqdn(args[0]) + p.device = args[1] + + for c.NextBlock() { + switch c.Val() { + case optionSelfAllowedIPs: + p.selfAllowedIPs = make([]net.IPNet, 0) + for _, aip := range c.RemainingArgs() { + _, prefix, err := net.ParseCIDR(aip) + if err != nil { + return nil, fmt.Errorf("invalid self-allowed-ips: %s err: %v", c.Val(), err) + } + p.selfAllowedIPs = append(p.selfAllowedIPs, *prefix) + } + case optionSelfEndpoint: + endpoint := c.RemainingArgs() + if len(endpoint) != 1 { + return nil, fmt.Errorf("expected 1 arg, got %d", len(endpoint)) + } + host, portS, err := net.SplitHostPort(endpoint[0]) + if err != nil { + return nil, fmt.Errorf("invalid self-endpoint, err: %v", err) + } + port, err := strconv.Atoi(portS) + if err != nil { + return nil, fmt.Errorf("error converting self-endpoint port: %v", err) + } + ip := net.ParseIP(host) + if ip == nil { + return nil, fmt.Errorf("invalid self-endpoint, invalid IP address: %s", host) + } + p.selfEndpoint = &net.UDPAddr{ + IP: ip, + Port: port, + } + default: + return nil, c.ArgErr() + } + } + } + + return p, nil } func setup(c *caddy.Controller) error { - c.Next() // Ignore "wgsd" and give us the next token. - - // return an error if there is no zone specified - if !c.NextArg() { - return plugin.Error("wgsd", c.ArgErr()) + wgsd, err := parse(c) + if err != nil { + return plugin.Error(pluginName, err) } - zone := dns.Fqdn(c.Val()) - - // return an error if there is no device name specified - if !c.NextArg() { - return plugin.Error("wgsd", c.ArgErr()) - } - device := c.Val() - - // return an error if there are more tokens on this line - if c.NextArg() { - return plugin.Error("wgsd", c.ArgErr()) - } - client, err := wgctrl.New() if err != nil { - return plugin.Error("wgsd", + return plugin.Error(pluginName, fmt.Errorf("error constructing wgctrl client: %v", err)) } c.OnFinalShutdown(client.Close) + wgsd.client = client // Add the Plugin to CoreDNS, so Servers can use it in their plugin chain. dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { - return &WGSD{ - Next: next, - client: client, - zone: zone, - device: device, - } + wgsd.Next = next + return wgsd }) return nil diff --git a/setup_test.go b/setup_test.go index 39e7410..6b39a37 100644 --- a/setup_test.go +++ b/setup_test.go @@ -8,24 +8,78 @@ import ( func TestSetup(t *testing.T) { testCases := []struct { - name string - input string - expectErr bool + name string + input string + expectErr bool + expectSelfAllowedIPs []string + expectSelfEndpoint []string }{ { "valid input", "wgsd example.com. wg0", false, + nil, + nil, }, { "missing token", "wgsd example.com.", true, + nil, + nil, }, { "too many tokens", "wgsd example.com. wg0 extra", true, + nil, + nil, + }, + { + "valid self-allowed-ips", + `wgsd example.com. wg0 { + self-allowed-ips 10.0.0.1/32 10.0.0.2/32 + }`, + false, + nil, + nil, + }, + { + "invalid self-allowed-ips", + `wgsd example.com. wg0 { + self-allowed-ips 10.0.01/32 10.0.0.2/32 + }`, + true, + nil, + nil, + }, + { + "valid self-endpoint", + `wgsd example.com. wg0 { + self-endpoint 127.0.0.1:51820 + }`, + false, + nil, + nil, + }, + { + "invalid self-endpoint", + `wgsd example.com. wg0 { + self-endpoint hostname:51820 + }`, + true, + nil, + nil, + }, + { + "all options", + `wgsd example.com. wg0 { + self-allowed-ips 10.0.0.1/32 10.0.0.2/32 + self-endpoint 127.0.0.1:51820 + }`, + false, + nil, + nil, }, } @@ -36,6 +90,9 @@ func TestSetup(t *testing.T) { if (err != nil) != tc.expectErr { t.Fatalf("expectErr: %v, got err=%v", tc.expectErr, err) } + if tc.expectErr { + return + } }) } } diff --git a/wgsd.go b/wgsd.go index 1674544..ab572ce 100644 --- a/wgsd.go +++ b/wgsd.go @@ -16,7 +16,11 @@ import ( ) // coredns plugin-specific logger -var logger = clog.NewWithPlugin("wgsd") +var logger = clog.NewWithPlugin(pluginName) + +const ( + pluginName = "wgsd" +) // WGSD is a CoreDNS plugin that provides WireGuard peer information via DNS-SD // semantics. WGSD implements the plugin.Handler interface. @@ -300,5 +304,5 @@ func soa(zone string) dns.RR { } func (p *WGSD) Name() string { - return "wgsd" + return pluginName }