package wgsd

import (
	"context"
	"encoding/base32"
	"encoding/base64"
	"fmt"
	"net"
	"strings"
	"testing"

	"github.com/coredns/coredns/plugin/pkg/dnstest"
	"github.com/coredns/coredns/plugin/test"
	"github.com/miekg/dns"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

type mockClient struct {
	devices map[string]*wgtypes.Device
}

func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
	return m.devices[d], nil
}

func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
	var allowed []net.IPNet
	var allowedString string
	for i, s := range prefixes {
		_, prefix, err := net.ParseCIDR(s)
		if err != nil {
			t.Fatalf("error parsing cidr: %v", err)
		}
		allowed = append(allowed, *prefix)
		if i != 0 {
			allowedString += ","
		}
		allowedString += prefix.String()
	}
	return allowed, allowedString
}

func TestWGSD(t *testing.T) {
	selfKey := [32]byte{}
	selfKey[0] = 99
	selfb32 := strings.ToLower(base32.StdEncoding.EncodeToString(selfKey[:]))
	selfb64 := base64.StdEncoding.EncodeToString(selfKey[:])
	selfAllowed, selfAllowedString := constructAllowedIPs(t, []string{"10.0.0.99/32", "10.0.0.100/32"})
	key1 := [32]byte{}
	key1[0] = 1
	peer1Allowed, peer1AllowedString := constructAllowedIPs(t, []string{"10.0.0.1/32", "10.0.0.2/32"})
	peer1 := wgtypes.Peer{
		Endpoint: &net.UDPAddr{
			IP:   net.ParseIP("1.1.1.1"),
			Port: 1,
		},
		PublicKey:  key1,
		AllowedIPs: peer1Allowed,
	}
	peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:]))
	peer1b64 := base64.StdEncoding.EncodeToString(peer1.PublicKey[:])
	key2 := [32]byte{}
	key2[0] = 2
	peer2Allowed, peer2AllowedString := constructAllowedIPs(t, []string{"10.0.0.3/32", "10.0.0.4/32"})
	peer2 := wgtypes.Peer{
		Endpoint: &net.UDPAddr{
			IP:   net.ParseIP("::2"),
			Port: 2,
		},
		PublicKey:  key2,
		AllowedIPs: peer2Allowed,
	}
	peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
	peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:])
	key3 := [32]byte{}
	key3[0] = 3
	peer3Allowed, _ := constructAllowedIPs(t, []string{"10.0.0.5/32", "10.0.0.6/32"})
	peer3 := wgtypes.Peer{
		Endpoint:   nil,
		PublicKey:  key3,
		AllowedIPs: peer3Allowed,
	}
	peer3b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer3.PublicKey[:]))
	p := &WGSD{
		Next: test.ErrorHandler(),
		Zones: Zones{
			Names: []string{"example.com."},
			Z: map[string]*Zone{
				"example.com.": {
					name:           "example.com.",
					device:         "wg0",
					serveSelf:      true,
					selfAllowedIPs: selfAllowed,
				},
			},
		},
		client: &mockClient{
			devices: map[string]*wgtypes.Device{
				"wg0": {
					Name:       "wg0",
					PublicKey:  selfKey,
					ListenPort: 51820,
					Peers:      []wgtypes.Peer{peer1, peer2, peer3},
				},
			},
		},
	}

	testCases := []test.Case{
		{
			Qname: "_wireguard._udp.example.com.",
			Qtype: dns.TypePTR,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer1b32)),
				test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", peer2b32)),
				test.PTR(fmt.Sprintf("_wireguard._udp.example.com. 0 IN PTR %s._wireguard._udp.example.com.", selfb32)),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
			Qtype: dns.TypeSRV,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 51820 %s._wireguard._udp.example.com.", selfb32, selfb32)),
			},
			Extra: []dns.RR{
				test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
			Qtype: dns.TypeSRV,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 1 %s._wireguard._udp.example.com.", peer1b32, peer1b32)),
			},
			Extra: []dns.RR{
				test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())),
				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer2b32),
			Qtype: dns.TypeSRV,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.SRV(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN SRV 0 0 2 %s._wireguard._udp.example.com.", peer2b32, peer2b32)),
			},
			Extra: []dns.RR{
				test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
			Qtype: dns.TypeA,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", selfb32, "127.0.0.1")),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
			Qtype: dns.TypeA,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.A(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN A %s", peer1b32, peer1.Endpoint.IP.String())),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer2b32),
			Qtype: dns.TypeAAAA,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.AAAA(fmt.Sprintf("%s._wireguard._udp.example.com. 0 IN AAAA %s", peer2b32, peer2.Endpoint.IP.String())),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", selfb32),
			Qtype: dns.TypeTXT,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, selfb32, txtVersion, selfb64, selfAllowedString)),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer1b32),
			Qtype: dns.TypeTXT,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer1b32, txtVersion, peer1b64, peer1AllowedString)),
			},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer2b32),
			Qtype: dns.TypeTXT,
			Rcode: dns.RcodeSuccess,
			Answer: []dns.RR{
				test.TXT(fmt.Sprintf(`%s._wireguard._udp.example.com. 0 IN TXT "txtvers=%d" "pub=%s" "allowed=%s"`, peer2b32, txtVersion, peer2b64, peer2AllowedString)),
			},
		},
		{
			Qname: "nxdomain.example.com.",
			Qtype: dns.TypeAAAA,
			Rcode: dns.RcodeNameError,
			Ns: []dns.RR{
				test.SOA(soa("example.com.").String()),
			},
		},
		{
			Qname: "servfail.notexample.com.",
			Qtype: dns.TypeAAAA,
			Rcode: dns.RcodeServerFailure,
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
			Qtype: dns.TypeSRV,
			Rcode: dns.RcodeNameError,
			Ns: []dns.RR{
				test.SOA(soa("example.com.").String()),
			},
			Answer: []dns.RR{},
			Extra:  []dns.RR{},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
			Qtype: dns.TypeA,
			Rcode: dns.RcodeNameError,
			Ns: []dns.RR{
				test.SOA(soa("example.com.").String()),
			},
			Answer: []dns.RR{},
			Extra:  []dns.RR{},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
			Qtype: dns.TypeAAAA,
			Rcode: dns.RcodeNameError,
			Ns: []dns.RR{
				test.SOA(soa("example.com.").String()),
			},
			Answer: []dns.RR{},
			Extra:  []dns.RR{},
		},
		{
			Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
			Qtype: dns.TypeTXT,
			Rcode: dns.RcodeNameError,
			Ns: []dns.RR{
				test.SOA(soa("example.com.").String()),
			},
			Answer: []dns.RR{},
			Extra:  []dns.RR{},
		},
	}
	for _, tc := range testCases {
		t.Run(fmt.Sprintf("%s %s", tc.Qname, dns.TypeToString[tc.Qtype]), func(t *testing.T) {
			m := tc.Msg()
			rec := dnstest.NewRecorder(&test.ResponseWriter{})
			ctx := context.TODO()
			_, err := p.ServeDNS(ctx, rec, m)
			if err != nil {
				t.Errorf("Expected no error, got %v", err)
				return
			}
			resp := rec.Msg
			if err := test.Header(tc, resp); err != nil {
				t.Error(err)
				return
			}
			if err := test.Section(tc, test.Answer, resp.Answer); err != nil {
				t.Error(err)
			}
			if err := test.Section(tc, test.Ns, resp.Ns); err != nil {
				t.Error(err)
			}
			if err := test.Section(tc, test.Extra, resp.Extra); err != nil {
				t.Error(err)
			}
		})
	}
}