300 lines
8.5 KiB
Go
300 lines
8.5 KiB
Go
package peer
|
|
|
|
import (
|
|
"net/netip"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
"vppn/m"
|
|
)
|
|
|
|
func mustKey(t *testing.T) wgtypes.Key {
|
|
t.Helper()
|
|
k, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("generate key: %v", err)
|
|
}
|
|
return k.PublicKey()
|
|
}
|
|
|
|
func TestOnAddPeer(t *testing.T) {
|
|
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
|
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
|
peerVPNIP := netip.MustParseAddr("10.0.0.2")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
setup func(a *App, key wgtypes.Key)
|
|
peer func(key wgtypes.Key) m.Peer
|
|
check func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key)
|
|
}{
|
|
{
|
|
name: "non-public peer registered in WG via AddPeer",
|
|
peer: func(k wgtypes.Key) m.Peer {
|
|
return m.Peer{WGPubKey: k, PeerIP: 2}
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
|
p := a.peersByKey[key]
|
|
if p == nil {
|
|
t.Fatal("not in peersByKey")
|
|
}
|
|
if a.peersByIP[peerVPNIP] == nil {
|
|
t.Fatal("not in peersByIP")
|
|
}
|
|
if p.State != StateRelayed {
|
|
t.Fatalf("state = %v, want StateRelayed", p.State)
|
|
}
|
|
dev.AssertAddPeer(t, 0, key)
|
|
},
|
|
},
|
|
{
|
|
name: "public peer with endpoint registered via AddDirect",
|
|
peer: func(k wgtypes.Key) m.Peer {
|
|
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
|
p := a.peersByKey[key]
|
|
if p == nil {
|
|
t.Fatal("not in peersByKey")
|
|
}
|
|
dev.AssertAddDirect(t, 0, p.PubKey(), ep1, p.VPNIP)
|
|
},
|
|
},
|
|
{
|
|
name: "re-add removes old WG entry before adding new one",
|
|
setup: func(a *App, key wgtypes.Key) {
|
|
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
|
|
},
|
|
peer: func(k wgtypes.Key) m.Peer {
|
|
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep2.Addr(), Port: ep2.Port()}
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
|
if len(dev.Calls) != 2 {
|
|
t.Fatalf("dev calls = %v, want [RemovePeer, AddDirect]", dev.Calls)
|
|
}
|
|
dev.AssertRemovePeer(t, 0, key)
|
|
dev.AssertAddDirect(t, 1, key, ep2, peerVPNIP)
|
|
if len(a.peersByKey) != 1 || len(a.peersByIP) != 1 {
|
|
t.Errorf("maps: peersByKey=%d peersByIP=%d, want 1 each", len(a.peersByKey), len(a.peersByIP))
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
|
key := mustKey(t)
|
|
if tc.setup != nil {
|
|
tc.setup(a, key)
|
|
dev.Calls = nil
|
|
}
|
|
a.onAddPeer(tc.peer(key))
|
|
tc.check(t, a, dev, key)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOnRemovePeer(t *testing.T) {
|
|
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
|
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
setup func(t *testing.T, a *App) wgtypes.Key // returns the key to remove
|
|
check func(t *testing.T, a *App, dev *fakeWGDevice)
|
|
}{
|
|
{
|
|
name: "unknown key is a no-op",
|
|
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
return mustKey(t)
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
dev.AssertNoCalls(t)
|
|
if len(a.peersByKey) != 0 {
|
|
t.Errorf("peersByKey should be empty")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "StateRelayed peer removed from maps with RemovePeer",
|
|
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
key := mustKey(t)
|
|
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2})
|
|
return key
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 1 {
|
|
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
|
}
|
|
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
|
|
t.Errorf("maps should be empty after remove")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "StateDirect peer removed from maps with RemovePeer",
|
|
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
key := mustKey(t)
|
|
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
|
|
return key
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 1 {
|
|
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
|
}
|
|
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
|
|
t.Errorf("maps should be empty after remove")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "removing active relay with no backup clears relay field",
|
|
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
relay := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
a.relay = relay
|
|
return relay.PubKey()
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 1 {
|
|
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
|
}
|
|
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
if a.relay != nil {
|
|
t.Errorf("relay should be nil after removing only relay")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "removing active relay elects backup via SetRelay",
|
|
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
relay1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
addRelayPeer(t, a, "10.0.0.11", ep2)
|
|
a.relay = relay1
|
|
return relay1.PubKey()
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 2 {
|
|
t.Fatalf("dev calls = %v, want [RemovePeer, SetRelay]", dev.Calls)
|
|
}
|
|
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
|
|
if a.relay == nil {
|
|
t.Errorf("relay should be set to backup after failover")
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
|
key := tc.setup(t, a)
|
|
dev.Calls = nil
|
|
a.onRemovePeer(key)
|
|
tc.check(t, a, dev)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSwitchActiveRelay(t *testing.T) {
|
|
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
|
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
setup func(t *testing.T, a *App)
|
|
check func(t *testing.T, a *App, dev *fakeWGDevice)
|
|
}{
|
|
{
|
|
name: "no candidates leaves relay nil",
|
|
setup: func(t *testing.T, a *App) {},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
dev.AssertNoCalls(t)
|
|
if a.relay != nil {
|
|
t.Error("relay should be nil")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "single candidate elected via SetRelay",
|
|
setup: func(t *testing.T, a *App) {
|
|
addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 1 {
|
|
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
|
}
|
|
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
|
if a.relay == nil {
|
|
t.Error("relay should be set")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "measured RTT beats zero RTT",
|
|
setup: func(t *testing.T, a *App) {
|
|
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
r1.RTT = 10 * time.Millisecond
|
|
addRelayPeer(t, a, "10.0.0.11", ep2) // RTT stays MaxInt64 (unmeaured)
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 1 {
|
|
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
|
}
|
|
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
|
},
|
|
},
|
|
{
|
|
name: "lower RTT wins",
|
|
setup: func(t *testing.T, a *App) {
|
|
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
r1.RTT = 5 * time.Millisecond
|
|
r2 := addRelayPeer(t, a, "10.0.0.11", ep2)
|
|
r2.RTT = 20 * time.Millisecond
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 1 {
|
|
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
|
}
|
|
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
|
},
|
|
},
|
|
{
|
|
name: "stale relay demoted to direct before backup elected",
|
|
setup: func(t *testing.T, a *App) {
|
|
old := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
old.LastPing = time.Time{} // stale — Up() checks LastPing; triggers switch — triggers switch from onTick
|
|
a.relay = old
|
|
addRelayPeer(t, a, "10.0.0.11", ep2)
|
|
},
|
|
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
if len(dev.Calls) != 2 {
|
|
t.Fatalf("dev calls = %v, want [AddDirect, SetRelay]", dev.Calls)
|
|
}
|
|
if dev.Calls[0].Method != "AddDirect" || dev.Calls[0].Endpoint != ep1 {
|
|
t.Errorf("call[0]: got %v, want AddDirect with ep1", dev.Calls[0])
|
|
}
|
|
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
|
|
if a.relay == nil || a.relay.EndpointV4 != ep2 {
|
|
t.Error("relay should be the backup peer")
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
|
tc.setup(t, a)
|
|
dev.Calls = nil
|
|
a.switchActiveRelay()
|
|
tc.check(t, a, dev)
|
|
})
|
|
}
|
|
}
|