From fb126c1ded1a707f3901b55150ec825a14af12be Mon Sep 17 00:00:00 2001 From: Julien Balestra Date: Fri, 18 Dec 2020 18:05:24 +0100 Subject: [PATCH] encoding: allow to use sha1, hexa and supports truncation --- README.md | 19 +++++++++-- encoder.go | 77 +++++++++++++++++++++++++++++++++++++++++++ go.mod | 1 + setup.go | 13 ++++++-- setup_test.go | 36 ++++++++++++++++++-- wgsd.go | 91 +++++++++++++++++++++++++-------------------------- wgsd_test.go | 45 +++++++++++++------------ 7 files changed, 207 insertions(+), 75 deletions(-) create mode 100644 encoder.go diff --git a/README.md b/README.md index e7c70cd..7506940 100644 --- a/README.md +++ b/README.md @@ -29,12 +29,25 @@ A basic client is available under [cmd/wgsd-client](cmd/wgsd-client). ## Configuration Syntax ``` -wgsd ZONE DEVICE +wgsd ZONE DEVICE ENCODING ``` ## Querying -Following RFC6763 this plugin provides a listing of peers via PTR records at the namespace `_wireguard._udp.`. The target for the PTR records is `._wireguard._udp.` which corresponds to SRV records. SRV targets are of the format `.`. When querying the SRV record for a peer, the target A/AAAA records will be included in the "additional" section of the response. Public keys are represented in Base32 rather than Base64 to allow for their use in node names where they are treated as case-insensitive by the DNS. +Following RFC6763 this plugin provides a listing of peers via PTR records at the namespace `_wireguard._udp.`. The target for the PTR records is `<$encodingPubKey>._wireguard._udp.` which corresponds to SRV records. SRV targets are of the format `.`. When querying the SRV record for a peer, the target A/AAAA records will be included in the "additional" section of the response. Public keys are represented in the chosen encoding rather than original Base64 human representation to allow for their use in node names where they are treated as case-insensitive by the DNS. + +The supported encoding settings are: +* base32 `b32` +* hexadecimal `hex` +* sha1 `sha1` + +Truncation, for example route53 restricts its record maximum size, to keep consistency with these setup limitation, wgsd supports truncation. This truncation could be compared to a short git sha1. + +The supported encoding truncation are: +* hexadecimal - e.g.: `hex:7` +* sha1 - e.g: `sha1:9` + +Note that the sha1 and truncation allows to obfuscate peers public keys. ## Example @@ -42,7 +55,7 @@ This configuration: ``` $ cat Corefile .:5353 { - wgsd example.com. wg0 + wgsd example.com. wg0 b32 } ``` diff --git a/encoder.go b/encoder.go new file mode 100644 index 0000000..f3dc724 --- /dev/null +++ b/encoder.go @@ -0,0 +1,77 @@ +package wgsd + +import ( + "crypto/sha1" + "encoding/base32" + "encoding/hex" + "errors" + "fmt" + "strconv" + "strings" +) + +type encoder interface { + EncodeToString([]byte) string +} + +func getEncoder(e string) (encoder, error) { + parts := strings.Split(e, ":") + if len(parts) > 2 { + return nil, errors.New("failed to parse encoder") + } + name := parts[0] + if len(parts) == 1 { + return buildEncoder(name, 0) + } + trunc, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, err + } + if trunc < 0 { + return nil, errors.New("truncation value is < 0") + } + return buildEncoder(name, trunc) +} + +func buildEncoder(name string, trunc int) (encoder, error) { + switch name { + case "b32": + if trunc != 0 { + return nil, fmt.Errorf("%s doesn't support truncation", name) + } + return base32.StdEncoding, nil + case "sha1": + return &shaOne{trunc: trunc}, nil + case "hex": + return &hexa{trunc: trunc}, nil + default: + return nil, errors.New("invalid encoder") + } +} + +type shaOne struct { + trunc int +} + +func (e *shaOne) EncodeToString(b []byte) string { + h := sha1.New() + _, _ = h.Write(b) + sum := h.Sum(nil) + r := hex.EncodeToString(sum) + if e.trunc == 0 || len(r) < e.trunc { + return r + } + return r[:e.trunc] +} + +type hexa struct { + trunc int +} + +func (e *hexa) EncodeToString(b []byte) string { + r := hex.EncodeToString(b) + if e.trunc == 0 || len(r) < e.trunc { + return r + } + return r[:e.trunc] +} diff --git a/go.mod b/go.mod index 925d308..fcdf685 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ require ( github.com/coredns/caddy v1.1.0 github.com/coredns/coredns v1.8.0 github.com/miekg/dns v1.1.34 + github.com/stretchr/testify v1.4.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20200511024508-91d9787b944f ) diff --git a/setup.go b/setup.go index ee057ae..a5f3af0 100644 --- a/setup.go +++ b/setup.go @@ -19,16 +19,24 @@ func setup(c *caddy.Controller) error { // return an error if there is no zone specified if !c.NextArg() { - return plugin.Error("wgsd", c.ArgErr()) + return plugin.Error("wgsd", fmt.Errorf("missing zone: %v", c.ArgErr())) } zone := dns.Fqdn(c.Val()) // return an error if there is no device name specified if !c.NextArg() { - return plugin.Error("wgsd", c.ArgErr()) + return plugin.Error("wgsd", fmt.Errorf("missing wireguard device name: %v", c.ArgErr())) } device := c.Val() + if !c.NextArg() { + return plugin.Error("wgsd", fmt.Errorf("missing wireguard public key encoding: %v", c.ArgErr())) + } + enc, err := getEncoder(c.Val()) + if err != nil { + return plugin.Error("wgsd", fmt.Errorf("unsupported wireguard public key encoding: %v", err)) + } + // return an error if there are more tokens on this line if c.NextArg() { return plugin.Error("wgsd", c.ArgErr()) @@ -48,6 +56,7 @@ func setup(c *caddy.Controller) error { client: client, zone: zone, device: device, + enc: enc, } }) diff --git a/setup_test.go b/setup_test.go index 39e7410..be17f66 100644 --- a/setup_test.go +++ b/setup_test.go @@ -13,10 +13,40 @@ func TestSetup(t *testing.T) { expectErr bool }{ { - "valid input", - "wgsd example.com. wg0", + "valid input b32", + "wgsd example.com. wg0 b32", false, }, + { + "valid input sha1", + "wgsd example.com. wg0 sha1", + false, + }, + { + "valid input hex", + "wgsd example.com. wg0 hex", + false, + }, + { + "valid input sha1 truncate", + "wgsd example.com. wg0 sha1:7", + false, + }, + { + "valid input hex truncate", + "wgsd example.com. wg0 hex:7", + false, + }, + { + "valid input hex truncate", + "wgsd example.com. wg0 hex:-1", + true, + }, + { + "invalid input b32 truncate", + "wgsd example.com. wg0 b32:7", + true, + }, { "missing token", "wgsd example.com.", @@ -24,7 +54,7 @@ func TestSetup(t *testing.T) { }, { "too many tokens", - "wgsd example.com. wg0 extra", + "wgsd example.com. wg0 b32 extra", true, }, } diff --git a/wgsd.go b/wgsd.go index a63529e..b5f2fde 100644 --- a/wgsd.go +++ b/wgsd.go @@ -2,7 +2,6 @@ package wgsd import ( "context" - "encoding/base32" "fmt" "net" "strings" @@ -28,6 +27,9 @@ type WGSD struct { zone string // the Wireguard device name, e.g. wg0 device string + + // the encoder used to encode wireguard peer public keys + enc encoder } type wgctrlClient interface { @@ -35,9 +37,7 @@ type wgctrlClient interface { } const ( - keyLen = 56 // the number of characters in a base32-encoded Wireguard public key - spPrefix = "_wireguard._udp." - serviceInstanceLen = keyLen + len(".") + len(spPrefix) + spPrefix = "_wireguard._udp." ) func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, @@ -68,7 +68,7 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, } // setup our reply message - m := new(dns.Msg) + m := &dns.Msg{} m.SetReply(r) m.Authoritative = true @@ -87,56 +87,55 @@ func (p *WGSD) ServeDNS(ctx context.Context, w dns.ResponseWriter, Ttl: 0, }, Ptr: fmt.Sprintf("%s.%s%s", - strings.ToLower(base32.StdEncoding.EncodeToString(peer.PublicKey[:])), + strings.ToLower(p.enc.EncodeToString(peer.PublicKey[:])), spPrefix, p.zone), }) } - w.WriteMsg(m) // nolint: errcheck + _ = w.WriteMsg(m) return dns.RcodeSuccess, nil - case len(name) == serviceInstanceLen && qtype == dns.TypeSRV: - pubKey := name[:keyLen] + case qtype == dns.TypeSRV: + pubKey := strings.TrimSuffix(name, "."+spPrefix) for _, peer := range device.Peers { - if strings.EqualFold( - base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { - endpoint := peer.Endpoint - hostRR := getHostRR(pubKey, p.zone, endpoint) - if hostRR == nil { - return nxDomain(p.zone, w, r) - } - m.Extra = append(m.Extra, hostRR) - m.Answer = append(m.Answer, &dns.SRV{ - Hdr: dns.RR_Header{ - Name: state.Name(), - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: 0, - }, - Priority: 0, - Weight: 0, - Port: uint16(endpoint.Port), - Target: fmt.Sprintf("%s.%s", - strings.ToLower(pubKey), p.zone), - }) - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil + if !strings.EqualFold(p.enc.EncodeToString(peer.PublicKey[:]), pubKey) { + continue } + endpoint := peer.Endpoint + hostRR := getHostRR(pubKey, p.zone, endpoint) + if hostRR == nil { + return nxDomain(p.zone, w, r) + } + m.Extra = append(m.Extra, hostRR) + m.Answer = append(m.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: state.Name(), + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 0, + }, + Priority: 0, + Weight: 0, + Port: uint16(endpoint.Port), + Target: fmt.Sprintf("%s.%s", + strings.ToLower(pubKey), p.zone), + }) + _ = w.WriteMsg(m) + return dns.RcodeSuccess, nil } return nxDomain(p.zone, w, r) - case len(name) == keyLen+1 && (qtype == dns.TypeA || - qtype == dns.TypeAAAA): - pubKey := name[:keyLen] + case qtype == dns.TypeA || qtype == dns.TypeAAAA: + pubKey := strings.TrimSuffix(name, ".") for _, peer := range device.Peers { - if strings.EqualFold( - base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) { - endpoint := peer.Endpoint - hostRR := getHostRR(pubKey, p.zone, endpoint) - if hostRR == nil { - return nxDomain(p.zone, w, r) - } - m.Answer = append(m.Answer, hostRR) - w.WriteMsg(m) // nolint: errcheck - return dns.RcodeSuccess, nil + if !strings.EqualFold(p.enc.EncodeToString(peer.PublicKey[:]), pubKey) { + continue } + endpoint := peer.Endpoint + hostRR := getHostRR(pubKey, p.zone, endpoint) + if hostRR == nil { + return nxDomain(p.zone, w, r) + } + m.Answer = append(m.Answer, hostRR) + _ = w.WriteMsg(m) + return dns.RcodeSuccess, nil } return nxDomain(p.zone, w, r) default: @@ -182,7 +181,7 @@ func nxDomain(zone string, w dns.ResponseWriter, r *dns.Msg) (int, error) { m.Authoritative = true m.Rcode = dns.RcodeNameError m.Ns = []dns.RR{soa(zone)} - w.WriteMsg(m) // nolint: errcheck + _ = w.WriteMsg(m) return dns.RcodeSuccess, nil } diff --git a/wgsd_test.go b/wgsd_test.go index 9ef5e56..d029bfd 100644 --- a/wgsd_test.go +++ b/wgsd_test.go @@ -11,6 +11,7 @@ import ( "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -26,8 +27,8 @@ func (m *mockClient) Device(d string) (*wgtypes.Device, error) { } func TestWGSD(t *testing.T) { - key1 := [32]byte{} - key1[0] = 1 + key1, err := wgtypes.ParseKey("JeZlz14G8tg1Bqh6apteFCwVhNhpexJ19FDPfuxQtUY=") + require.NoError(t, err) peer1 := wgtypes.Peer{ Endpoint: &net.UDPAddr{ IP: net.ParseIP("1.1.1.1"), @@ -36,8 +37,9 @@ func TestWGSD(t *testing.T) { PublicKey: key1, } peer1b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer1.PublicKey[:])) - key2 := [32]byte{} - key2[0] = 2 + + key2, err := wgtypes.ParseKey("xScVkH3fUGUv4RrJFfmcqm8rs3SEHr41km6+yffAHw4=") + require.NoError(t, err) peer2 := wgtypes.Peer{ Endpoint: &net.UDPAddr{ IP: net.ParseIP("::2"), @@ -46,6 +48,9 @@ func TestWGSD(t *testing.T) { PublicKey: key2, } peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:])) + + enc, err := getEncoder("b32") + require.NoError(t, err) p := &WGSD{ Next: test.ErrorHandler(), client: &mockClient{ @@ -53,6 +58,7 @@ func TestWGSD(t *testing.T) { }, zone: "example.com.", device: "wg0", + enc: enc, } testCases := []test.Case{ @@ -122,25 +128,22 @@ func TestWGSD(t *testing.T) { m := tc.Msg() rec := dnstest.NewRecorder(&test.ResponseWriter{}) ctx := context.TODO() + _, err := p.ServeDNS(ctx, rec, m) - if err != nil { - t.Errorf("Expected no error, got %v", err) - return - } + require.NoError(t, err) + resp := rec.Msg - if err := test.Header(tc, resp); err != nil { - t.Error(err) - return - } - if err := test.Section(tc, test.Answer, resp.Answer); err != nil { - t.Error(err) - } - if err := test.Section(tc, test.Ns, resp.Ns); err != nil { - t.Error(err) - } - if err := test.Section(tc, test.Extra, resp.Extra); err != nil { - t.Error(err) - } + err = test.Header(tc, resp) + require.NoError(t, err) + + err = test.Section(tc, test.Answer, resp.Answer) + require.NoError(t, err) + + err = test.Section(tc, test.Ns, resp.Ns) + require.NoError(t, err) + + err = test.Section(tc, test.Extra, resp.Extra) + require.NoError(t, err) }) } }