standardize handler funcs

This commit is contained in:
Jordan Whited 2021-01-01 15:41:11 -08:00 committed by Jordan Whited
parent a928f85a58
commit 7d03ee7041

205
wgsd.go
View File

@ -42,11 +42,114 @@ const (
serviceInstanceLen = keyLen + len(spSubPrefix) serviceInstanceLen = keyLen + len(spSubPrefix)
) )
type handlerFn func(ctx context.Context, state request.Request, peers []wgtypes.Peer) (int, error)
func getHandlerFn(queryType uint16, name string) handlerFn {
switch {
case name == spPrefix && queryType == dns.TypePTR:
return handlePTR
case len(name) == serviceInstanceLen && queryType == dns.TypeSRV:
return handleSRV
case len(name) == len(spSubPrefix)+keyLen && (queryType == dns.TypeA ||
queryType == dns.TypeAAAA || queryType == dns.TypeTXT):
return handleHostOrTXT
default:
return nil
}
}
func handlePTR(ctx context.Context, state request.Request,
peers []wgtypes.Peer) (int, error) {
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
for _, peer := range peers {
if peer.Endpoint == nil {
continue
}
m.Answer = append(m.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
},
Ptr: fmt.Sprintf("%s.%s%s",
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
spPrefix, state.Zone),
})
}
state.W.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
func handleSRV(ctx context.Context, state request.Request,
peers []wgtypes.Peer) (int, error) {
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
pubKey := state.Name()[:keyLen]
for _, peer := range peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(state)
}
txtRR := getTXTRR(state.Name(), peer)
m.Extra = append(m.Extra, hostRR, txtRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: state.Name(),
})
state.W.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(state)
}
func handleHostOrTXT(ctx context.Context, state request.Request,
peers []wgtypes.Peer) (int, error) {
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
pubKey := state.Name()[:keyLen]
for _, peer := range peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA {
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(state)
}
m.Answer = append(m.Answer, hostRR)
} else {
txtRR := getTXTRR(state.Name(), peer)
m.Answer = append(m.Answer, txtRR)
}
state.W.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(state)
}
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
r *dns.Msg) (int, error) { r *dns.Msg) (int, error) {
// request.Request is a convenience struct we wrap around the msg and // request.Request is a convenience struct we wrap around the msg and
// ResponseWriter. // ResponseWriter.
state := request.Request{W: w, Req: r} state := request.Request{W: w, Req: r, Zone: p.zone}
// Check if the request is for the zone we are serving. If it doesn't match // Check if the request is for the zone we are serving. If it doesn't match
// we pass the request on to the next plugin. // we pass the request on to the next plugin.
@ -56,99 +159,25 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
// strip zone from name // strip zone from name
name := strings.TrimSuffix(state.Name(), p.zone) name := strings.TrimSuffix(state.Name(), p.zone)
qType := state.QType() queryType := state.QType()
logger.Debugf("received query for: %s type: %s", name, logger.Debugf("received query for: %s type: %s", name,
dns.TypeToString[qType]) dns.TypeToString[queryType])
handler := getHandlerFn(queryType, name)
if handler == nil {
return nxDomain(state)
}
device, err := p.client.Device(p.device) device, err := p.client.Device(p.device)
if err != nil { if err != nil {
return dns.RcodeServerFailure, err return dns.RcodeServerFailure, err
} }
if len(device.Peers) == 0 { if len(device.Peers) == 0 {
return nxDomain(p.zone, w, r) return nxDomain(state)
} }
// setup our reply message return handler(ctx, state, device.Peers)
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
switch {
// TODO: handle SOA
case name == spPrefix && qType == dns.TypePTR:
for _, peer := range device.Peers {
if peer.Endpoint == nil {
continue
}
m.Answer = append(m.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
},
Ptr: fmt.Sprintf("%s.%s%s",
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
spPrefix, p.zone),
})
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
case len(name) == serviceInstanceLen && qType == dns.TypeSRV:
pubKey := name[:keyLen]
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
txtRR := getTXTRR(state.Name(), peer)
m.Extra = append(m.Extra, hostRR, txtRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: state.Name(),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(p.zone, w, r)
case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA ||
qType == dns.TypeAAAA || qType == dns.TypeTXT):
pubKey := name[:keyLen]
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
if qType == dns.TypeA || qType == dns.TypeAAAA {
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Answer = append(m.Answer, hostRR)
} else {
txtRR := getTXTRR(state.Name(), peer)
m.Answer = append(m.Answer, txtRR)
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxDomain(p.zone, w, r)
default:
return nxDomain(p.zone, w, r)
}
} }
func getHostRR(name string, endpoint *net.UDPAddr) dns.RR { func getHostRR(name string, endpoint *net.UDPAddr) dns.RR {
@ -211,13 +240,13 @@ func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT {
} }
} }
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) { func nxDomain(state request.Request) (int, error) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(state.Req)
m.Authoritative = true m.Authoritative = true
m.Rcode = dns.RcodeNameError m.Rcode = dns.RcodeNameError
m.Ns = []dns.RR{soa(zone)} m.Ns = []dns.RR{soa(state.Zone)}
w.WriteMsg(m) // nolint: errcheck state.W.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil return dns.RcodeSuccess, nil
} }