124 lines
3.6 KiB
Go
124 lines
3.6 KiB
Go
package peer
|
|
|
|
import (
|
|
"net/netip"
|
|
"sync"
|
|
"testing"
|
|
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
// fakeWGDevice records every call made to it. It is safe to read Calls after
|
|
// the event loop has processed the event under test (single-threaded loop
|
|
// means no extra synchronisation needed, but the mutex guards concurrent test
|
|
// helpers if needed).
|
|
type fakeWGDevice struct {
|
|
mu sync.Mutex
|
|
Calls []fakeCall
|
|
peers []wgtypes.Peer
|
|
}
|
|
|
|
type fakeCall struct {
|
|
Method string
|
|
PubKey wgtypes.Key
|
|
Endpoint netip.AddrPort
|
|
VPNiP netip.Addr
|
|
Network netip.Prefix
|
|
}
|
|
|
|
func (f *fakeWGDevice) record(c fakeCall) {
|
|
f.mu.Lock()
|
|
f.Calls = append(f.Calls, c)
|
|
f.mu.Unlock()
|
|
}
|
|
|
|
func (f *fakeWGDevice) Name() string { return "wg-test" }
|
|
|
|
func (f *fakeWGDevice) Peers() ([]wgtypes.Peer, error) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
out := make([]wgtypes.Peer, len(f.peers))
|
|
copy(out, f.peers)
|
|
return out, nil
|
|
}
|
|
|
|
func (f *fakeWGDevice) AddPeer(pubKey wgtypes.Key) error {
|
|
f.record(fakeCall{Method: "AddPeer", PubKey: pubKey})
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeWGDevice) AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error {
|
|
f.record(fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeWGDevice) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error {
|
|
f.record(fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeWGDevice) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
|
|
f.record(fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeWGDevice) Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error {
|
|
f.record(fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeWGDevice) RemovePeer(pubKey wgtypes.Key) error {
|
|
f.record(fakeCall{Method: "RemovePeer", PubKey: pubKey})
|
|
return nil
|
|
}
|
|
|
|
// AssertNoCalls fails the test if any dev calls were recorded.
|
|
func (f *fakeWGDevice) AssertNoCalls(t *testing.T) {
|
|
t.Helper()
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
if len(f.Calls) != 0 {
|
|
t.Fatalf("unexpected dev calls: %v", f.Calls)
|
|
}
|
|
}
|
|
|
|
func (f *fakeWGDevice) AssertAddPeer(t *testing.T, i int, pubKey wgtypes.Key) {
|
|
t.Helper()
|
|
f.assertCall(t, i, fakeCall{Method: "AddPeer", PubKey: pubKey})
|
|
}
|
|
|
|
func (f *fakeWGDevice) AssertAddDirect(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) {
|
|
t.Helper()
|
|
f.assertCall(t, i, fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
|
|
}
|
|
|
|
func (f *fakeWGDevice) AssertSetRelay(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) {
|
|
t.Helper()
|
|
f.assertCall(t, i, fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
|
|
}
|
|
|
|
func (f *fakeWGDevice) AssertAddProbe(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort) {
|
|
t.Helper()
|
|
f.assertCall(t, i, fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
|
|
}
|
|
|
|
func (f *fakeWGDevice) AssertPromote(t *testing.T, i int, pubKey wgtypes.Key, vpnIP netip.Addr) {
|
|
t.Helper()
|
|
f.assertCall(t, i, fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
|
|
}
|
|
|
|
func (f *fakeWGDevice) AssertRemovePeer(t *testing.T, i int, pubKey wgtypes.Key) {
|
|
t.Helper()
|
|
f.assertCall(t, i, fakeCall{Method: "RemovePeer", PubKey: pubKey})
|
|
}
|
|
|
|
func (f *fakeWGDevice) assertCall(t *testing.T, i int, c fakeCall) {
|
|
t.Helper()
|
|
if len(f.Calls) <= i {
|
|
t.Fatalf("no call at index %d: %v", i, c)
|
|
}
|
|
if c != f.Calls[i] {
|
|
t.Fatalf("call[%d]: got %v, want %v", i, f.Calls[i], c)
|
|
}
|
|
}
|