mirror of
https://github.com/jwhited/wgsd.git
synced 2025-04-03 18:49:32 +08:00
encoding: allow to use sha1, hexa and supports truncation
This commit is contained in:
parent
ce787925be
commit
fb126c1ded
19
README.md
19
README.md
@ -29,12 +29,25 @@ A basic client is available under [cmd/wgsd-client](cmd/wgsd-client).
|
||||
## Configuration Syntax
|
||||
|
||||
```
|
||||
wgsd ZONE DEVICE
|
||||
wgsd ZONE DEVICE ENCODING
|
||||
```
|
||||
|
||||
## Querying
|
||||
|
||||
Following RFC6763 this plugin provides a listing of peers via PTR records at the namespace `_wireguard._udp.<zone>`. The target for the PTR records is `<base32PubKey>._wireguard._udp.<zone>` which corresponds to SRV records. SRV targets are of the format `<base32PubKey>.<zone>`. When querying the SRV record for a peer, the target A/AAAA records will be included in the "additional" section of the response. Public keys are represented in Base32 rather than Base64 to allow for their use in node names where they are treated as case-insensitive by the DNS.
|
||||
Following RFC6763 this plugin provides a listing of peers via PTR records at the namespace `_wireguard._udp.<zone>`. The target for the PTR records is `<$encodingPubKey>._wireguard._udp.<zone>` which corresponds to SRV records. SRV targets are of the format `<base32PubKey>.<zone>`. When querying the SRV record for a peer, the target A/AAAA records will be included in the "additional" section of the response. Public keys are represented in the chosen encoding rather than original Base64 human representation to allow for their use in node names where they are treated as case-insensitive by the DNS.
|
||||
|
||||
The supported encoding settings are:
|
||||
* base32 `b32`
|
||||
* hexadecimal `hex`
|
||||
* sha1 `sha1`
|
||||
|
||||
Truncation, for example route53 restricts its record maximum size, to keep consistency with these setup limitation, wgsd supports truncation. This truncation could be compared to a short git sha1.
|
||||
|
||||
The supported encoding truncation are:
|
||||
* hexadecimal - e.g.: `hex:7`
|
||||
* sha1 - e.g: `sha1:9`
|
||||
|
||||
Note that the sha1 and truncation allows to obfuscate peers public keys.
|
||||
|
||||
## Example
|
||||
|
||||
@ -42,7 +55,7 @@ This configuration:
|
||||
```
|
||||
$ cat Corefile
|
||||
.:5353 {
|
||||
wgsd example.com. wg0
|
||||
wgsd example.com. wg0 b32
|
||||
}
|
||||
```
|
||||
|
||||
|
77
encoder.go
Normal file
77
encoder.go
Normal file
@ -0,0 +1,77 @@
|
||||
package wgsd
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/base32"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type encoder interface {
|
||||
EncodeToString([]byte) string
|
||||
}
|
||||
|
||||
func getEncoder(e string) (encoder, error) {
|
||||
parts := strings.Split(e, ":")
|
||||
if len(parts) > 2 {
|
||||
return nil, errors.New("failed to parse encoder")
|
||||
}
|
||||
name := parts[0]
|
||||
if len(parts) == 1 {
|
||||
return buildEncoder(name, 0)
|
||||
}
|
||||
trunc, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if trunc < 0 {
|
||||
return nil, errors.New("truncation value is < 0")
|
||||
}
|
||||
return buildEncoder(name, trunc)
|
||||
}
|
||||
|
||||
func buildEncoder(name string, trunc int) (encoder, error) {
|
||||
switch name {
|
||||
case "b32":
|
||||
if trunc != 0 {
|
||||
return nil, fmt.Errorf("%s doesn't support truncation", name)
|
||||
}
|
||||
return base32.StdEncoding, nil
|
||||
case "sha1":
|
||||
return &shaOne{trunc: trunc}, nil
|
||||
case "hex":
|
||||
return &hexa{trunc: trunc}, nil
|
||||
default:
|
||||
return nil, errors.New("invalid encoder")
|
||||
}
|
||||
}
|
||||
|
||||
type shaOne struct {
|
||||
trunc int
|
||||
}
|
||||
|
||||
func (e *shaOne) EncodeToString(b []byte) string {
|
||||
h := sha1.New()
|
||||
_, _ = h.Write(b)
|
||||
sum := h.Sum(nil)
|
||||
r := hex.EncodeToString(sum)
|
||||
if e.trunc == 0 || len(r) < e.trunc {
|
||||
return r
|
||||
}
|
||||
return r[:e.trunc]
|
||||
}
|
||||
|
||||
type hexa struct {
|
||||
trunc int
|
||||
}
|
||||
|
||||
func (e *hexa) EncodeToString(b []byte) string {
|
||||
r := hex.EncodeToString(b)
|
||||
if e.trunc == 0 || len(r) < e.trunc {
|
||||
return r
|
||||
}
|
||||
return r[:e.trunc]
|
||||
}
|
1
go.mod
1
go.mod
@ -6,5 +6,6 @@ require (
|
||||
github.com/coredns/caddy v1.1.0
|
||||
github.com/coredns/coredns v1.8.0
|
||||
github.com/miekg/dns v1.1.34
|
||||
github.com/stretchr/testify v1.4.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20200511024508-91d9787b944f
|
||||
)
|
||||
|
13
setup.go
13
setup.go
@ -19,16 +19,24 @@ func setup(c *caddy.Controller) error {
|
||||
|
||||
// return an error if there is no zone specified
|
||||
if !c.NextArg() {
|
||||
return plugin.Error("wgsd", c.ArgErr())
|
||||
return plugin.Error("wgsd", fmt.Errorf("missing zone: %v", c.ArgErr()))
|
||||
}
|
||||
zone := dns.Fqdn(c.Val())
|
||||
|
||||
// return an error if there is no device name specified
|
||||
if !c.NextArg() {
|
||||
return plugin.Error("wgsd", c.ArgErr())
|
||||
return plugin.Error("wgsd", fmt.Errorf("missing wireguard device name: %v", c.ArgErr()))
|
||||
}
|
||||
device := c.Val()
|
||||
|
||||
if !c.NextArg() {
|
||||
return plugin.Error("wgsd", fmt.Errorf("missing wireguard public key encoding: %v", c.ArgErr()))
|
||||
}
|
||||
enc, err := getEncoder(c.Val())
|
||||
if err != nil {
|
||||
return plugin.Error("wgsd", fmt.Errorf("unsupported wireguard public key encoding: %v", err))
|
||||
}
|
||||
|
||||
// return an error if there are more tokens on this line
|
||||
if c.NextArg() {
|
||||
return plugin.Error("wgsd", c.ArgErr())
|
||||
@ -48,6 +56,7 @@ func setup(c *caddy.Controller) error {
|
||||
client: client,
|
||||
zone: zone,
|
||||
device: device,
|
||||
enc: enc,
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -13,10 +13,40 @@ func TestSetup(t *testing.T) {
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
"valid input",
|
||||
"wgsd example.com. wg0",
|
||||
"valid input b32",
|
||||
"wgsd example.com. wg0 b32",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid input sha1",
|
||||
"wgsd example.com. wg0 sha1",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid input hex",
|
||||
"wgsd example.com. wg0 hex",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid input sha1 truncate",
|
||||
"wgsd example.com. wg0 sha1:7",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid input hex truncate",
|
||||
"wgsd example.com. wg0 hex:7",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid input hex truncate",
|
||||
"wgsd example.com. wg0 hex:-1",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid input b32 truncate",
|
||||
"wgsd example.com. wg0 b32:7",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"missing token",
|
||||
"wgsd example.com.",
|
||||
@ -24,7 +54,7 @@ func TestSetup(t *testing.T) {
|
||||
},
|
||||
{
|
||||
"too many tokens",
|
||||
"wgsd example.com. wg0 extra",
|
||||
"wgsd example.com. wg0 b32 extra",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
91
wgsd.go
91
wgsd.go
@ -2,7 +2,6 @@ package wgsd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base32"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@ -28,6 +27,9 @@ type WGSD struct {
|
||||
zone string
|
||||
// the Wireguard device name, e.g. wg0
|
||||
device string
|
||||
|
||||
// the encoder used to encode wireguard peer public keys
|
||||
enc encoder
|
||||
}
|
||||
|
||||
type wgctrlClient interface {
|
||||
@ -35,9 +37,7 @@ type wgctrlClient interface {
|
||||
}
|
||||
|
||||
const (
|
||||
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
|
||||
spPrefix = "_wireguard._udp."
|
||||
serviceInstanceLen = keyLen + len(".") + len(spPrefix)
|
||||
spPrefix = "_wireguard._udp."
|
||||
)
|
||||
|
||||
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
@ -68,7 +68,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
}
|
||||
|
||||
// setup our reply message
|
||||
m := new(dns.Msg)
|
||||
m := &dns.Msg{}
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
|
||||
@ -87,56 +87,55 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
|
||||
Ttl: 0,
|
||||
},
|
||||
Ptr: fmt.Sprintf("%s.%s%s",
|
||||
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
|
||||
strings.ToLower(p.enc.EncodeToString(peer.PublicKey[:])),
|
||||
spPrefix, p.zone),
|
||||
})
|
||||
}
|
||||
w.WriteMsg(m) // nolint: errcheck
|
||||
_ = w.WriteMsg(m)
|
||||
return dns.RcodeSuccess, nil
|
||||
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
|
||||
pubKey := name[:keyLen]
|
||||
case qtype == dns.TypeSRV:
|
||||
pubKey := strings.TrimSuffix(name, "."+spPrefix)
|
||||
for _, peer := range device.Peers {
|
||||
if strings.EqualFold(
|
||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||
endpoint := peer.Endpoint
|
||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
||||
if hostRR == nil {
|
||||
return nxDomain(p.zone, w, r)
|
||||
}
|
||||
m.Extra = append(m.Extra, hostRR)
|
||||
m.Answer = append(m.Answer, &dns.SRV{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: state.Name(),
|
||||
Rrtype: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Priority: 0,
|
||||
Weight: 0,
|
||||
Port: uint16(endpoint.Port),
|
||||
Target: fmt.Sprintf("%s.%s",
|
||||
strings.ToLower(pubKey), p.zone),
|
||||
})
|
||||
w.WriteMsg(m) // nolint: errcheck
|
||||
return dns.RcodeSuccess, nil
|
||||
if !strings.EqualFold(p.enc.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||
continue
|
||||
}
|
||||
endpoint := peer.Endpoint
|
||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
||||
if hostRR == nil {
|
||||
return nxDomain(p.zone, w, r)
|
||||
}
|
||||
m.Extra = append(m.Extra, hostRR)
|
||||
m.Answer = append(m.Answer, &dns.SRV{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: state.Name(),
|
||||
Rrtype: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Priority: 0,
|
||||
Weight: 0,
|
||||
Port: uint16(endpoint.Port),
|
||||
Target: fmt.Sprintf("%s.%s",
|
||||
strings.ToLower(pubKey), p.zone),
|
||||
})
|
||||
_ = w.WriteMsg(m)
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
return nxDomain(p.zone, w, r)
|
||||
case len(name) == keyLen+1 && (qtype == dns.TypeA ||
|
||||
qtype == dns.TypeAAAA):
|
||||
pubKey := name[:keyLen]
|
||||
case qtype == dns.TypeA || qtype == dns.TypeAAAA:
|
||||
pubKey := strings.TrimSuffix(name, ".")
|
||||
for _, peer := range device.Peers {
|
||||
if strings.EqualFold(
|
||||
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||
endpoint := peer.Endpoint
|
||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
||||
if hostRR == nil {
|
||||
return nxDomain(p.zone, w, r)
|
||||
}
|
||||
m.Answer = append(m.Answer, hostRR)
|
||||
w.WriteMsg(m) // nolint: errcheck
|
||||
return dns.RcodeSuccess, nil
|
||||
if !strings.EqualFold(p.enc.EncodeToString(peer.PublicKey[:]), pubKey) {
|
||||
continue
|
||||
}
|
||||
endpoint := peer.Endpoint
|
||||
hostRR := getHostRR(pubKey, p.zone, endpoint)
|
||||
if hostRR == nil {
|
||||
return nxDomain(p.zone, w, r)
|
||||
}
|
||||
m.Answer = append(m.Answer, hostRR)
|
||||
_ = w.WriteMsg(m)
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
return nxDomain(p.zone, w, r)
|
||||
default:
|
||||
@ -182,7 +181,7 @@ func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
m.Authoritative = true
|
||||
m.Rcode = dns.RcodeNameError
|
||||
m.Ns = []dns.RR{soa(zone)}
|
||||
w.WriteMsg(m) // nolint: errcheck
|
||||
_ = w.WriteMsg(m)
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
|
45
wgsd_test.go
45
wgsd_test.go
@ -11,6 +11,7 @@ import (
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
@ -26,8 +27,8 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
|
||||
}
|
||||
|
||||
func TestWGSD(t *testing.T) {
|
||||
key1 := [32]byte{}
|
||||
key1[0] = 1
|
||||
key1, err := wgtypes.ParseKey("JeZlz14G8tg1Bqh6apteFCwVhNhpexJ19FDPfuxQtUY=")
|
||||
require.NoError(t, err)
|
||||
peer1 := wgtypes.Peer{
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.ParseIP("1.1.1.1"),
|
||||
@ -36,8 +37,9 @@ func TestWGSD(t *testing.T) {
|
||||
PublicKey: key1,
|
||||
}
|
||||
peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:]))
|
||||
key2 := [32]byte{}
|
||||
key2[0] = 2
|
||||
|
||||
key2, err := wgtypes.ParseKey("xScVkH3fUGUv4RrJFfmcqm8rs3SEHr41km6+yffAHw4=")
|
||||
require.NoError(t, err)
|
||||
peer2 := wgtypes.Peer{
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.ParseIP("::2"),
|
||||
@ -46,6 +48,9 @@ func TestWGSD(t *testing.T) {
|
||||
PublicKey: key2,
|
||||
}
|
||||
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
|
||||
|
||||
enc, err := getEncoder("b32")
|
||||
require.NoError(t, err)
|
||||
p := &WGSD{
|
||||
Next: test.ErrorHandler(),
|
||||
client: &mockClient{
|
||||
@ -53,6 +58,7 @@ func TestWGSD(t *testing.T) {
|
||||
},
|
||||
zone: "example.com.",
|
||||
device: "wg0",
|
||||
enc: enc,
|
||||
}
|
||||
|
||||
testCases := []test.Case{
|
||||
@ -122,25 +128,22 @@ func TestWGSD(t *testing.T) {
|
||||
m := tc.Msg()
|
||||
rec := dnstest.NewRecorder(&test.ResponseWriter{})
|
||||
ctx := context.TODO()
|
||||
|
||||
_, err := p.ServeDNS(ctx, rec, m)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := rec.Msg
|
||||
if err := test.Header(tc, resp); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if err := test.Section(tc, test.Answer, resp.Answer); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := test.Section(tc, test.Ns, resp.Ns); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := test.Section(tc, test.Extra, resp.Extra); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = test.Header(tc, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = test.Section(tc, test.Answer, resp.Answer)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = test.Section(tc, test.Ns, resp.Ns)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = test.Section(tc, test.Extra, resp.Extra)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user