Refactor - now wireguard based. (#7)
This commit is contained in:
299
peer/on_hub_test.go
Normal file
299
peer/on_hub_test.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"vppn/m"
|
||||
)
|
||||
|
||||
func mustKey(t *testing.T) wgtypes.Key {
|
||||
t.Helper()
|
||||
k, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
return k.PublicKey()
|
||||
}
|
||||
|
||||
func TestOnAddPeer(t *testing.T) {
|
||||
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
||||
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
||||
peerVPNIP := netip.MustParseAddr("10.0.0.2")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(a *App, key wgtypes.Key)
|
||||
peer func(key wgtypes.Key) m.Peer
|
||||
check func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key)
|
||||
}{
|
||||
{
|
||||
name: "non-public peer registered in WG via AddPeer",
|
||||
peer: func(k wgtypes.Key) m.Peer {
|
||||
return m.Peer{WGPubKey: k, PeerIP: 2}
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
||||
p := a.peersByKey[key]
|
||||
if p == nil {
|
||||
t.Fatal("not in peersByKey")
|
||||
}
|
||||
if a.peersByIP[peerVPNIP] == nil {
|
||||
t.Fatal("not in peersByIP")
|
||||
}
|
||||
if p.State != StateRelayed {
|
||||
t.Fatalf("state = %v, want StateRelayed", p.State)
|
||||
}
|
||||
dev.AssertAddPeer(t, 0, key)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "public peer with endpoint registered via AddDirect",
|
||||
peer: func(k wgtypes.Key) m.Peer {
|
||||
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
||||
p := a.peersByKey[key]
|
||||
if p == nil {
|
||||
t.Fatal("not in peersByKey")
|
||||
}
|
||||
dev.AssertAddDirect(t, 0, p.PubKey(), ep1, p.VPNIP)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "re-add removes old WG entry before adding new one",
|
||||
setup: func(a *App, key wgtypes.Key) {
|
||||
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
|
||||
},
|
||||
peer: func(k wgtypes.Key) m.Peer {
|
||||
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep2.Addr(), Port: ep2.Port()}
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
||||
if len(dev.Calls) != 2 {
|
||||
t.Fatalf("dev calls = %v, want [RemovePeer, AddDirect]", dev.Calls)
|
||||
}
|
||||
dev.AssertRemovePeer(t, 0, key)
|
||||
dev.AssertAddDirect(t, 1, key, ep2, peerVPNIP)
|
||||
if len(a.peersByKey) != 1 || len(a.peersByIP) != 1 {
|
||||
t.Errorf("maps: peersByKey=%d peersByIP=%d, want 1 each", len(a.peersByKey), len(a.peersByIP))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
||||
key := mustKey(t)
|
||||
if tc.setup != nil {
|
||||
tc.setup(a, key)
|
||||
dev.Calls = nil
|
||||
}
|
||||
a.onAddPeer(tc.peer(key))
|
||||
tc.check(t, a, dev, key)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnRemovePeer(t *testing.T) {
|
||||
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
||||
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(t *testing.T, a *App) wgtypes.Key // returns the key to remove
|
||||
check func(t *testing.T, a *App, dev *fakeWGDevice)
|
||||
}{
|
||||
{
|
||||
name: "unknown key is a no-op",
|
||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
||||
return mustKey(t)
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
dev.AssertNoCalls(t)
|
||||
if len(a.peersByKey) != 0 {
|
||||
t.Errorf("peersByKey should be empty")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StateRelayed peer removed from maps with RemovePeer",
|
||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
||||
key := mustKey(t)
|
||||
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2})
|
||||
return key
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 1 {
|
||||
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
||||
}
|
||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
||||
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
|
||||
t.Errorf("maps should be empty after remove")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StateDirect peer removed from maps with RemovePeer",
|
||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
||||
key := mustKey(t)
|
||||
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
|
||||
return key
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 1 {
|
||||
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
||||
}
|
||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
||||
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
|
||||
t.Errorf("maps should be empty after remove")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "removing active relay with no backup clears relay field",
|
||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
||||
relay := addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||
a.relay = relay
|
||||
return relay.PubKey()
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 1 {
|
||||
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
||||
}
|
||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
||||
if a.relay != nil {
|
||||
t.Errorf("relay should be nil after removing only relay")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "removing active relay elects backup via SetRelay",
|
||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
||||
relay1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||
addRelayPeer(t, a, "10.0.0.11", ep2)
|
||||
a.relay = relay1
|
||||
return relay1.PubKey()
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 2 {
|
||||
t.Fatalf("dev calls = %v, want [RemovePeer, SetRelay]", dev.Calls)
|
||||
}
|
||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
||||
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
|
||||
if a.relay == nil {
|
||||
t.Errorf("relay should be set to backup after failover")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
||||
key := tc.setup(t, a)
|
||||
dev.Calls = nil
|
||||
a.onRemovePeer(key)
|
||||
tc.check(t, a, dev)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchActiveRelay(t *testing.T) {
|
||||
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
||||
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(t *testing.T, a *App)
|
||||
check func(t *testing.T, a *App, dev *fakeWGDevice)
|
||||
}{
|
||||
{
|
||||
name: "no candidates leaves relay nil",
|
||||
setup: func(t *testing.T, a *App) {},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
dev.AssertNoCalls(t)
|
||||
if a.relay != nil {
|
||||
t.Error("relay should be nil")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single candidate elected via SetRelay",
|
||||
setup: func(t *testing.T, a *App) {
|
||||
addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 1 {
|
||||
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
||||
}
|
||||
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
||||
if a.relay == nil {
|
||||
t.Error("relay should be set")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "measured RTT beats zero RTT",
|
||||
setup: func(t *testing.T, a *App) {
|
||||
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||
r1.RTT = 10 * time.Millisecond
|
||||
addRelayPeer(t, a, "10.0.0.11", ep2) // RTT stays MaxInt64 (unmeaured)
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 1 {
|
||||
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
||||
}
|
||||
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "lower RTT wins",
|
||||
setup: func(t *testing.T, a *App) {
|
||||
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||
r1.RTT = 5 * time.Millisecond
|
||||
r2 := addRelayPeer(t, a, "10.0.0.11", ep2)
|
||||
r2.RTT = 20 * time.Millisecond
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 1 {
|
||||
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
||||
}
|
||||
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "stale relay demoted to direct before backup elected",
|
||||
setup: func(t *testing.T, a *App) {
|
||||
old := addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||
old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick
|
||||
a.relay = old
|
||||
addRelayPeer(t, a, "10.0.0.11", ep2)
|
||||
},
|
||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
||||
if len(dev.Calls) != 2 {
|
||||
t.Fatalf("dev calls = %v, want [AddDirect, SetRelay]", dev.Calls)
|
||||
}
|
||||
if dev.Calls[0].Method != "AddDirect" || dev.Calls[0].Endpoint != ep1 {
|
||||
t.Errorf("call[0]: got %v, want AddDirect with ep1", dev.Calls[0])
|
||||
}
|
||||
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
|
||||
if a.relay == nil || a.relay.EndpointV4 != ep2 {
|
||||
t.Error("relay should be the backup peer")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
||||
tc.setup(t, a)
|
||||
dev.Calls = nil
|
||||
a.switchActiveRelay()
|
||||
tc.check(t, a, dev)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user