package peer import ( "bytes" "net/netip" "time" "vppn/m" ) type stateFunc func(msg any) stateFunc type sentProbe struct { SentAt time.Time Addr netip.AddrPort } type remoteFSM struct { *Remote pingTimer *time.Ticker lastSeen time.Time traceID uint64 probes map[uint64]sentProbe sharedKey [32]byte buf []byte } func newRemoteFSM(r *Remote) *remoteFSM { fsm := &remoteFSM{ Remote: r, pingTimer: time.NewTicker(timeoutInterval), probes: map[uint64]sentProbe{}, buf: make([]byte, bufferSize), } fsm.pingTimer.Stop() return fsm } func (r *remoteFSM) Run() { go func() { for range r.pingTimer.C { r.messages <- pingTimerMsg{} } }() state := r.enterDisconnected() for msg := range r.messages { state = state(msg) } } // ---------------------------------------------------------------------------- func (r *remoteFSM) enterDisconnected() stateFunc { r.updateConf(remoteConfig{}) return r.stateDisconnected } func (r *remoteFSM) stateDisconnected(iMsg any) stateFunc { switch msg := iMsg.(type) { case peerUpdateMsg: return r.enterPeerUpdating(msg.Peer) case controlMsg[packetInit]: r.logf("Unexpected INIT") case controlMsg[packetSyn]: r.logf("Unexpected SYN") case controlMsg[packetAck]: r.logf("Unexpected ACK") case controlMsg[packetProbe]: r.logf("Unexpected probe") case controlMsg[packetLocalDiscovery]: // Ignore case pingTimerMsg: r.logf("Unexpected ping") default: r.logf("Ignoring message: %#v", iMsg) } return r.stateDisconnected } // ---------------------------------------------------------------------------- func (r *remoteFSM) enterPeerUpdating(peer *m.Peer) stateFunc { if peer == nil { return r.enterDisconnected() } conf := remoteConfig{ Peer: peer, ControlCipher: newControlCipher(r.PrivKey, peer.PubKey), } r.updateConf(conf) if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { if r.LocalAddrValid && r.LocalPeerIP < peer.PeerIP { return r.enterServer() } return r.enterClientInit() } if r.LocalAddrValid || r.LocalPeerIP < peer.PeerIP { return r.enterServer() } return r.enterClientInit() } // ---------------------------------------------------------------------------- func (r *remoteFSM) enterServer() stateFunc { conf := r.conf() conf.Server = true r.updateConf(conf) r.logf("==> Server") r.pingTimer.Reset(pingInterval) r.lastSeen = time.Now() clear(r.sharedKey[:]) return r.stateServer } func (r *remoteFSM) stateServer(iMsg any) stateFunc { switch msg := iMsg.(type) { case peerUpdateMsg: return r.enterPeerUpdating(msg.Peer) case controlMsg[packetInit]: r.stateServer_onInit(msg) case controlMsg[packetSyn]: r.stateServer_onSyn(msg) case controlMsg[packetAck]: r.logf("Unexpected ACK") case controlMsg[packetProbe]: r.stateServer_onProbe(msg) case controlMsg[packetLocalDiscovery]: // Ignore case pingTimerMsg: r.stateServer_onPingTimer() default: r.logf("Unexpected message: %#v", iMsg) } return r.stateServer } func (r *remoteFSM) stateServer_onInit(msg controlMsg[packetInit]) { conf := r.conf() conf.Up = false conf.Direct = msg.Packet.Direct conf.DirectAddr = msg.SrcAddr r.updateConf(conf) init := packetInit{ TraceID: msg.Packet.TraceID, Direct: conf.Direct, Version: version, } r.sendControl(conf, init.Marshal(r.buf)) } func (r *remoteFSM) stateServer_onSyn(msg controlMsg[packetSyn]) { r.lastSeen = time.Now() p := msg.Packet // Before we can respond to this packet, we need to make sure the // route is setup properly. conf := r.conf() logSyn := !conf.Up || conf.Direct != p.Direct conf.Up = true conf.Direct = p.Direct conf.DirectAddr = msg.SrcAddr // Update data cipher if the key has changed. if !bytes.Equal(r.sharedKey[:], p.SharedKey[:]) { conf.DataCipher = newDataCipherFromKey(p.SharedKey) copy(r.sharedKey[:], p.SharedKey[:]) } r.updateConf(conf) if logSyn { r.logf("Got SYN.") } r.sendControl(conf, packetAck{ TraceID: p.TraceID, ToAddr: conf.DirectAddr, PossibleAddrs: r.PubAddrs.Get(), }.Marshal(r.buf)) if p.Direct { return } // Send probes if not a direct connection. for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { break } r.logf("Probing %v...", addr) r.sendControlToAddr(packetProbe{TraceID: r.NewTraceID()}.Marshal(r.buf), addr) } } func (r *remoteFSM) stateServer_onProbe(msg controlMsg[packetProbe]) { if !msg.SrcAddr.IsValid() { return } data := packetProbe{TraceID: msg.Packet.TraceID}.Marshal(r.buf) r.sendControlToAddr(data, msg.SrcAddr) } func (r *remoteFSM) stateServer_onPingTimer() { conf := r.conf() if time.Since(r.lastSeen) > timeoutInterval && conf.Up { conf.Up = false r.updateConf(conf) r.logf("Timeout.") } } // ---------------------------------------------------------------------------- func (r *remoteFSM) enterClientInit() stateFunc { conf := r.conf() ip, ipValid := netip.AddrFromSlice(conf.Peer.PublicIP) conf.Up = false conf.Server = false conf.Direct = ipValid conf.DirectAddr = netip.AddrPortFrom(ip, conf.Peer.Port) conf.DataCipher = newDataCipher() r.updateConf(conf) r.logf("==> ClientInit") r.lastSeen = time.Now() r.pingTimer.Reset(pingInterval) r.stateClientInit_sendInit() return r.stateClientInit } func (r *remoteFSM) stateClientInit(iMsg any) stateFunc { switch msg := iMsg.(type) { case peerUpdateMsg: return r.enterPeerUpdating(msg.Peer) case controlMsg[packetInit]: return r.stateClientInit_onInit(msg) case controlMsg[packetSyn]: r.logf("Unexpected SYN") case controlMsg[packetAck]: r.logf("Unexpected ACK") case controlMsg[packetProbe]: // Ignore case controlMsg[packetLocalDiscovery]: // Ignore case pingTimerMsg: return r.stateClientInit_onPing() default: r.logf("Unexpected message: %#v", iMsg) } return r.stateClientInit } func (r *remoteFSM) stateClientInit_sendInit() { conf := r.conf() r.traceID = r.NewTraceID() init := packetInit{ TraceID: r.traceID, Direct: conf.Direct, Version: version, } r.sendControl(conf, init.Marshal(r.buf)) } func (r *remoteFSM) stateClientInit_onInit(msg controlMsg[packetInit]) stateFunc { if msg.Packet.TraceID != r.traceID { r.logf("Invalid trace ID on INIT.") return r.stateClientInit } r.logf("Got INIT version %d.", msg.Packet.Version) return r.enterClient() } func (r *remoteFSM) stateClientInit_onPing() stateFunc { if time.Since(r.lastSeen) < timeoutInterval { r.stateClientInit_sendInit() return r.stateClientInit } // Direct connect failed. Try indirect. conf := r.conf() if conf.Direct { conf.Direct = false r.updateConf(conf) r.lastSeen = time.Now() r.stateClientInit_sendInit() r.logf("Direct connection failed. Attempting indirect connection.") return r.stateClientInit } // Indirect failed. Re-enter init state. r.logf("Timeout.") return r.enterClientInit() } // ---------------------------------------------------------------------------- func (r *remoteFSM) enterClient() stateFunc { conf := r.conf() r.probes = make(map[uint64]sentProbe, 8) r.traceID = r.NewTraceID() r.stateClient_sendSyn(conf) r.pingTimer.Reset(pingInterval) r.logf("==> Client") return r.stateClient } func (r *remoteFSM) stateClient(iMsg any) stateFunc { switch msg := iMsg.(type) { case peerUpdateMsg: return r.enterPeerUpdating(msg.Peer) case controlMsg[packetAck]: r.stateClient_onAck(msg) case controlMsg[packetProbe]: r.stateClient_onProbe(msg) case controlMsg[packetLocalDiscovery]: r.stateClient_onLocalDiscovery(msg) case pingTimerMsg: return r.stateClient_onPingTimer() default: r.logf("Ignoring message: %v", iMsg) } return r.stateClient } func (r *remoteFSM) stateClient_onAck(msg controlMsg[packetAck]) { if msg.Packet.TraceID != r.traceID { return } r.lastSeen = time.Now() conf := r.conf() if !conf.Up { conf.Up = true r.updateConf(conf) r.logf("Got ACK.") } if conf.Direct { r.PubAddrs.Store(msg.Packet.ToAddr) return } // Relayed. r.stateClient_cleanProbes() for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { break } r.stateClient_sendProbeTo(addr) } } func (r *remoteFSM) stateClient_cleanProbes() { for key, sent := range r.probes { if time.Since(sent.SentAt) > pingInterval { delete(r.probes, key) } } } func (r *remoteFSM) stateClient_sendProbeTo(addr netip.AddrPort) { probe := packetProbe{TraceID: r.NewTraceID()} r.probes[probe.TraceID] = sentProbe{ SentAt: time.Now(), Addr: addr, } r.logf("Probing %v...", addr) r.sendControlToAddr(probe.Marshal(r.buf), addr) } func (r *remoteFSM) stateClient_onProbe(msg controlMsg[packetProbe]) { conf := r.conf() if conf.Direct { return } r.stateClient_cleanProbes() sent, ok := r.probes[msg.Packet.TraceID] if !ok { return } conf.Direct = true conf.DirectAddr = sent.Addr r.updateConf(conf) r.traceID = r.NewTraceID() r.stateClient_sendSyn(conf) r.logf("Successful probe to %v.", sent.Addr) } func (r *remoteFSM) stateClient_onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { conf := r.conf() if conf.Direct { return } // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), conf.Peer.Port) r.stateClient_sendProbeTo(addr) } func (r *remoteFSM) stateClient_onPingTimer() stateFunc { conf := r.conf() if time.Since(r.lastSeen) > timeoutInterval { if conf.Up { r.logf("Timeout.") } return r.enterClientInit() } r.stateClient_sendSyn(conf) return r.stateClient } func (r *remoteFSM) stateClient_sendSyn(conf remoteConfig) { syn := packetSyn{ TraceID: r.traceID, SharedKey: conf.DataCipher.Key(), Direct: conf.Direct, PossibleAddrs: r.PubAddrs.Get(), } r.sendControl(conf, syn.Marshal(r.buf)) }