mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-18 13:59:34 +08:00
support multiple zone:device mappings
This commit is contained in:
parent
a700f38f3e
commit
e068f9d9d2
86
setup.go
86
setup.go
@ -16,64 +16,77 @@ func init() {
|
||||
plugin.Register(pluginName, setup)
|
||||
}
|
||||
|
||||
const (
|
||||
optionSelfAllowedIPs = "self-allowed-ips"
|
||||
optionSelfEndpoint = "self-endpoint"
|
||||
)
|
||||
func parse(c *caddy.Controller) (Zones, error) {
|
||||
z := make(map[string]*Zone)
|
||||
names := []string{}
|
||||
|
||||
func parse(c *caddy.Controller) (*WGSD, error) {
|
||||
p := &WGSD{}
|
||||
for c.Next() {
|
||||
// wgsd zone device
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 2 {
|
||||
return nil, fmt.Errorf("expected 2 args, got %d", len(args))
|
||||
return Zones{}, fmt.Errorf("expected 2 args, got %d", len(args))
|
||||
}
|
||||
p.zone = dns.Fqdn(args[0])
|
||||
p.device = args[1]
|
||||
zone := &Zone{
|
||||
name: dns.Fqdn(args[0]),
|
||||
device: args[1],
|
||||
}
|
||||
names = append(names, zone.name)
|
||||
_, ok := z[zone.name]
|
||||
if ok {
|
||||
return Zones{}, fmt.Errorf("duplicate zone name %s",
|
||||
zone.name)
|
||||
}
|
||||
z[zone.name] = zone
|
||||
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case optionSelfAllowedIPs:
|
||||
p.selfAllowedIPs = make([]net.IPNet, 0)
|
||||
for _, aip := range c.RemainingArgs() {
|
||||
_, prefix, err := net.ParseCIDR(aip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid self-allowed-ips: %s err: %v", c.Val(), err)
|
||||
}
|
||||
p.selfAllowedIPs = append(p.selfAllowedIPs, *prefix)
|
||||
}
|
||||
case optionSelfEndpoint:
|
||||
endpoint := c.RemainingArgs()
|
||||
if len(endpoint) != 1 {
|
||||
return nil, fmt.Errorf("expected 1 arg, got %d", len(endpoint))
|
||||
}
|
||||
host, portS, err := net.SplitHostPort(endpoint[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid self-endpoint, err: %v", err)
|
||||
case "self":
|
||||
// self [endpoint] [allowed-ips ... ]
|
||||
zone.serveSelf = true
|
||||
args = c.RemainingArgs()
|
||||
if len(args) < 1 {
|
||||
break
|
||||
}
|
||||
|
||||
// assume first arg is endpoint
|
||||
host, portS, err := net.SplitHostPort(args[0])
|
||||
if err == nil {
|
||||
port, err := strconv.Atoi(portS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting self-endpoint port: %v", err)
|
||||
return Zones{}, fmt.Errorf("error converting self endpoint port: %v", err)
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid self-endpoint, invalid IP address: %s", host)
|
||||
return Zones{}, fmt.Errorf("invalid self endpoint IP address: %s", host)
|
||||
}
|
||||
p.selfEndpoint = &net.UDPAddr{
|
||||
zone.selfEndpoint = &net.UDPAddr{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
}
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
zone.selfAllowedIPs = make([]net.IPNet, 0)
|
||||
}
|
||||
for _, allowedIPString := range args {
|
||||
_, prefix, err := net.ParseCIDR(allowedIPString)
|
||||
if err != nil {
|
||||
return Zones{}, fmt.Errorf("invalid self allowed-ip '%s' err: %v", allowedIPString, err)
|
||||
}
|
||||
zone.selfAllowedIPs = append(zone.selfAllowedIPs, *prefix)
|
||||
}
|
||||
default:
|
||||
return nil, c.ArgErr()
|
||||
return Zones{}, c.ArgErr()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
return Zones{Z: z, Names: names}, nil
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
wgsd, err := parse(c)
|
||||
zones, err := parse(c)
|
||||
if err != nil {
|
||||
return plugin.Error(pluginName, err)
|
||||
}
|
||||
@ -84,13 +97,14 @@ func setup(c *caddy.Controller) error {
|
||||
err))
|
||||
}
|
||||
c.OnFinalShutdown(client.Close)
|
||||
wgsd.client = client
|
||||
|
||||
// Add the Plugin to CoreDNS, so Servers can use it in their plugin chain.
|
||||
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
wgsd.Next = next
|
||||
return wgsd
|
||||
return &WGSD{
|
||||
Next: next,
|
||||
Zones: zones,
|
||||
client: client,
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
144
setup_test.go
144
setup_test.go
@ -9,105 +9,159 @@ import (
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
_, prefix1, _ := net.ParseCIDR("1.1.1.1/32")
|
||||
_, prefix2, _ := net.ParseCIDR("2.2.2.2/32")
|
||||
_, prefix3, _ := net.ParseCIDR("3.3.3.3/32")
|
||||
_, prefix4, _ := net.ParseCIDR("4.4.4.4/32")
|
||||
endpoint1 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
expectSelfAllowedIPs []string
|
||||
expectSelfEndpoint *net.UDPAddr
|
||||
shouldErr bool
|
||||
expectedZones Zones
|
||||
}{
|
||||
{
|
||||
"valid input",
|
||||
"wgsd example.com. wg0",
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
Zones{
|
||||
Z: map[string]*Zone{
|
||||
"example.com.": {
|
||||
name: "example.com.",
|
||||
device: "wg0",
|
||||
},
|
||||
},
|
||||
Names: []string{"example.com."},
|
||||
},
|
||||
},
|
||||
{
|
||||
"missing token",
|
||||
"wgsd example.com.",
|
||||
true,
|
||||
nil,
|
||||
nil,
|
||||
Zones{},
|
||||
},
|
||||
{
|
||||
"too many tokens",
|
||||
"wgsd example.com. wg0 extra",
|
||||
true,
|
||||
nil,
|
||||
nil,
|
||||
Zones{},
|
||||
},
|
||||
{
|
||||
"valid self-allowed-ips",
|
||||
"valid self allowed-ips",
|
||||
`wgsd example.com. wg0 {
|
||||
self-allowed-ips 10.0.0.1/32 10.0.0.2/32
|
||||
self 1.1.1.1/32 2.2.2.2/32
|
||||
}`,
|
||||
false,
|
||||
[]string{"10.0.0.1/32", "10.0.0.2/32"},
|
||||
nil,
|
||||
Zones{
|
||||
Z: map[string]*Zone{
|
||||
"example.com.": {
|
||||
name: "example.com.",
|
||||
device: "wg0",
|
||||
serveSelf: true,
|
||||
selfAllowedIPs: []net.IPNet{*prefix1, *prefix2},
|
||||
},
|
||||
},
|
||||
Names: []string{"example.com."},
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid self-allowed-ips",
|
||||
`wgsd example.com. wg0 {
|
||||
self-allowed-ips 10.0.01/32 10.0.0.2/32
|
||||
self 1.1.11/32 2.2.2.2/32
|
||||
}`,
|
||||
true,
|
||||
nil,
|
||||
nil,
|
||||
Zones{},
|
||||
},
|
||||
{
|
||||
"valid self-endpoint",
|
||||
`wgsd example.com. wg0 {
|
||||
self-endpoint 127.0.0.1:51820
|
||||
self 127.0.0.1:51820
|
||||
}`,
|
||||
false,
|
||||
nil,
|
||||
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820},
|
||||
Zones{
|
||||
Z: map[string]*Zone{
|
||||
"example.com.": {
|
||||
name: "example.com.",
|
||||
device: "wg0",
|
||||
serveSelf: true,
|
||||
selfEndpoint: endpoint1,
|
||||
},
|
||||
},
|
||||
Names: []string{"example.com."},
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid self-endpoint",
|
||||
`wgsd example.com. wg0 {
|
||||
self-endpoint hostname:51820
|
||||
self hostname:51820
|
||||
}`,
|
||||
true,
|
||||
nil,
|
||||
nil,
|
||||
Zones{},
|
||||
},
|
||||
{
|
||||
"multiple blocks",
|
||||
`wgsd example.com. wg0 {
|
||||
self 127.0.0.1:51820 1.1.1.1/32 2.2.2.2/32
|
||||
}
|
||||
wgsd example2.com. wg1 {
|
||||
self 127.0.0.1:51820 3.3.3.3/32 4.4.4.4/32
|
||||
}`,
|
||||
false,
|
||||
Zones{
|
||||
Z: map[string]*Zone{
|
||||
"example.com.": {
|
||||
name: "example.com.",
|
||||
device: "wg0",
|
||||
serveSelf: true,
|
||||
selfEndpoint: endpoint1,
|
||||
selfAllowedIPs: []net.IPNet{*prefix1, *prefix2},
|
||||
},
|
||||
"example2.com.": {
|
||||
name: "example2.com.",
|
||||
device: "wg1",
|
||||
serveSelf: true,
|
||||
selfEndpoint: endpoint1,
|
||||
selfAllowedIPs: []net.IPNet{*prefix3, *prefix4},
|
||||
},
|
||||
},
|
||||
Names: []string{"example.com.", "example2.com."},
|
||||
},
|
||||
},
|
||||
{
|
||||
"all options",
|
||||
`wgsd example.com. wg0 {
|
||||
self-allowed-ips 10.0.0.1/32 10.0.0.2/32
|
||||
self-endpoint 127.0.0.1:51820
|
||||
self 127.0.0.1:51820 1.1.1.1/32 2.2.2.2/32
|
||||
}`,
|
||||
false,
|
||||
[]string{"10.0.0.1/32", "10.0.0.2/32"},
|
||||
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51820},
|
||||
Zones{
|
||||
Z: map[string]*Zone{
|
||||
"example.com.": {
|
||||
name: "example.com.",
|
||||
device: "wg0",
|
||||
serveSelf: true,
|
||||
selfEndpoint: endpoint1,
|
||||
selfAllowedIPs: []net.IPNet{*prefix1, *prefix2},
|
||||
},
|
||||
},
|
||||
Names: []string{"example.com."},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := caddy.NewTestController("dns", tc.input)
|
||||
wgsd, err := parse(c)
|
||||
if (err != nil) != tc.expectErr {
|
||||
t.Fatalf("expectErr: %v, got err=%v", tc.expectErr, err)
|
||||
zones, err := parse(c)
|
||||
|
||||
if err == nil && tc.shouldErr {
|
||||
t.Fatal("expected errors, but got no error")
|
||||
} else if err != nil && !tc.shouldErr {
|
||||
t.Fatalf("expected no errors, but got '%v'", err)
|
||||
} else {
|
||||
if !reflect.DeepEqual(tc.expectedZones, zones) {
|
||||
t.Fatalf("expected %v, got %v", tc.expectedZones, zones)
|
||||
}
|
||||
if tc.expectErr {
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(wgsd.selfEndpoint, tc.expectSelfEndpoint) {
|
||||
t.Errorf("expected self-endpoint %s but found: %s", tc.expectSelfEndpoint, wgsd.selfEndpoint)
|
||||
}
|
||||
var expectSelfAllowedIPs []net.IPNet
|
||||
if tc.expectSelfAllowedIPs != nil {
|
||||
expectSelfAllowedIPs = make([]net.IPNet, 0)
|
||||
for _, s := range tc.expectSelfAllowedIPs {
|
||||
_, p, _ := net.ParseCIDR(s)
|
||||
expectSelfAllowedIPs = append(expectSelfAllowedIPs, *p)
|
||||
}
|
||||
}
|
||||
if !reflect.DeepEqual(wgsd.selfAllowedIPs, expectSelfAllowedIPs) {
|
||||
t.Errorf("expected self-allowed-ips %s but found: %s", expectSelfAllowedIPs, wgsd.selfAllowedIPs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
63
wgsd.go
63
wgsd.go
@ -26,17 +26,21 @@ const (
|
||||
// semantics. WGSD implements the plugin.Handler interface.
|
||||
type WGSD struct {
|
||||
Next plugin.Handler
|
||||
Zones
|
||||
client wgctrlClient // the client for retrieving WireGuard peer information
|
||||
}
|
||||
|
||||
// the client for retrieving WireGuard peer information
|
||||
client wgctrlClient
|
||||
// the DNS zone we are serving records for
|
||||
zone string
|
||||
// the WireGuard device name, e.g. wg0
|
||||
device string
|
||||
// overrides the self endpoint value
|
||||
selfEndpoint *net.UDPAddr
|
||||
// self allowed IPs
|
||||
selfAllowedIPs []net.IPNet
|
||||
type Zones struct {
|
||||
Z map[string]*Zone // a mapping from zone name to zone data
|
||||
Names []string // all keys from the map z as a string slice
|
||||
}
|
||||
|
||||
type Zone struct {
|
||||
name string // the name of the zone we are authoritative for
|
||||
device string // the WireGuard device name, e.g. wg0
|
||||
serveSelf bool // flag to enable serving data about self
|
||||
selfEndpoint *net.UDPAddr // overrides the self endpoint value
|
||||
selfAllowedIPs []net.IPNet // self allowed IPs
|
||||
}
|
||||
|
||||
type wgctrlClient interface {
|
||||
@ -150,50 +154,61 @@ func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) {
|
||||
return nxDomain(state)
|
||||
}
|
||||
|
||||
func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
|
||||
func getSelfPeer(zone *Zone, device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
|
||||
self := wgtypes.Peer{
|
||||
PublicKey: device.PublicKey,
|
||||
}
|
||||
if p.selfEndpoint != nil {
|
||||
self.Endpoint = p.selfEndpoint
|
||||
if zone.selfEndpoint != nil {
|
||||
self.Endpoint = zone.selfEndpoint
|
||||
} else {
|
||||
self.Endpoint = &net.UDPAddr{
|
||||
IP: net.ParseIP(state.LocalIP()),
|
||||
Port: device.ListenPort,
|
||||
}
|
||||
}
|
||||
self.AllowedIPs = p.selfAllowedIPs
|
||||
self.AllowedIPs = zone.selfAllowedIPs
|
||||
return self, nil
|
||||
}
|
||||
|
||||
func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) {
|
||||
func getPeers(client wgctrlClient, zone *Zone, state request.Request) (
|
||||
[]wgtypes.Peer, error) {
|
||||
peers := make([]wgtypes.Peer, 0)
|
||||
device, err := p.client.Device(p.device)
|
||||
device, err := client.Device(zone.device)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers = append(peers, device.Peers...)
|
||||
self, err := p.getSelfPeer(device, state)
|
||||
if zone.serveSelf {
|
||||
self, err := getSelfPeer(zone, device, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(peers, self), nil
|
||||
peers = append(peers, self)
|
||||
}
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
r *dns.Msg) (int, error) {
|
||||
// request.Request is a convenience struct we wrap around the msg and
|
||||
// ResponseWriter.
|
||||
state := request.Request{W: w, Req: r, Zone: p.zone}
|
||||
state := request.Request{W: w, Req: r}
|
||||
|
||||
// Check if the request is for the zone we are serving. If it doesn't match
|
||||
// we pass the request on to the next plugin.
|
||||
if plugin.Zones([]string{p.zone}).Matches(state.Name()) == "" {
|
||||
// Check if the request is for a zone we are serving. If it doesn't match we
|
||||
// pass the request on to the next plugin.
|
||||
zoneName := plugin.Zones(p.Names).Matches(state.Name())
|
||||
if zoneName == "" {
|
||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
state.Zone = zoneName
|
||||
|
||||
zone, ok := p.Z[zoneName]
|
||||
if !ok {
|
||||
return dns.RcodeServerFailure, nil
|
||||
}
|
||||
|
||||
// strip zone from name
|
||||
name := strings.TrimSuffix(state.Name(), p.zone)
|
||||
name := strings.TrimSuffix(state.Name(), zoneName)
|
||||
queryType := state.QType()
|
||||
|
||||
logger.Debugf("received query for: %s type: %s", name,
|
||||
@ -204,7 +219,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
return nxDomain(state)
|
||||
}
|
||||
|
||||
peers, err := p.getPeers(state)
|
||||
peers, err := getPeers(p.client, zone, state)
|
||||
if err != nil {
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
|
22
wgsd_test.go
22
wgsd_test.go
@ -16,11 +16,11 @@ import (
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
device *wgtypes.Device
|
||||
devices map[string]*wgtypes.Device
|
||||
}
|
||||
|
||||
func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
|
||||
return m.device, nil
|
||||
return m.devices[d], nil
|
||||
}
|
||||
|
||||
func constructAllowedIPs(t *testing.T, prefixes []string) ([]net.IPNet, string) {
|
||||
@ -74,17 +74,27 @@ func TestWGSD(t *testing.T) {
|
||||
peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:])
|
||||
p := &WGSD{
|
||||
Next: test.ErrorHandler(),
|
||||
Zones: Zones{
|
||||
Names: []string{"example.com."},
|
||||
Z: map[string]*Zone{
|
||||
"example.com.": {
|
||||
name: "example.com.",
|
||||
device: "wg0",
|
||||
serveSelf: true,
|
||||
selfAllowedIPs: selfAllowed,
|
||||
},
|
||||
},
|
||||
},
|
||||
client: &mockClient{
|
||||
device: &wgtypes.Device{
|
||||
devices: map[string]*wgtypes.Device{
|
||||
"wg0": {
|
||||
Name: "wg0",
|
||||
PublicKey: selfKey,
|
||||
ListenPort: 51820,
|
||||
Peers: []wgtypes.Peer{peer1, peer2},
|
||||
},
|
||||
},
|
||||
zone: "example.com.",
|
||||
device: "wg0",
|
||||
selfAllowedIPs: selfAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []test.Case{
|
||||
|
Loading…
x
Reference in New Issue
Block a user