mirror of
https://github.com/jwhited/wgsd.git
synced 2025-11-26 21:35:34 +08:00
support multiple zone:device mappings
This commit is contained in:
committed by
Jordan Whited
parent
a700f38f3e
commit
e068f9d9d2
67
wgsd.go
67
wgsd.go
@@ -26,17 +26,21 @@ const (
|
||||
// semantics. WGSD implements the plugin.Handler interface.
|
||||
type WGSD struct {
|
||||
Next plugin.Handler
|
||||
Zones
|
||||
client wgctrlClient // the client for retrieving WireGuard peer information
|
||||
}
|
||||
|
||||
// the client for retrieving WireGuard peer information
|
||||
client wgctrlClient
|
||||
// the DNS zone we are serving records for
|
||||
zone string
|
||||
// the WireGuard device name, e.g. wg0
|
||||
device string
|
||||
// overrides the self endpoint value
|
||||
selfEndpoint *net.UDPAddr
|
||||
// self allowed IPs
|
||||
selfAllowedIPs []net.IPNet
|
||||
type Zones struct {
|
||||
Z map[string]*Zone // a mapping from zone name to zone data
|
||||
Names []string // all keys from the map z as a string slice
|
||||
}
|
||||
|
||||
type Zone struct {
|
||||
name string // the name of the zone we are authoritative for
|
||||
device string // the WireGuard device name, e.g. wg0
|
||||
serveSelf bool // flag to enable serving data about self
|
||||
selfEndpoint *net.UDPAddr // overrides the self endpoint value
|
||||
selfAllowedIPs []net.IPNet // self allowed IPs
|
||||
}
|
||||
|
||||
type wgctrlClient interface {
|
||||
@@ -150,50 +154,61 @@ func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) {
|
||||
return nxDomain(state)
|
||||
}
|
||||
|
||||
func (p *WGSD) getSelfPeer(device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
|
||||
func getSelfPeer(zone *Zone, device *wgtypes.Device, state request.Request) (wgtypes.Peer, error) {
|
||||
self := wgtypes.Peer{
|
||||
PublicKey: device.PublicKey,
|
||||
}
|
||||
if p.selfEndpoint != nil {
|
||||
self.Endpoint = p.selfEndpoint
|
||||
if zone.selfEndpoint != nil {
|
||||
self.Endpoint = zone.selfEndpoint
|
||||
} else {
|
||||
self.Endpoint = &net.UDPAddr{
|
||||
IP: net.ParseIP(state.LocalIP()),
|
||||
Port: device.ListenPort,
|
||||
}
|
||||
}
|
||||
self.AllowedIPs = p.selfAllowedIPs
|
||||
self.AllowedIPs = zone.selfAllowedIPs
|
||||
return self, nil
|
||||
}
|
||||
|
||||
func (p *WGSD) getPeers(state request.Request) ([]wgtypes.Peer, error) {
|
||||
func getPeers(client wgctrlClient, zone *Zone, state request.Request) (
|
||||
[]wgtypes.Peer, error) {
|
||||
peers := make([]wgtypes.Peer, 0)
|
||||
device, err := p.client.Device(p.device)
|
||||
device, err := client.Device(zone.device)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers = append(peers, device.Peers...)
|
||||
self, err := p.getSelfPeer(device, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if zone.serveSelf {
|
||||
self, err := getSelfPeer(zone, device, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers = append(peers, self)
|
||||
}
|
||||
return append(peers, self), nil
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
r *dns.Msg) (int, error) {
|
||||
// request.Request is a convenience struct we wrap around the msg and
|
||||
// ResponseWriter.
|
||||
state := request.Request{W: w, Req: r, Zone: p.zone}
|
||||
state := request.Request{W: w, Req: r}
|
||||
|
||||
// Check if the request is for the zone we are serving. If it doesn't match
|
||||
// we pass the request on to the next plugin.
|
||||
if plugin.Zones([]string{p.zone}).Matches(state.Name()) == "" {
|
||||
// Check if the request is for a zone we are serving. If it doesn't match we
|
||||
// pass the request on to the next plugin.
|
||||
zoneName := plugin.Zones(p.Names).Matches(state.Name())
|
||||
if zoneName == "" {
|
||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
state.Zone = zoneName
|
||||
|
||||
zone, ok := p.Z[zoneName]
|
||||
if !ok {
|
||||
return dns.RcodeServerFailure, nil
|
||||
}
|
||||
|
||||
// strip zone from name
|
||||
name := strings.TrimSuffix(state.Name(), p.zone)
|
||||
name := strings.TrimSuffix(state.Name(), zoneName)
|
||||
queryType := state.QType()
|
||||
|
||||
logger.Debugf("received query for: %s type: %s", name,
|
||||
@@ -204,7 +219,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
return nxDomain(state)
|
||||
}
|
||||
|
||||
peers, err := p.getPeers(state)
|
||||
peers, err := getPeers(p.client, zone, state)
|
||||
if err != nil {
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user