diff --git a/setup_test.go b/setup_test.go index 6b39a37..03ce778 100644 --- a/setup_test.go +++ b/setup_test.go @@ -1,6 +1,8 @@ package wgsd import ( + "net" + "reflect" "testing" "github.com/coredns/caddy" @@ -12,7 +14,7 @@ func TestSetup(t *testing.T) { input string expectErr bool expectSelfAllowedIPs []string - expectSelfEndpoint []string + expectSelfEndpoint *net.UDPAddr }{ { "valid input", @@ -41,7 +43,7 @@ func TestSetup(t *testing.T) { self-allowed-ips 10.0.0.1/32 10.0.0.2/32 }`, false, - nil, + []string{"10.0.0.1/32", "10.0.0.2/32"}, nil, }, { @@ -60,7 +62,7 @@ func TestSetup(t *testing.T) { }`, false, nil, - nil, + &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820}, }, { "invalid self-endpoint", @@ -78,21 +80,35 @@ func TestSetup(t *testing.T) { self-endpoint 127.0.0.1:51820 }`, false, - nil, - nil, + []string{"10.0.0.1/32", "10.0.0.2/32"}, + &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := caddy.NewTestController("dns", tc.input) - err := setup(c) + wgsd, err := parse(c) if (err != nil) != tc.expectErr { t.Fatalf("expectErr: %v, got err=%v", tc.expectErr, err) } if tc.expectErr { return } + if !reflect.DeepEqual(wgsd.selfEndpoint, tc.expectSelfEndpoint) { + t.Errorf("expected self-endpoint %s but found: %s", tc.expectSelfEndpoint, wgsd.selfEndpoint) + } + var expectSelfAllowedIPs []net.IPNet + if tc.expectSelfAllowedIPs != nil { + expectSelfAllowedIPs = make([]net.IPNet, 0) + for _, s := range tc.expectSelfAllowedIPs { + _, p, _ := net.ParseCIDR(s) + expectSelfAllowedIPs = append(expectSelfAllowedIPs, *p) + } + } + if !reflect.DeepEqual(wgsd.selfAllowedIPs, expectSelfAllowedIPs) { + t.Errorf("expected self-allowed-ips %s but found: %s", expectSelfAllowedIPs, wgsd.selfAllowedIPs) + } }) } } diff --git a/wgsd.go b/wgsd.go index ab572ce..b61eca9 100644 --- a/wgsd.go +++ b/wgsd.go @@ -50,7 +50,7 @@ const ( serviceInstanceLen = keyLen + len(spSubPrefix) ) -type handlerFn func(ctx context.Context, state request.Request, peers []wgtypes.Peer) (int, error) +type handlerFn func(state request.Request, peers []wgtypes.Peer) (int, error) func getHandlerFn(queryType uint16, name string) handlerFn { switch { @@ -66,8 +66,7 @@ func getHandlerFn(queryType uint16, name string) handlerFn { } } -func handlePTR(ctx context.Context, state request.Request, - peers []wgtypes.Peer) (int, error) { +func handlePTR(state request.Request, peers []wgtypes.Peer) (int, error) { m := new(dns.Msg) m.SetReply(state.Req) m.Authoritative = true @@ -91,8 +90,7 @@ func handlePTR(ctx context.Context, state request.Request, return dns.RcodeSuccess, nil } -func handleSRV(ctx context.Context, state request.Request, - peers []wgtypes.Peer) (int, error) { +func handleSRV(state request.Request, peers []wgtypes.Peer) (int, error) { m := new(dns.Msg) m.SetReply(state.Req) m.Authoritative = true @@ -126,8 +124,7 @@ func handleSRV(ctx context.Context, state request.Request, return nxDomain(state) } -func handleHostOrTXT(ctx context.Context, state request.Request, - peers []wgtypes.Peer) (int, error) { +func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) { m := new(dns.Msg) m.SetReply(state.Req) m.Authoritative = true @@ -212,7 +209,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, return dns.RcodeServerFailure, err } - return handler(ctx, state, peers) + return handler(state, peers) } func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {