encoding: allow to use sha1, hexa and supports truncation

This commit is contained in:
Julien Balestra 2020-12-18 18:05:24 +01:00
parent ce787925be
commit fb126c1ded
7 changed files with 207 additions and 75 deletions

View File

@ -29,12 +29,25 @@ A basic client is available under [cmd/wgsd-client](cmd/wgsd-client).
## Configuration Syntax
```
wgsd ZONE DEVICE
wgsd ZONE DEVICE ENCODING
```
## Querying
Following RFC6763 this plugin provides a listing of peers via PTR records at the namespace `_wireguard._udp.<zone>`. The target for the PTR records is `<base32PubKey>._wireguard._udp.<zone>` which corresponds to SRV records. SRV targets are of the format `<base32PubKey>.<zone>`. When querying the SRV record for a peer, the target A/AAAA records will be included in the "additional" section of the response. Public keys are represented in Base32 rather than Base64 to allow for their use in node names where they are treated as case-insensitive by the DNS.
Following RFC6763 this plugin provides a listing of peers via PTR records at the namespace `_wireguard._udp.<zone>`. The target for the PTR records is `<$encodingPubKey>._wireguard._udp.<zone>` which corresponds to SRV records. SRV targets are of the format `<base32PubKey>.<zone>`. When querying the SRV record for a peer, the target A/AAAA records will be included in the "additional" section of the response. Public keys are represented in the chosen encoding rather than original Base64 human representation to allow for their use in node names where they are treated as case-insensitive by the DNS.
The supported encoding settings are:
* base32 `b32`
* hexadecimal `hex`
* sha1 `sha1`
Truncation, for example route53 restricts its record maximum size, to keep consistency with these setup limitation, wgsd supports truncation. This truncation could be compared to a short git sha1.
The supported encoding truncation are:
* hexadecimal - e.g.: `hex:7`
* sha1 - e.g: `sha1:9`
Note that the sha1 and truncation allows to obfuscate peers public keys.
## Example
@ -42,7 +55,7 @@ This configuration:
```
$ cat Corefile
.:5353 {
wgsd example.com. wg0
wgsd example.com. wg0 b32
}
```

77
encoder.go Normal file
View File

@ -0,0 +1,77 @@
package wgsd
import (
"crypto/sha1"
"encoding/base32"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
)
type encoder interface {
EncodeToString([]byte) string
}
func getEncoder(e string) (encoder, error) {
parts := strings.Split(e, ":")
if len(parts) > 2 {
return nil, errors.New("failed to parse encoder")
}
name := parts[0]
if len(parts) == 1 {
return buildEncoder(name, 0)
}
trunc, err := strconv.Atoi(parts[1])
if err != nil {
return nil, err
}
if trunc < 0 {
return nil, errors.New("truncation value is < 0")
}
return buildEncoder(name, trunc)
}
func buildEncoder(name string, trunc int) (encoder, error) {
switch name {
case "b32":
if trunc != 0 {
return nil, fmt.Errorf("%s doesn't support truncation", name)
}
return base32.StdEncoding, nil
case "sha1":
return &shaOne{trunc: trunc}, nil
case "hex":
return &hexa{trunc: trunc}, nil
default:
return nil, errors.New("invalid encoder")
}
}
type shaOne struct {
trunc int
}
func (e *shaOne) EncodeToString(b []byte) string {
h := sha1.New()
_, _ = h.Write(b)
sum := h.Sum(nil)
r := hex.EncodeToString(sum)
if e.trunc == 0 || len(r) < e.trunc {
return r
}
return r[:e.trunc]
}
type hexa struct {
trunc int
}
func (e *hexa) EncodeToString(b []byte) string {
r := hex.EncodeToString(b)
if e.trunc == 0 || len(r) < e.trunc {
return r
}
return r[:e.trunc]
}

1
go.mod
View File

@ -6,5 +6,6 @@ require (
github.com/coredns/caddy v1.1.0
github.com/coredns/coredns v1.8.0
github.com/miekg/dns v1.1.34
github.com/stretchr/testify v1.4.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20200511024508-91d9787b944f
)

View File

@ -19,16 +19,24 @@ func setup(c *caddy.Controller) error {
// return an error if there is no zone specified
if !c.NextArg() {
return plugin.Error("wgsd", c.ArgErr())
return plugin.Error("wgsd", fmt.Errorf("missing zone: %v", c.ArgErr()))
}
zone := dns.Fqdn(c.Val())
// return an error if there is no device name specified
if !c.NextArg() {
return plugin.Error("wgsd", c.ArgErr())
return plugin.Error("wgsd", fmt.Errorf("missing wireguard device name: %v", c.ArgErr()))
}
device := c.Val()
if !c.NextArg() {
return plugin.Error("wgsd", fmt.Errorf("missing wireguard public key encoding: %v", c.ArgErr()))
}
enc, err := getEncoder(c.Val())
if err != nil {
return plugin.Error("wgsd", fmt.Errorf("unsupported wireguard public key encoding: %v", err))
}
// return an error if there are more tokens on this line
if c.NextArg() {
return plugin.Error("wgsd", c.ArgErr())
@ -48,6 +56,7 @@ func setup(c *caddy.Controller) error {
client: client,
zone: zone,
device: device,
enc: enc,
}
})

View File

@ -13,10 +13,40 @@ func TestSetup(t *testing.T) {
expectErr bool
}{
{
"valid input",
"wgsd example.com. wg0",
"valid input b32",
"wgsd example.com. wg0 b32",
false,
},
{
"valid input sha1",
"wgsd example.com. wg0 sha1",
false,
},
{
"valid input hex",
"wgsd example.com. wg0 hex",
false,
},
{
"valid input sha1 truncate",
"wgsd example.com. wg0 sha1:7",
false,
},
{
"valid input hex truncate",
"wgsd example.com. wg0 hex:7",
false,
},
{
"valid input hex truncate",
"wgsd example.com. wg0 hex:-1",
true,
},
{
"invalid input b32 truncate",
"wgsd example.com. wg0 b32:7",
true,
},
{
"missing token",
"wgsd example.com.",
@ -24,7 +54,7 @@ func TestSetup(t *testing.T) {
},
{
"too many tokens",
"wgsd example.com. wg0 extra",
"wgsd example.com. wg0 b32 extra",
true,
},
}

91
wgsd.go
View File

@ -2,7 +2,6 @@ package wgsd
import (
"context"
"encoding/base32"
"fmt"
"net"
"strings"
@ -28,6 +27,9 @@ type WGSD struct {
zone string
// the Wireguard device name, e.g. wg0
device string
// the encoder used to encode wireguard peer public keys
enc encoder
}
type wgctrlClient interface {
@ -35,9 +37,7 @@ type wgctrlClient interface {
}
const (
keyLen = 56 // the number of characters in a base32-encoded Wireguard public key
spPrefix = "_wireguard._udp."
serviceInstanceLen = keyLen + len(".") + len(spPrefix)
spPrefix = "_wireguard._udp."
)
func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
@ -68,7 +68,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
}
// setup our reply message
m := new(dns.Msg)
m := &dns.Msg{}
m.SetReply(r)
m.Authoritative = true
@ -87,56 +87,55 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter,
Ttl: 0,
},
Ptr: fmt.Sprintf("%s.%s%s",
strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])),
strings.ToLower(p.enc.EncodeToString(peer.PublicKey[:])),
spPrefix, p.zone),
})
}
w.WriteMsg(m) // nolint: errcheck
_ = w.WriteMsg(m)
return dns.RcodeSuccess, nil
case len(name) == serviceInstanceLen && qtype == dns.TypeSRV:
pubKey := name[:keyLen]
case qtype == dns.TypeSRV:
pubKey := strings.TrimSuffix(name, "."+spPrefix)
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: fmt.Sprintf("%s.%s",
strings.ToLower(pubKey), p.zone),
})
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
if !strings.EqualFold(p.enc.EncodeToString(peer.PublicKey[:]), pubKey) {
continue
}
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Extra = append(m.Extra, hostRR)
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: state.Name(),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Priority: 0,
Weight: 0,
Port: uint16(endpoint.Port),
Target: fmt.Sprintf("%s.%s",
strings.ToLower(pubKey), p.zone),
})
_ = w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
return nxDomain(p.zone, w, r)
case len(name) == keyLen+1 && (qtype == dns.TypeA ||
qtype == dns.TypeAAAA):
pubKey := name[:keyLen]
case qtype == dns.TypeA || qtype == dns.TypeAAAA:
pubKey := strings.TrimSuffix(name, ".")
for _, peer := range device.Peers {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Answer = append(m.Answer, hostRR)
w.WriteMsg(m) // nolint: errcheck
return dns.RcodeSuccess, nil
if !strings.EqualFold(p.enc.EncodeToString(peer.PublicKey[:]), pubKey) {
continue
}
endpoint := peer.Endpoint
hostRR := getHostRR(pubKey, p.zone, endpoint)
if hostRR == nil {
return nxDomain(p.zone, w, r)
}
m.Answer = append(m.Answer, hostRR)
_ = w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
return nxDomain(p.zone, w, r)
default:
@ -182,7 +181,7 @@ func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m.Authoritative = true
m.Rcode = dns.RcodeNameError
m.Ns = []dns.RR{soa(zone)}
w.WriteMsg(m) // nolint: errcheck
_ = w.WriteMsg(m)
return dns.RcodeSuccess, nil
}

View File

@ -11,6 +11,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -26,8 +27,8 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) {
}
func TestWGSD(t *testing.T) {
key1 := [32]byte{}
key1[0] = 1
key1, err := wgtypes.ParseKey("JeZlz14G8tg1Bqh6apteFCwVhNhpexJ19FDPfuxQtUY=")
require.NoError(t, err)
peer1 := wgtypes.Peer{
Endpoint: &net.UDPAddr{
IP: net.ParseIP("1.1.1.1"),
@ -36,8 +37,9 @@ func TestWGSD(t *testing.T) {
PublicKey: key1,
}
peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:]))
key2 := [32]byte{}
key2[0] = 2
key2, err := wgtypes.ParseKey("xScVkH3fUGUv4RrJFfmcqm8rs3SEHr41km6+yffAHw4=")
require.NoError(t, err)
peer2 := wgtypes.Peer{
Endpoint: &net.UDPAddr{
IP: net.ParseIP("::2"),
@ -46,6 +48,9 @@ func TestWGSD(t *testing.T) {
PublicKey: key2,
}
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
enc, err := getEncoder("b32")
require.NoError(t, err)
p := &WGSD{
Next: test.ErrorHandler(),
client: &mockClient{
@ -53,6 +58,7 @@ func TestWGSD(t *testing.T) {
},
zone: "example.com.",
device: "wg0",
enc: enc,
}
testCases := []test.Case{
@ -122,25 +128,22 @@ func TestWGSD(t *testing.T) {
m := tc.Msg()
rec := dnstest.NewRecorder(&test.ResponseWriter{})
ctx := context.TODO()
_, err := p.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("Expected no error, got %v", err)
return
}
require.NoError(t, err)
resp := rec.Msg
if err := test.Header(tc, resp); err != nil {
t.Error(err)
return
}
if err := test.Section(tc, test.Answer, resp.Answer); err != nil {
t.Error(err)
}
if err := test.Section(tc, test.Ns, resp.Ns); err != nil {
t.Error(err)
}
if err := test.Section(tc, test.Extra, resp.Extra); err != nil {
t.Error(err)
}
err = test.Header(tc, resp)
require.NoError(t, err)
err = test.Section(tc, test.Answer, resp.Answer)
require.NoError(t, err)
err = test.Section(tc, test.Ns, resp.Ns)
require.NoError(t, err)
err = test.Section(tc, test.Extra, resp.Extra)
require.NoError(t, err)
})
}
}