mirror of
https://github.com/jwhited/wgsd.git
synced 2025-01-31 12:19:32 +08:00
standardize handler funcs
This commit is contained in:
parent
a928f85a58
commit
7d03ee7041
141
wgsd.go
141
wgsd.go
@ -42,42 +42,28 @@ const (
|
|||||||
serviceInstanceLen = keyLen + len(spSubPrefix)
|
serviceInstanceLen = keyLen + len(spSubPrefix)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
type handlerFn func(ctx context.Context, state request.Request, peers []wgtypes.Peer) (int, error)
|
||||||
r *dns.Msg) (int, error) {
|
|
||||||
// request.Request is a convenience struct we wrap around the msg and
|
|
||||||
// ResponseWriter.
|
|
||||||
state := request.Request{W: w, Req: r}
|
|
||||||
|
|
||||||
// Check if the request is for the zone we are serving. If it doesn't match
|
|
||||||
// we pass the request on to the next plugin.
|
|
||||||
if plugin.Zones([]string{p.zone}).Matches(state.Name()) == "" {
|
|
||||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// strip zone from name
|
|
||||||
name := strings.TrimSuffix(state.Name(), p.zone)
|
|
||||||
qType := state.QType()
|
|
||||||
|
|
||||||
logger.Debugf("received query for: %s type: %s", name,
|
|
||||||
dns.TypeToString[qType])
|
|
||||||
|
|
||||||
device, err := p.client.Device(p.device)
|
|
||||||
if err != nil {
|
|
||||||
return dns.RcodeServerFailure, err
|
|
||||||
}
|
|
||||||
if len(device.Peers) == 0 {
|
|
||||||
return nxDomain(p.zone, w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup our reply message
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Authoritative = true
|
|
||||||
|
|
||||||
|
func getHandlerFn(queryType uint16, name string) handlerFn {
|
||||||
switch {
|
switch {
|
||||||
// TODO: handle SOA
|
case name == spPrefix && queryType == dns.TypePTR:
|
||||||
case name == spPrefix && qType == dns.TypePTR:
|
return handlePTR
|
||||||
for _, peer := range device.Peers {
|
case len(name) == serviceInstanceLen && queryType == dns.TypeSRV:
|
||||||
|
return handleSRV
|
||||||
|
case len(name) == len(spSubPrefix)+keyLen && (queryType == dns.TypeA ||
|
||||||
|
queryType == dns.TypeAAAA || queryType == dns.TypeTXT):
|
||||||
|
return handleHostOrTXT
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlePTR(ctx context.Context, state request.Request,
|
||||||
|
peers []wgtypes.Peer) (int, error) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(state.Req)
|
||||||
|
m.Authoritative = true
|
||||||
|
for _, peer := range peers {
|
||||||
if peer.Endpoint == nil {
|
if peer.Endpoint == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -90,20 +76,26 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
},
|
},
|
||||||
Ptr: fmt.Sprintf("%s.%s%s",
|
Ptr: fmt.Sprintf("%s.%s%s",
|
||||||
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
|
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
|
||||||
spPrefix, p.zone),
|
spPrefix, state.Zone),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
state.W.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
case len(name) == serviceInstanceLen && qType == dns.TypeSRV:
|
}
|
||||||
pubKey := name[:keyLen]
|
|
||||||
for _, peer := range device.Peers {
|
func handleSRV(ctx context.Context, state request.Request,
|
||||||
|
peers []wgtypes.Peer) (int, error) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(state.Req)
|
||||||
|
m.Authoritative = true
|
||||||
|
pubKey := state.Name()[:keyLen]
|
||||||
|
for _, peer := range peers {
|
||||||
if strings.EqualFold(
|
if strings.EqualFold(
|
||||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||||
endpoint := peer.Endpoint
|
endpoint := peer.Endpoint
|
||||||
hostRR := getHostRR(state.Name(), endpoint)
|
hostRR := getHostRR(state.Name(), endpoint)
|
||||||
if hostRR == nil {
|
if hostRR == nil {
|
||||||
return nxDomain(p.zone, w, r)
|
return nxDomain(state)
|
||||||
}
|
}
|
||||||
txtRR := getTXTRR(state.Name(), peer)
|
txtRR := getTXTRR(state.Name(), peer)
|
||||||
m.Extra = append(m.Extra, hostRR, txtRR)
|
m.Extra = append(m.Extra, hostRR, txtRR)
|
||||||
@ -119,36 +111,73 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
Port: uint16(endpoint.Port),
|
Port: uint16(endpoint.Port),
|
||||||
Target: state.Name(),
|
Target: state.Name(),
|
||||||
})
|
})
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
state.W.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nxDomain(p.zone, w, r)
|
return nxDomain(state)
|
||||||
case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA ||
|
}
|
||||||
qType == dns.TypeAAAA || qType == dns.TypeTXT):
|
|
||||||
pubKey := name[:keyLen]
|
func handleHostOrTXT(ctx context.Context, state request.Request,
|
||||||
for _, peer := range device.Peers {
|
peers []wgtypes.Peer) (int, error) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(state.Req)
|
||||||
|
m.Authoritative = true
|
||||||
|
pubKey := state.Name()[:keyLen]
|
||||||
|
for _, peer := range peers {
|
||||||
if strings.EqualFold(
|
if strings.EqualFold(
|
||||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||||
endpoint := peer.Endpoint
|
endpoint := peer.Endpoint
|
||||||
if qType == dns.TypeA || qType == dns.TypeAAAA {
|
if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA {
|
||||||
hostRR := getHostRR(state.Name(), endpoint)
|
hostRR := getHostRR(state.Name(), endpoint)
|
||||||
if hostRR == nil {
|
if hostRR == nil {
|
||||||
return nxDomain(p.zone, w, r)
|
return nxDomain(state)
|
||||||
}
|
}
|
||||||
m.Answer = append(m.Answer, hostRR)
|
m.Answer = append(m.Answer, hostRR)
|
||||||
} else {
|
} else {
|
||||||
txtRR := getTXTRR(state.Name(), peer)
|
txtRR := getTXTRR(state.Name(), peer)
|
||||||
m.Answer = append(m.Answer, txtRR)
|
m.Answer = append(m.Answer, txtRR)
|
||||||
}
|
}
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
state.W.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nxDomain(p.zone, w, r)
|
return nxDomain(state)
|
||||||
default:
|
|
||||||
return nxDomain(p.zone, w, r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||||
|
r *dns.Msg) (int, error) {
|
||||||
|
// request.Request is a convenience struct we wrap around the msg and
|
||||||
|
// ResponseWriter.
|
||||||
|
state := request.Request{W: w, Req: r, Zone: p.zone}
|
||||||
|
|
||||||
|
// Check if the request is for the zone we are serving. If it doesn't match
|
||||||
|
// we pass the request on to the next plugin.
|
||||||
|
if plugin.Zones([]string{p.zone}).Matches(state.Name()) == "" {
|
||||||
|
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// strip zone from name
|
||||||
|
name := strings.TrimSuffix(state.Name(), p.zone)
|
||||||
|
queryType := state.QType()
|
||||||
|
|
||||||
|
logger.Debugf("received query for: %s type: %s", name,
|
||||||
|
dns.TypeToString[queryType])
|
||||||
|
|
||||||
|
handler := getHandlerFn(queryType, name)
|
||||||
|
if handler == nil {
|
||||||
|
return nxDomain(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
device, err := p.client.Device(p.device)
|
||||||
|
if err != nil {
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
|
if len(device.Peers) == 0 {
|
||||||
|
return nxDomain(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler(ctx, state, device.Peers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
|
||||||
@ -211,13 +240,13 @@ func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func nxDomain(state request.Request) (int, error) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(state.Req)
|
||||||
m.Authoritative = true
|
m.Authoritative = true
|
||||||
m.Rcode = dns.RcodeNameError
|
m.Rcode = dns.RcodeNameError
|
||||||
m.Ns = []dns.RR{soa(zone)}
|
m.Ns = []dns.RR{soa(state.Zone)}
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
state.W.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user