diff --git a/wgsd.go b/wgsd.go index b70ce45..6cc2aa0 100644 --- a/wgsd.go +++ b/wgsd.go @@ -42,11 +42,114 @@ const ( serviceInstanceLen = keyLen + len(spSubPrefix) ) +type handlerFn func(ctx context.Context, state request.Request, peers []wgtypes.Peer) (int, error) + +func getHandlerFn(queryType uint16, name string) handlerFn { + switch { + case name == spPrefix && queryType == dns.TypePTR: + return handlePTR + case len(name) == serviceInstanceLen && queryType == dns.TypeSRV: + return handleSRV + case len(name) == len(spSubPrefix)+keyLen && (queryType == dns.TypeA || + queryType == dns.TypeAAAA || queryType == dns.TypeTXT): + return handleHostOrTXT + default: + return nil + } +} + +func handlePTR(ctx context.Context, state request.Request, + peers []wgtypes.Peer) (int, error) { + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + for _, peer := range peers { + if peer.Endpoint == nil { + continue + } + m.Answer = append(m.Answer, &dns.PTR{ + Hdr: dns.RR_Header{ + Name: state.Name(), + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 0, + }, + Ptr: fmt.Sprintf("%s.%s%s", + strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])), + spPrefix, state.Zone), + }) + } + state.W.WriteMsg(m) // nolint: errcheck + return dns.RcodeSuccess, nil +} + +func handleSRV(ctx context.Context, state request.Request, + peers []wgtypes.Peer) (int, error) { + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + pubKey := state.Name()[:keyLen] + for _, peer := range peers { + if strings.EqualFold( + base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { + endpoint := peer.Endpoint + hostRR := getHostRR(state.Name(), endpoint) + if hostRR == nil { + return nxDomain(state) + } + txtRR := getTXTRR(state.Name(), peer) + m.Extra = append(m.Extra, hostRR, txtRR) + m.Answer = append(m.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: state.Name(), + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 0, + }, + Priority: 0, + Weight: 0, + Port: uint16(endpoint.Port), + Target: state.Name(), + }) + state.W.WriteMsg(m) // nolint: errcheck + return dns.RcodeSuccess, nil + } + } + return nxDomain(state) +} + +func handleHostOrTXT(ctx context.Context, state request.Request, + peers []wgtypes.Peer) (int, error) { + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + pubKey := state.Name()[:keyLen] + for _, peer := range peers { + if strings.EqualFold( + base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { + endpoint := peer.Endpoint + if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA { + hostRR := getHostRR(state.Name(), endpoint) + if hostRR == nil { + return nxDomain(state) + } + m.Answer = append(m.Answer, hostRR) + } else { + txtRR := getTXTRR(state.Name(), peer) + m.Answer = append(m.Answer, txtRR) + } + state.W.WriteMsg(m) // nolint: errcheck + return dns.RcodeSuccess, nil + } + } + return nxDomain(state) +} + func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { // request.Request is a convenience struct we wrap around the msg and // ResponseWriter. - state := request.Request{W: w, Req: r} + state := request.Request{W: w, Req: r, Zone: p.zone} // Check if the request is for the zone we are serving. If it doesn't match // we pass the request on to the next plugin. @@ -56,99 +159,25 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, // strip zone from name name := strings.TrimSuffix(state.Name(), p.zone) - qType := state.QType() + queryType := state.QType() logger.Debugf("received query for: %s type: %s", name, - dns.TypeToString[qType]) + dns.TypeToString[queryType]) + + handler := getHandlerFn(queryType, name) + if handler == nil { + return nxDomain(state) + } device, err := p.client.Device(p.device) if err != nil { return dns.RcodeServerFailure, err } if len(device.Peers) == 0 { - return nxDomain(p.zone, w, r) + return nxDomain(state) } - // setup our reply message - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - - switch { - // TODO: handle SOA - case name == spPrefix && qType == dns.TypePTR: - for _, peer := range device.Peers { - if peer.Endpoint == nil { - continue - } - m.Answer = append(m.Answer, &dns.PTR{ - Hdr: dns.RR_Header{ - Name: state.Name(), - Rrtype: dns.TypePTR, - Class: dns.ClassINET, - Ttl: 0, - }, - Ptr: fmt.Sprintf("%s.%s%s", - strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])), - spPrefix, p.zone), - }) - } - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil - case len(name) == serviceInstanceLen && qType == dns.TypeSRV: - pubKey := name[:keyLen] - for _, peer := range device.Peers { - if strings.EqualFold( - base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { - endpoint := peer.Endpoint - hostRR := getHostRR(state.Name(), endpoint) - if hostRR == nil { - return nxDomain(p.zone, w, r) - } - txtRR := getTXTRR(state.Name(), peer) - m.Extra = append(m.Extra, hostRR, txtRR) - m.Answer = append(m.Answer, &dns.SRV{ - Hdr: dns.RR_Header{ - Name: state.Name(), - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: 0, - }, - Priority: 0, - Weight: 0, - Port: uint16(endpoint.Port), - Target: state.Name(), - }) - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil - } - } - return nxDomain(p.zone, w, r) - case len(name) == len(spSubPrefix)+keyLen && (qType == dns.TypeA || - qType == dns.TypeAAAA || qType == dns.TypeTXT): - pubKey := name[:keyLen] - for _, peer := range device.Peers { - if strings.EqualFold( - base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { - endpoint := peer.Endpoint - if qType == dns.TypeA || qType == dns.TypeAAAA { - hostRR := getHostRR(state.Name(), endpoint) - if hostRR == nil { - return nxDomain(p.zone, w, r) - } - m.Answer = append(m.Answer, hostRR) - } else { - txtRR := getTXTRR(state.Name(), peer) - m.Answer = append(m.Answer, txtRR) - } - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil - } - } - return nxDomain(p.zone, w, r) - default: - return nxDomain(p.zone, w, r) - } + return handler(ctx, state, device.Peers) } func getHostRR(name string, endpoint *net.UDPAddr) dns.RR { @@ -211,13 +240,13 @@ func getTXTRR(name string, peer wgtypes.Peer) *dns.TXT { } } -func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) { +func nxDomain(state request.Request) (int, error) { m := new(dns.Msg) - m.SetReply(r) + m.SetReply(state.Req) m.Authoritative = true m.Rcode = dns.RcodeNameError - m.Ns = []dns.RR{soa(zone)} - w.WriteMsg(m) // nolint: errcheck + m.Ns = []dns.RR{soa(state.Zone)} + state.W.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil }