From 8109291569c4e0e54201f70b7a921354f8ad92da Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 12 May 2020 17:13:40 -0700 Subject: [PATCH] handle SRV queries --- wgsd.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/wgsd.go b/wgsd.go index c790137..5acf349 100644 --- a/wgsd.go +++ b/wgsd.go @@ -51,6 +51,14 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, name := strings.TrimSuffix(state.Name(), p.zone) qtype := state.QType() + device, err := p.client.Device(p.device) + if err != nil { + return dns.RcodeServerFailure, nil + } + if len(device.Peers) == 0 { + return nxdomain(p.zone, w, r) + } + // setup our reply message m := new(dns.Msg) m.SetReply(r) @@ -58,17 +66,10 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, switch { case name == spPrefix && qtype == dns.TypePTR: - device, err := p.client.Device(p.device) - if err != nil { - return dns.RcodeServerFailure, nil - } - if len(device.Peers) == 0 { - return nxdomain(p.zone, w, r) - } for _, peer := range device.Peers { m.Answer = append(m.Answer, &dns.PTR{ Hdr: dns.RR_Header{ - Name: fmt.Sprintf("%s%s", spPrefix, p.zone), + Name: state.Name(), Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0, @@ -81,7 +82,55 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, w.WriteMsg(m) // nolint: errcheck return dns.RcodeSuccess, nil case len(name) == serviceInstanceLen && qtype == dns.TypeSRV: - // TODO: return SRV + A/AAAA of peer + pubKey := name[:44] + for _, peer := range device.Peers { + if base64.StdEncoding.EncodeToString(peer.PublicKey[:]) == pubKey { + endpoint := peer.Endpoint + if endpoint.IP == nil { + return nxdomain(p.zone, w, r) + } + srvTarget := fmt.Sprintf("%s.%s", pubKey, p.zone) + m.Answer = append(m.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: state.Name(), + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 0, + }, + Priority: 0, + Weight: 0, + Port: uint16(endpoint.Port), + Target: srvTarget, + }) + switch { + case endpoint.IP.To4() != nil: + m.Extra = append(m.Extra, &dns.A{ + Hdr: dns.RR_Header{ + Name: srvTarget, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: endpoint.IP, + }) + case endpoint.IP.To16() != nil: + m.Extra = append(m.Extra, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: srvTarget, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 0, + }, + AAAA: endpoint.IP, + }) + default: + // TODO: this shouldn't happen + } + w.WriteMsg(m) // nolint: errcheck + return dns.RcodeSuccess, nil + } + } + return nxdomain(p.zone, w, r) case len(name) == keyLen+len(".") && (qtype == dns.TypeA || qtype == dns.TypeAAAA): // TODO: return A/AAAA for of peer