mirror of
https://github.com/jwhited/wgsd.git
synced 2025-04-09 05:29:30 +08:00
also respond with records for the host itself
This commit is contained in:
parent
3f4967eb68
commit
0edb1df552
21
setup.go
21
setup.go
@ -2,6 +2,7 @@ package wgsd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
"github.com/coredns/caddy"
|
"github.com/coredns/caddy"
|
||||||
"github.com/coredns/coredns/core/dnsserver"
|
"github.com/coredns/coredns/core/dnsserver"
|
||||||
@ -29,6 +30,14 @@ func setup(c *caddy.Controller) error {
|
|||||||
}
|
}
|
||||||
device := c.Val()
|
device := c.Val()
|
||||||
|
|
||||||
|
// parse optional local ip
|
||||||
|
var wgIP net.IP
|
||||||
|
if c.NextArg() {
|
||||||
|
wgIP = net.ParseIP(c.Val())
|
||||||
|
} else {
|
||||||
|
wgIP = getOutboundIP()
|
||||||
|
}
|
||||||
|
|
||||||
// return an error if there are more tokens on this line
|
// return an error if there are more tokens on this line
|
||||||
if c.NextArg() {
|
if c.NextArg() {
|
||||||
return plugin.Error("wgsd", c.ArgErr())
|
return plugin.Error("wgsd", c.ArgErr())
|
||||||
@ -48,8 +57,20 @@ func setup(c *caddy.Controller) error {
|
|||||||
client: client,
|
client: client,
|
||||||
zone: zone,
|
zone: zone,
|
||||||
device: device,
|
device: device,
|
||||||
|
wgIP: wgIP,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get preferred outbound ip of this machine
|
||||||
|
func getOutboundIP() net.IP {
|
||||||
|
conn, err := net.Dial("udp", "1.1.1.1:80")
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return conn.LocalAddr().(*net.UDPAddr).IP
|
||||||
|
}
|
||||||
|
76
wgsd.go
76
wgsd.go
@ -28,12 +28,19 @@ type WGSD struct {
|
|||||||
zone string
|
zone string
|
||||||
// the Wireguard device name, e.g. wg0
|
// the Wireguard device name, e.g. wg0
|
||||||
device string
|
device string
|
||||||
|
// the IP the local wireguard is running on
|
||||||
|
wgIP net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
type wgctrlClient interface {
|
type wgctrlClient interface {
|
||||||
Device(string) (*wgtypes.Device, error)
|
Device(string) (*wgtypes.Device, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type host struct {
|
||||||
|
PublicKey wgtypes.Key
|
||||||
|
Endpoint *net.UDPAddr
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
|
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
|
||||||
spPrefix = "_wireguard._udp."
|
spPrefix = "_wireguard._udp."
|
||||||
@ -63,19 +70,26 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return dns.RcodeServerFailure, err
|
return dns.RcodeServerFailure, err
|
||||||
}
|
}
|
||||||
if len(device.Peers) == 0 {
|
|
||||||
return nxDomain(p.zone, w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup our reply message
|
// setup our reply message
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.Authoritative = true
|
m.Authoritative = true
|
||||||
|
|
||||||
|
allPeers := []host{
|
||||||
|
{PublicKey: device.PublicKey, Endpoint: &net.UDPAddr{Port: device.ListenPort, IP: p.wgIP}},
|
||||||
|
}
|
||||||
|
for _, p := range device.Peers {
|
||||||
|
allPeers = append(allPeers, host{
|
||||||
|
PublicKey: p.PublicKey,
|
||||||
|
Endpoint: p.Endpoint,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
// TODO: handle SOA
|
// TODO: handle SOA
|
||||||
case name == spPrefix && qtype == dns.TypePTR:
|
case name == spPrefix && qtype == dns.TypePTR:
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range allPeers {
|
||||||
if peer.Endpoint == nil {
|
if peer.Endpoint == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -93,39 +107,21 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
|||||||
}
|
}
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
w.WriteMsg(m) // nolint: errcheck
|
||||||
return dns.RcodeSuccess, nil
|
return dns.RcodeSuccess, nil
|
||||||
|
case len(name) != serviceInstanceLen && qtype == dns.TypeSRV:
|
||||||
|
return p.sendSRV(m, w, r, allPeers[0])
|
||||||
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
|
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
|
||||||
pubKey := name[:keyLen]
|
pubKey := name[:keyLen]
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range allPeers {
|
||||||
if strings.EqualFold(
|
if strings.EqualFold(
|
||||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||||
endpoint := peer.Endpoint
|
return p.sendSRV(m, w, r, peer)
|
||||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
|
||||||
if hostRR == nil {
|
|
||||||
return nxDomain(p.zone, w, r)
|
|
||||||
}
|
|
||||||
m.Extra = append(m.Extra, hostRR)
|
|
||||||
m.Answer = append(m.Answer, &dns.SRV{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: state.Name(),
|
|
||||||
Rrtype: dns.TypeSRV,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 0,
|
|
||||||
},
|
|
||||||
Priority: 0,
|
|
||||||
Weight: 0,
|
|
||||||
Port: uint16(endpoint.Port),
|
|
||||||
Target: fmt.Sprintf("%s.%s",
|
|
||||||
strings.ToLower(pubKey), p.zone),
|
|
||||||
})
|
|
||||||
w.WriteMsg(m) // nolint: errcheck
|
|
||||||
return dns.RcodeSuccess, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nxDomain(p.zone, w, r)
|
return nxDomain(p.zone, w, r)
|
||||||
case len(name) == keyLen+1 && (qtype == dns.TypeA ||
|
case len(name) == keyLen+1 && (qtype == dns.TypeA ||
|
||||||
qtype == dns.TypeAAAA):
|
qtype == dns.TypeAAAA):
|
||||||
pubKey := name[:keyLen]
|
pubKey := name[:keyLen]
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range allPeers {
|
||||||
if strings.EqualFold(
|
if strings.EqualFold(
|
||||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||||
endpoint := peer.Endpoint
|
endpoint := peer.Endpoint
|
||||||
@ -176,6 +172,32 @@ func getHostRR(pubKey, zone string, endpoint *net.UDPAddr) dns.RR {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WGSD) sendSRV(m *dns.Msg, w dns.ResponseWriter, r *dns.Msg, peer host) (int, error) {
|
||||||
|
state := request.Request{W: w, Req: r}
|
||||||
|
pubKey := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
|
||||||
|
endpoint := peer.Endpoint
|
||||||
|
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
||||||
|
if hostRR == nil {
|
||||||
|
return nxDomain(p.zone, w, r)
|
||||||
|
}
|
||||||
|
m.Extra = append(m.Extra, hostRR)
|
||||||
|
m.Answer = append(m.Answer, &dns.SRV{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: state.Name(),
|
||||||
|
Rrtype: dns.TypeSRV,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 0,
|
||||||
|
},
|
||||||
|
Priority: 0,
|
||||||
|
Weight: 0,
|
||||||
|
Port: uint16(endpoint.Port),
|
||||||
|
Target: fmt.Sprintf("%s.%s",
|
||||||
|
strings.ToLower(pubKey), p.zone),
|
||||||
|
})
|
||||||
|
w.WriteMsg(m) // nolint: errcheck
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user