Refactor - now wireguard based. (#7)
This commit is contained in:
225
peer/wginterface/interface.go
Normal file
225
peer/wginterface/interface.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Package wginterface demonstrates creating and destroying a WireGuard network
|
||||
// interface using only raw system calls — no netlink library.
|
||||
//
|
||||
// Creating a typed interface (kind = "wireguard") requires the NETLINK_ROUTE
|
||||
// protocol; there is no ioctl path for it. Everything else — assigning an IP
|
||||
// address and bringing the link up — can be done with the older AF_INET ioctl
|
||||
// interface, exactly as one would for a TUN device.
|
||||
//
|
||||
// The package requires CAP_NET_ADMIN and the wireguard kernel module.
|
||||
package wginterface
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Create creates a WireGuard interface named name, assigns vpnIP/prefixLen to
|
||||
// it, and brings it up.
|
||||
func Create(name string, vpnIP net.IP, prefixLen int) error {
|
||||
_ = Delete(name) // remove any stale interface left by a previous run
|
||||
if err := nlNewLink(name); err != nil {
|
||||
return fmt.Errorf("failed to create wireguard link: %w", err)
|
||||
}
|
||||
if err := ioctlSetAddr(name, vpnIP, prefixLen); err != nil {
|
||||
_ = Delete(name)
|
||||
return fmt.Errorf("assign address: %w", err)
|
||||
}
|
||||
if err := ioctlLinkUp(name); err != nil {
|
||||
_ = Delete(name)
|
||||
return fmt.Errorf("link up: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes the named interface.
|
||||
func Delete(name string) error {
|
||||
return nlDelLink(name)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Netlink link management
|
||||
//
|
||||
// Creating a WireGuard interface requires an RTM_NEWLINK message with a nested
|
||||
// IFLA_LINKINFO attribute whose IFLA_INFO_KIND is "wireguard". The full
|
||||
// message layout is:
|
||||
//
|
||||
// nlmsghdr (16 bytes)
|
||||
// ifinfomsg (16 bytes, all zeros for a new link)
|
||||
// rtattr IFLA_IFNAME → name + \0
|
||||
// rtattr IFLA_LINKINFO
|
||||
// rtattr IFLA_INFO_KIND → "wireguard" + \0
|
||||
//
|
||||
// All multi-byte integers are in native byte order (little-endian on
|
||||
// x86/arm64). Every attribute is padded to a 4-byte boundary; the len field
|
||||
// in the header records the unpadded length but the attribute occupies the
|
||||
// padded size.
|
||||
|
||||
const (
|
||||
nlmsgHdrLen = 16 // sizeof(struct nlmsghdr)
|
||||
sizeofIfInfo = 16 // sizeof(struct ifinfomsg)
|
||||
|
||||
// Attribute types not exposed by the unix package at the level we need.
|
||||
iflaLinkInfo = 18 // IFLA_LINKINFO — container for link-type attributes
|
||||
iflaInfoKind = 1 // IFLA_INFO_KIND — link type string, nested inside IFLA_LINKINFO
|
||||
)
|
||||
|
||||
// nlNewLink creates the wireguard interface using Netlink.
|
||||
func nlNewLink(name string) error {
|
||||
// Build innermost attribute first, then wrap outward.
|
||||
infoKind := nlAttr(iflaInfoKind, cstring("wireguard"))
|
||||
linkInfo := nlAttr(iflaLinkInfo, infoKind)
|
||||
ifName := nlAttr(unix.IFLA_IFNAME, cstring(name))
|
||||
|
||||
// ifinfomsg: all-zero = AF_UNSPEC, no index, no flags (kernel assigns index).
|
||||
ifInfo := make([]byte, sizeofIfInfo)
|
||||
|
||||
payload := slices.Concat(ifInfo, ifName, linkInfo)
|
||||
flags := uint16(unix.NLM_F_REQUEST | unix.NLM_F_ACK | unix.NLM_F_CREATE | unix.NLM_F_EXCL)
|
||||
return nlRoundtrip(unix.RTM_NEWLINK, flags, payload)
|
||||
}
|
||||
|
||||
func nlDelLink(name string) error {
|
||||
iface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For RTM_DELLINK the kernel identifies the link by ifi_index. ifi_index
|
||||
// sits at byte offset 4 in the ifinfomsg struct.
|
||||
ifInfo := make([]byte, sizeofIfInfo)
|
||||
binary.NativeEndian.PutUint32(ifInfo[4:8], uint32(iface.Index))
|
||||
|
||||
return nlRoundtrip(unix.RTM_DELLINK, uint16(unix.NLM_F_REQUEST|unix.NLM_F_ACK), ifInfo)
|
||||
}
|
||||
|
||||
// nlRoundtrip opens a NETLINK_ROUTE socket, sends one request, reads the
|
||||
// NLMSG_ERROR acknowledgement, and closes the socket.
|
||||
func nlRoundtrip(msgType uint16, flags uint16, payload []byte) error {
|
||||
fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("socket: %w", err)
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
if err := unix.Bind(fd, &unix.SockaddrNetlink{Family: unix.AF_NETLINK}); err != nil {
|
||||
return fmt.Errorf("bind: %w", err)
|
||||
}
|
||||
|
||||
msg := nlMsg(msgType, flags, payload)
|
||||
if err := unix.Sendto(fd, msg, 0, &unix.SockaddrNetlink{Family: unix.AF_NETLINK}); err != nil {
|
||||
return fmt.Errorf("sendto: %w", err)
|
||||
}
|
||||
|
||||
resp := make([]byte, 4096)
|
||||
n, _, err := unix.Recvfrom(fd, resp, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("recvfrom: %w", err)
|
||||
}
|
||||
return nlAckErr(resp[:n])
|
||||
}
|
||||
|
||||
// nlMsg prepends an nlmsghdr to payload.
|
||||
func nlMsg(msgType uint16, flags uint16, payload []byte) []byte {
|
||||
buf := make([]byte, nlmsgHdrLen+len(payload))
|
||||
binary.NativeEndian.PutUint32(buf[0:4], uint32(len(buf))) // nlmsg_len
|
||||
binary.NativeEndian.PutUint16(buf[4:6], msgType) // nlmsg_type
|
||||
binary.NativeEndian.PutUint16(buf[6:8], flags) // nlmsg_flags
|
||||
binary.NativeEndian.PutUint32(buf[8:12], 1) // nlmsg_seq
|
||||
binary.NativeEndian.PutUint32(buf[12:16], 0) // nlmsg_pid (0 = kernel)
|
||||
copy(buf[nlmsgHdrLen:], payload)
|
||||
return buf
|
||||
}
|
||||
|
||||
// nlAckErr parses an NLMSG_ERROR response. The error field is a negated errno
|
||||
// (0 = success, -EEXIST = interface exists, etc.).
|
||||
func nlAckErr(resp []byte) error {
|
||||
if len(resp) < nlmsgHdrLen+4 {
|
||||
return fmt.Errorf("netlink response too short (%d bytes)", len(resp))
|
||||
}
|
||||
if binary.NativeEndian.Uint16(resp[4:6]) != unix.NLMSG_ERROR {
|
||||
return fmt.Errorf("unexpected nlmsg_type %d", binary.NativeEndian.Uint16(resp[4:6]))
|
||||
}
|
||||
// Error code follows the nlmsghdr; it is a signed int32 holding -errno.
|
||||
code := int32(binary.NativeEndian.Uint32(resp[nlmsgHdrLen:]))
|
||||
if code != 0 {
|
||||
return unix.Errno(-code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nlAttr encodes one netlink attribute: [len:u16][type:u16][data][pad to 4
|
||||
// bytes]. The len field counts the header + data (before padding); the
|
||||
// allocation is padded so that the next attribute starts on a 4-byte boundary.
|
||||
func nlAttr(attrType uint16, data []byte) []byte {
|
||||
const hdr = 4
|
||||
attrLen := hdr + len(data)
|
||||
padded := (attrLen + 3) &^ 3
|
||||
buf := make([]byte, padded)
|
||||
binary.NativeEndian.PutUint16(buf[0:2], uint16(attrLen))
|
||||
binary.NativeEndian.PutUint16(buf[2:4], attrType)
|
||||
copy(buf[hdr:], data)
|
||||
return buf
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ioctl-based address assignment and link-up
|
||||
//
|
||||
// These operations could also be done via RTM_NEWADDR / RTM_NEWLINK netlink
|
||||
// messages, but the AF_INET ioctl interface is simpler.
|
||||
|
||||
func ioctlSetAddr(name string, ip net.IP, prefixLen int) error {
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
req, err := unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := req.SetInet4Addr(ip.To4()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := unix.IoctlIfreq(fd, unix.SIOCSIFADDR, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err = unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mask := net.CIDRMask(prefixLen, 32)
|
||||
if err := req.SetInet4Addr([]byte(mask)); err != nil {
|
||||
return err
|
||||
}
|
||||
return unix.IoctlIfreq(fd, unix.SIOCSIFNETMASK, req)
|
||||
}
|
||||
|
||||
func ioctlLinkUp(name string) error {
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
req, err := unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := unix.IoctlIfreq(fd, unix.SIOCGIFFLAGS, req); err != nil {
|
||||
return err
|
||||
}
|
||||
req.SetUint16(req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING)
|
||||
return unix.IoctlIfreq(fd, unix.SIOCSIFFLAGS, req)
|
||||
}
|
||||
|
||||
// cstring returns b as a null-terminated byte slice.
|
||||
func cstring(s string) []byte {
|
||||
return append([]byte(s), 0)
|
||||
}
|
||||
184
peer/wginterface/manage.go
Normal file
184
peer/wginterface/manage.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package wginterface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const (
|
||||
// RekeyTimeout is the WireGuard session lifetime before a new handshake
|
||||
// is initiated. Sessions older than this but younger than SessionTimeout
|
||||
// remain valid.
|
||||
RekeyTimeout = 120 * time.Second
|
||||
|
||||
// SessionTimeout is the WireGuard session lifetime after which sessions
|
||||
// are rejected. A peer with LastHandshakeTime older than this is
|
||||
// effectively disconnected.
|
||||
SessionTimeout = 180 * time.Second
|
||||
)
|
||||
|
||||
const ProbeKeepalive = 8 * time.Second
|
||||
|
||||
var zeroKeepalive = time.Duration(0)
|
||||
|
||||
// Device wraps a wgctrl client bound to a named WireGuard interface.
|
||||
type Device struct {
|
||||
client *wgctrl.Client
|
||||
name string
|
||||
}
|
||||
|
||||
// Open attaches to an existing WireGuard interface.
|
||||
func Open(name string) (*Device, error) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wgctrl: %w", err)
|
||||
}
|
||||
return &Device{client: client, name: name}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying wgctrl client.
|
||||
func (d *Device) Close() error {
|
||||
return d.client.Close()
|
||||
}
|
||||
|
||||
// Name returns the interface name.
|
||||
func (d *Device) Name() string {
|
||||
return d.name
|
||||
}
|
||||
|
||||
// Configure sets the device's private key and UDP listen port.
|
||||
func (d *Device) Configure(privKey wgtypes.Key, listenPort int) error {
|
||||
return d.client.ConfigureDevice(d.name, wgtypes.Config{
|
||||
PrivateKey: &privKey,
|
||||
ListenPort: &listenPort,
|
||||
})
|
||||
}
|
||||
|
||||
// Peers returns the current state of all peers on the device.
|
||||
func (d *Device) Peers() ([]wgtypes.Peer, error) {
|
||||
dev, err := d.client.Device(d.name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get device %q: %w", d.name, err)
|
||||
}
|
||||
return dev.Peers, nil
|
||||
}
|
||||
|
||||
// Peer returns the current state of a single peer by public key.
|
||||
func (d *Device) Peer(pubKey wgtypes.Key) (wgtypes.Peer, error) {
|
||||
peers, err := d.Peers()
|
||||
if err != nil {
|
||||
return wgtypes.Peer{}, err
|
||||
}
|
||||
for _, p := range peers {
|
||||
if p.PublicKey == pubKey {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return wgtypes.Peer{}, fmt.Errorf("peer %v not found in %q", pubKey, d.name)
|
||||
}
|
||||
|
||||
// AddPeer registers a peer with no AllowedIPs and no endpoint. WireGuard will
|
||||
// accept handshakes from this peer but route no traffic to it yet.
|
||||
func (d *Device) AddPeer(pubKey wgtypes.Key) error {
|
||||
return d.client.ConfigureDevice(d.name, wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: pubKey,
|
||||
ReplaceAllowedIPs: true,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// SetRelay configures the relay peer with AllowedIPs covering the entire VPN
|
||||
// network prefix. This is the fallback route for all VPN traffic.
|
||||
func (d *Device) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error {
|
||||
masked := network.Masked()
|
||||
a4 := masked.Addr().As4()
|
||||
return d.client.ConfigureDevice(d.name, wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: pubKey,
|
||||
Endpoint: net.UDPAddrFromAddrPort(endpoint),
|
||||
AllowedIPs: []net.IPNet{{
|
||||
IP: net.IP(a4[:]),
|
||||
Mask: net.CIDRMask(masked.Bits(), 32),
|
||||
}},
|
||||
ReplaceAllowedIPs: true,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// AddProbe adds a peer with no AllowedIPs and a 5s keepalive. WireGuard will
|
||||
// attempt handshakes without routing any traffic through this peer yet.
|
||||
func (d *Device) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
|
||||
keepalive := ProbeKeepalive
|
||||
return d.client.ConfigureDevice(d.name, wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: pubKey,
|
||||
Endpoint: net.UDPAddrFromAddrPort(endpoint),
|
||||
AllowedIPs: []net.IPNet{},
|
||||
ReplaceAllowedIPs: true,
|
||||
PersistentKeepaliveInterval: &keepalive,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// Promote upgrades a probe entry to a /32 AllowedIPs and removes the probe
|
||||
// keepalive, causing WireGuard to prefer this peer's direct path over the
|
||||
// relay's wider route.
|
||||
func (d *Device) Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error {
|
||||
a4 := vpnIP.As4()
|
||||
return d.client.ConfigureDevice(d.name, wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: []net.IPNet{{
|
||||
IP: net.IP(a4[:]),
|
||||
Mask: net.CIDRMask(32, 32),
|
||||
}},
|
||||
ReplaceAllowedIPs: true,
|
||||
PersistentKeepaliveInterval: &zeroKeepalive,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// AddDirect adds a peer with a known endpoint and /32 AllowedIPs in one step,
|
||||
// for peers with a stable public endpoint reported by the hub.
|
||||
func (d *Device) AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error {
|
||||
a4 := vpnIP.As4()
|
||||
return d.client.ConfigureDevice(d.name, wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: pubKey,
|
||||
Endpoint: net.UDPAddrFromAddrPort(endpoint),
|
||||
AllowedIPs: []net.IPNet{{
|
||||
IP: net.IP(a4[:]),
|
||||
Mask: net.CIDRMask(32, 32),
|
||||
}},
|
||||
ReplaceAllowedIPs: true,
|
||||
PersistentKeepaliveInterval: &zeroKeepalive,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the device.
|
||||
func (d *Device) RemovePeer(pubKey wgtypes.Key) error {
|
||||
return d.client.ConfigureDevice(d.name, wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: pubKey,
|
||||
Remove: true,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// EnableForwarding enables IPv4 forwarding globally and on the interface,
|
||||
// required for relay peers that forward traffic between VPN peers.
|
||||
func (d *Device) EnableForwarding() error {
|
||||
if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
path := fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/forwarding", d.name)
|
||||
return os.WriteFile(path, []byte("1\n"), 0644)
|
||||
}
|
||||
303
peer/wginterface/manage_test.go
Normal file
303
peer/wginterface/manage_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
//go:build integration
|
||||
|
||||
package wginterface_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"vppn/peer/wginterface"
|
||||
)
|
||||
|
||||
const (
|
||||
testBasePort = 59100
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if os.Getuid() != 0 {
|
||||
fmt.Fprintln(os.Stderr, "wginterface integration tests require root; skipping")
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
type testPeer struct {
|
||||
Name string
|
||||
VpnIP netip.Addr
|
||||
Port int
|
||||
PrivKey wgtypes.Key
|
||||
PubKey wgtypes.Key
|
||||
Dev *wginterface.Device
|
||||
}
|
||||
|
||||
func (p *testPeer) Endpoint() netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), uint16(p.Port))
|
||||
}
|
||||
|
||||
func newTestPeer(t *testing.T, name string, vpnIP netip.Addr, port int) *testPeer {
|
||||
t.Helper()
|
||||
|
||||
privKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
|
||||
a4 := vpnIP.As4()
|
||||
if err := wginterface.Create(name, net.IP(a4[:]), 24); err != nil {
|
||||
t.Fatalf("create %s: %v", name, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := wginterface.Delete(name); err != nil {
|
||||
log.Printf("Failed to delete interface %s: %v", name, err)
|
||||
}
|
||||
})
|
||||
|
||||
dev, err := wginterface.Open(name)
|
||||
if err != nil {
|
||||
t.Fatalf("open %s: %v", name, err)
|
||||
}
|
||||
t.Cleanup(func() { dev.Close() })
|
||||
|
||||
if err := dev.Configure(privKey, port); err != nil {
|
||||
t.Fatalf("configure %s: %v", name, err)
|
||||
}
|
||||
|
||||
return &testPeer{
|
||||
Name: name,
|
||||
VpnIP: vpnIP,
|
||||
Port: port,
|
||||
PrivKey: privKey,
|
||||
PubKey: privKey.PublicKey(),
|
||||
Dev: dev,
|
||||
}
|
||||
}
|
||||
|
||||
// waitHandshake polls until the named peer has completed a handshake or the timeout elapses.
|
||||
func waitHandshake(t *testing.T, dev *wginterface.Device, pubKey wgtypes.Key, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
p, err := dev.Peer(pubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("peer lookup: %v", err)
|
||||
}
|
||||
if !p.LastHandshakeTime.IsZero() {
|
||||
return
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("no handshake within %v", timeout)
|
||||
}
|
||||
|
||||
func TestDirectHandshake(t *testing.T) {
|
||||
p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort)
|
||||
p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1)
|
||||
|
||||
if err := p1.Dev.AddDirect(p2.PubKey, p2.Endpoint(), p2.VpnIP); err != nil {
|
||||
t.Fatalf("p1 AddDirect: %v", err)
|
||||
}
|
||||
if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil {
|
||||
t.Fatalf("p2 AddDirect: %v", err)
|
||||
}
|
||||
|
||||
waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second)
|
||||
waitHandshake(t, p2.Dev, p1.PubKey, 30*time.Second)
|
||||
}
|
||||
|
||||
func TestProbeAndPromote(t *testing.T) {
|
||||
p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort)
|
||||
p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1)
|
||||
|
||||
// p2 needs a peer entry for p1 so it can respond to the handshake initiation.
|
||||
if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil {
|
||||
t.Fatalf("p2 AddDirect: %v", err)
|
||||
}
|
||||
|
||||
if err := p1.Dev.AddProbe(p2.PubKey, p2.Endpoint()); err != nil {
|
||||
t.Fatalf("AddProbe: %v", err)
|
||||
}
|
||||
waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second)
|
||||
|
||||
if err := p1.Dev.Promote(p2.PubKey, p2.VpnIP); err != nil {
|
||||
t.Fatalf("Promote: %v", err)
|
||||
}
|
||||
|
||||
peer, err := p1.Dev.Peer(p2.PubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Peer: %v", err)
|
||||
}
|
||||
checkAllowedIP(t, peer, p2.VpnIP, 32)
|
||||
}
|
||||
|
||||
func TestRelayHandshakes(t *testing.T) {
|
||||
vpnNetwork := netip.MustParsePrefix("192.168.99.0/24")
|
||||
|
||||
relay := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort)
|
||||
peer1 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1)
|
||||
peer2 := newTestPeer(t, "wgtest2", netip.MustParseAddr("192.168.99.3"), testBasePort+2)
|
||||
|
||||
if err := relay.Dev.AddDirect(peer1.PubKey, peer1.Endpoint(), peer1.VpnIP); err != nil {
|
||||
t.Fatalf("relay AddDirect peer1: %v", err)
|
||||
}
|
||||
if err := relay.Dev.AddDirect(peer2.PubKey, peer2.Endpoint(), peer2.VpnIP); err != nil {
|
||||
t.Fatalf("relay AddDirect peer2: %v", err)
|
||||
}
|
||||
if err := peer1.Dev.SetRelay(relay.PubKey, relay.Endpoint(), vpnNetwork); err != nil {
|
||||
t.Fatalf("peer1 SetRelay: %v", err)
|
||||
}
|
||||
if err := peer2.Dev.SetRelay(relay.PubKey, relay.Endpoint(), vpnNetwork); err != nil {
|
||||
t.Fatalf("peer2 SetRelay: %v", err)
|
||||
}
|
||||
|
||||
waitHandshake(t, relay.Dev, peer1.PubKey, 30*time.Second)
|
||||
waitHandshake(t, relay.Dev, peer2.PubKey, 30*time.Second)
|
||||
waitHandshake(t, peer1.Dev, relay.PubKey, 30*time.Second)
|
||||
waitHandshake(t, peer2.Dev, relay.PubKey, 30*time.Second)
|
||||
|
||||
// relay has /32 entries for each peer
|
||||
p, err := relay.Dev.Peer(peer1.PubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("relay peer1: %v", err)
|
||||
}
|
||||
checkAllowedIP(t, p, peer1.VpnIP, 32)
|
||||
|
||||
p, err = relay.Dev.Peer(peer2.PubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("relay peer2: %v", err)
|
||||
}
|
||||
checkAllowedIP(t, p, peer2.VpnIP, 32)
|
||||
|
||||
// peers have /24 fallback route via relay
|
||||
p, err = peer1.Dev.Peer(relay.PubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("peer1 relay: %v", err)
|
||||
}
|
||||
checkAllowedIP(t, p, vpnNetwork.Masked().Addr(), 24)
|
||||
|
||||
p, err = peer2.Dev.Peer(relay.PubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("peer2 relay: %v", err)
|
||||
}
|
||||
checkAllowedIP(t, p, vpnNetwork.Masked().Addr(), 24)
|
||||
}
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort)
|
||||
p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1)
|
||||
|
||||
if err := p1.Dev.AddDirect(p2.PubKey, p2.Endpoint(), p2.VpnIP); err != nil {
|
||||
t.Fatalf("AddDirect: %v", err)
|
||||
}
|
||||
if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil {
|
||||
t.Fatalf("AddDirect: %v", err)
|
||||
}
|
||||
waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second)
|
||||
|
||||
if err := p1.Dev.RemovePeer(p2.PubKey); err != nil {
|
||||
t.Fatalf("RemovePeer: %v", err)
|
||||
}
|
||||
if _, err := p1.Dev.Peer(p2.PubKey); err == nil {
|
||||
t.Fatal("expected error after RemovePeer, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableForwarding(t *testing.T) {
|
||||
p := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort)
|
||||
|
||||
if err := p.Dev.EnableForwarding(); err != nil {
|
||||
t.Fatalf("EnableForwarding: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/forwarding", p.Name))
|
||||
if err != nil {
|
||||
t.Fatalf("read forwarding: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(string(data)) != "1" {
|
||||
t.Fatalf("expected forwarding=1, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromoteKeepalive(t *testing.T) {
|
||||
p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort)
|
||||
p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1)
|
||||
|
||||
if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil {
|
||||
t.Fatalf("p2 AddDirect: %v", err)
|
||||
}
|
||||
if err := p1.Dev.AddProbe(p2.PubKey, p2.Endpoint()); err != nil {
|
||||
t.Fatalf("AddProbe: %v", err)
|
||||
}
|
||||
waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second)
|
||||
|
||||
if err := p1.Dev.Promote(p2.PubKey, p2.VpnIP); err != nil {
|
||||
t.Fatalf("Promote: %v", err)
|
||||
}
|
||||
|
||||
peer, err := p1.Dev.Peer(p2.PubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Peer: %v", err)
|
||||
}
|
||||
if peer.PersistentKeepaliveInterval != 0 {
|
||||
t.Fatalf("expected keepalive disabled after promote, got %v", peer.PersistentKeepaliveInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeersCount(t *testing.T) {
|
||||
relay := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort)
|
||||
peer1 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1)
|
||||
peer2 := newTestPeer(t, "wgtest2", netip.MustParseAddr("192.168.99.3"), testBasePort+2)
|
||||
|
||||
if err := relay.Dev.AddDirect(peer1.PubKey, peer1.Endpoint(), peer1.VpnIP); err != nil {
|
||||
t.Fatalf("AddDirect peer1: %v", err)
|
||||
}
|
||||
if err := relay.Dev.AddDirect(peer2.PubKey, peer2.Endpoint(), peer2.VpnIP); err != nil {
|
||||
t.Fatalf("AddDirect peer2: %v", err)
|
||||
}
|
||||
|
||||
peers, err := relay.Dev.Peers()
|
||||
if err != nil {
|
||||
t.Fatalf("Peers: %v", err)
|
||||
}
|
||||
if len(peers) != 2 {
|
||||
t.Fatalf("expected 2 peers, got %d", len(peers))
|
||||
}
|
||||
|
||||
if err := relay.Dev.RemovePeer(peer1.PubKey); err != nil {
|
||||
t.Fatalf("RemovePeer: %v", err)
|
||||
}
|
||||
|
||||
peers, err = relay.Dev.Peers()
|
||||
if err != nil {
|
||||
t.Fatalf("Peers after remove: %v", err)
|
||||
}
|
||||
if len(peers) != 1 {
|
||||
t.Fatalf("expected 1 peer after remove, got %d", len(peers))
|
||||
}
|
||||
if peers[0].PublicKey != peer2.PubKey {
|
||||
t.Fatal("wrong peer remained after remove")
|
||||
}
|
||||
}
|
||||
|
||||
// checkAllowedIP asserts that a peer has exactly one AllowedIP matching addr/bits.
|
||||
func checkAllowedIP(t *testing.T, p wgtypes.Peer, addr netip.Addr, bits int) {
|
||||
t.Helper()
|
||||
if len(p.AllowedIPs) != 1 {
|
||||
t.Fatalf("expected 1 AllowedIP, got %d", len(p.AllowedIPs))
|
||||
}
|
||||
ones, _ := p.AllowedIPs[0].Mask.Size()
|
||||
if ones != bits {
|
||||
t.Fatalf("expected /%d, got /%d", bits, ones)
|
||||
}
|
||||
got := netip.AddrFrom4([4]byte(p.AllowedIPs[0].IP.To4()))
|
||||
if got != addr {
|
||||
t.Fatalf("expected AllowedIP %v, got %v", addr, got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user