handle SRV queries

This commit is contained in:
Jordan Whited 2020-05-12 17:13:40 -07:00
parent e6531c81ed
commit 8109291569

67
wgsd.go
View File

@ -51,6 +51,14 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
name := strings.TrimSuffix(state.Name(), p.zone) name := strings.TrimSuffix(state.Name(), p.zone)
qtype := state.QType() qtype := state.QType()
device, err := p.client.Device(p.device)
if err != nil {
return dns.RcodeServerFailure, nil
}
if len(device.Peers) == 0 {
return nxdomain(p.zone, w, r)
}
// setup our reply message // setup our reply message
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
@ -58,17 +66,10 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
switch { switch {
case name == spPrefix && qtype == dns.TypePTR: case name == spPrefix && qtype == dns.TypePTR:
device, err := p.client.Device(p.device)
if err != nil {
return dns.RcodeServerFailure, nil
}
if len(device.Peers) == 0 {
return nxdomain(p.zone, w, r)
}
for _, peer := range device.Peers { for _, peer := range device.Peers {
m.Answer = append(m.Answer, &dns.PTR{ m.Answer = append(m.Answer, &dns.PTR{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: fmt.Sprintf("%s%s", spPrefix, p.zone), Name: state.Name(),
Rrtype: dns.TypePTR, Rrtype: dns.TypePTR,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 0, Ttl: 0,
@ -81,7 +82,55 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
w.WriteMsg(m) // nolint: errcheck w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil return dns.RcodeSuccess, nil
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV: case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
// TODO: return SRV + A/AAAA of peer pubKey := name[:44]
for _, peer := range device.Peers {
if base64.StdEncoding.EncodeToString(peer.PublicKey[:]) == pubKey {
endpoint := peer.Endpoint
if endpoint.IP == nil {
return nxdomain(p.zone, w, r)
}
srvTarget := fmt.Sprintf("%s.%s", pubKey, p.zone)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: srvTarget,
})
switch {
case endpoint.IP.To4() != nil:
m.Extra = append(m.Extra, &dns.A{
Hdr: dns.RR_Header{
Name: srvTarget,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: endpoint.IP,
})
case endpoint.IP.To16() != nil:
m.Extra = append(m.Extra, &dns.AAAA{
Hdr: dns.RR_Header{
Name: srvTarget,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: endpoint.IP,
})
default:
// TODO: this shouldn't happen
}
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
}
}
return nxdomain(p.zone, w, r)
case len(name) == keyLen+len(".") && (qtype == dns.TypeA || case len(name) == keyLen+len(".") && (qtype == dns.TypeA ||
qtype == dns.TypeAAAA): qtype == dns.TypeAAAA):
// TODO: return A/AAAA for of peer // TODO: return A/AAAA for of peer