package peer import ( "net/netip" "testing" "time" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "vppn/m" ) func mustKey(t *testing.T) wgtypes.Key { t.Helper() k, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatalf("generate key: %v", err) } return k.PublicKey() } func TestOnAddPeer(t *testing.T) { ep1 := netip.MustParseAddrPort("1.2.3.4:51820") ep2 := netip.MustParseAddrPort("5.6.7.8:51820") peerVPNIP := netip.MustParseAddr("10.0.0.2") testCases := []struct { name string setup func(a *App, key wgtypes.Key) peer func(key wgtypes.Key) m.Peer check func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) }{ { name: "non-public peer registered in WG via AddPeer", peer: func(k wgtypes.Key) m.Peer { return m.Peer{WGPubKey: k, PeerIP: 2} }, check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) { p := a.peersByKey[key] if p == nil { t.Fatal("not in peersByKey") } if a.peersByIP[peerVPNIP] == nil { t.Fatal("not in peersByIP") } if p.State != StateRelayed { t.Fatalf("state = %v, want StateRelayed", p.State) } dev.AssertAddPeer(t, 0, key) }, }, { name: "public peer with endpoint registered via AddDirect", peer: func(k wgtypes.Key) m.Peer { return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()} }, check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) { p := a.peersByKey[key] if p == nil { t.Fatal("not in peersByKey") } dev.AssertAddDirect(t, 0, p.PubKey(), ep1, p.VPNIP) }, }, { name: "re-add removes old WG entry before adding new one", setup: func(a *App, key wgtypes.Key) { a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}) }, peer: func(k wgtypes.Key) m.Peer { return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep2.Addr(), Port: ep2.Port()} }, check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) { if len(dev.Calls) != 2 { t.Fatalf("dev calls = %v, want [RemovePeer, AddDirect]", dev.Calls) } dev.AssertRemovePeer(t, 0, key) dev.AssertAddDirect(t, 1, key, ep2, peerVPNIP) if len(a.peersByKey) != 1 || len(a.peersByIP) != 1 { t.Errorf("maps: peersByKey=%d peersByIP=%d, want 1 each", len(a.peersByKey), len(a.peersByIP)) } }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { a, dev, _ := newTestApp(t, "10.0.0.1", false) key := mustKey(t) if tc.setup != nil { tc.setup(a, key) dev.Calls = nil } a.onAddPeer(tc.peer(key)) tc.check(t, a, dev, key) }) } } func TestOnRemovePeer(t *testing.T) { ep1 := netip.MustParseAddrPort("1.2.3.4:51820") ep2 := netip.MustParseAddrPort("5.6.7.8:51820") testCases := []struct { name string setup func(t *testing.T, a *App) wgtypes.Key // returns the key to remove check func(t *testing.T, a *App, dev *fakeWGDevice) }{ { name: "unknown key is a no-op", setup: func(t *testing.T, a *App) wgtypes.Key { return mustKey(t) }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { dev.AssertNoCalls(t) if len(a.peersByKey) != 0 { t.Errorf("peersByKey should be empty") } }, }, { name: "StateRelayed peer removed from maps with RemovePeer", setup: func(t *testing.T, a *App) wgtypes.Key { key := mustKey(t) a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2}) return key }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 1 { t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls) } dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 { t.Errorf("maps should be empty after remove") } }, }, { name: "StateDirect peer removed from maps with RemovePeer", setup: func(t *testing.T, a *App) wgtypes.Key { key := mustKey(t) a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}) return key }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 1 { t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls) } dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 { t.Errorf("maps should be empty after remove") } }, }, { name: "removing active relay with no backup clears relay field", setup: func(t *testing.T, a *App) wgtypes.Key { relay := addRelayPeer(t, a, "10.0.0.10", ep1) a.relay = relay return relay.PubKey() }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 1 { t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls) } dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) if a.relay != nil { t.Errorf("relay should be nil after removing only relay") } }, }, { name: "removing active relay elects backup via SetRelay", setup: func(t *testing.T, a *App) wgtypes.Key { relay1 := addRelayPeer(t, a, "10.0.0.10", ep1) addRelayPeer(t, a, "10.0.0.11", ep2) a.relay = relay1 return relay1.PubKey() }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 2 { t.Fatalf("dev calls = %v, want [RemovePeer, SetRelay]", dev.Calls) } dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet) if a.relay == nil { t.Errorf("relay should be set to backup after failover") } }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { a, dev, _ := newTestApp(t, "10.0.0.1", false) key := tc.setup(t, a) dev.Calls = nil a.onRemovePeer(key) tc.check(t, a, dev) }) } } func TestSwitchActiveRelay(t *testing.T) { ep1 := netip.MustParseAddrPort("1.2.3.4:51820") ep2 := netip.MustParseAddrPort("5.6.7.8:51820") testCases := []struct { name string setup func(t *testing.T, a *App) check func(t *testing.T, a *App, dev *fakeWGDevice) }{ { name: "no candidates leaves relay nil", setup: func(t *testing.T, a *App) {}, check: func(t *testing.T, a *App, dev *fakeWGDevice) { dev.AssertNoCalls(t) if a.relay != nil { t.Error("relay should be nil") } }, }, { name: "single candidate elected via SetRelay", setup: func(t *testing.T, a *App) { addRelayPeer(t, a, "10.0.0.10", ep1) }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 1 { t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls) } dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet) if a.relay == nil { t.Error("relay should be set") } }, }, { name: "measured RTT beats zero RTT", setup: func(t *testing.T, a *App) { r1 := addRelayPeer(t, a, "10.0.0.10", ep1) r1.RTT = 10 * time.Millisecond addRelayPeer(t, a, "10.0.0.11", ep2) // RTT stays MaxInt64 (unmeaured) }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 1 { t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls) } dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet) }, }, { name: "lower RTT wins", setup: func(t *testing.T, a *App) { r1 := addRelayPeer(t, a, "10.0.0.10", ep1) r1.RTT = 5 * time.Millisecond r2 := addRelayPeer(t, a, "10.0.0.11", ep2) r2.RTT = 20 * time.Millisecond }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 1 { t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls) } dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet) }, }, { name: "stale relay demoted to direct before backup elected", setup: func(t *testing.T, a *App) { old := addRelayPeer(t, a, "10.0.0.10", ep1) old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick a.relay = old addRelayPeer(t, a, "10.0.0.11", ep2) }, check: func(t *testing.T, a *App, dev *fakeWGDevice) { if len(dev.Calls) != 2 { t.Fatalf("dev calls = %v, want [AddDirect, SetRelay]", dev.Calls) } if dev.Calls[0].Method != "AddDirect" || dev.Calls[0].Endpoint != ep1 { t.Errorf("call[0]: got %v, want AddDirect with ep1", dev.Calls[0]) } dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet) if a.relay == nil || a.relay.EndpointV4 != ep2 { t.Error("relay should be the backup peer") } }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { a, dev, _ := newTestApp(t, "10.0.0.1", false) tc.setup(t, a) dev.Calls = nil a.switchActiveRelay() tc.check(t, a, dev) }) } }