package node import ( "fmt" "log" "math/rand" "net/netip" "sync/atomic" "time" "vppn/m" ) const ( dialTimeout = 8 * time.Second connectTimeout = 6 * time.Second pingInterval = 6 * time.Second timeoutInterval = 20 * time.Second ) // ---------------------------------------------------------------------------- type peerSupervisor struct { // The purpose of this state machine is to manage this published data. published *atomic.Pointer[peerRoute] staged peerRoute // Local copy of shared data. See publish(). // Immutable data. remoteIP byte // Remote VPN IP. // Mutable peer data. peer *m.Peer remotePub bool // Incoming events. peerUpdates chan *m.Peer controlPackets chan controlPacket // Buffers for sending control packets. buf1 []byte buf2 []byte } func newPeerSupervisor(i int) *peerSupervisor { return &peerSupervisor{ published: routingTable[i], remoteIP: byte(i), peerUpdates: peerUpdates[i], controlPackets: controlPackets[i], buf1: make([]byte, bufferSize), buf2: make([]byte, bufferSize), } } type stateFunc func() stateFunc func (s *peerSupervisor) Run() { state := s.noPeer for { state = state() } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { _sendControlPacket(pkt, s.staged, s.buf1, s.buf2) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) logf(msg string, args ...any) { log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) publish() { data := s.staged s.published.Store(&data) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) noPeer() stateFunc { return s.peerUpdate(<-s.peerUpdates) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc { return func() stateFunc { return s._peerUpdate(peer) } } func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { defer s.publish() s.peer = peer s.staged = peerRoute{} if s.peer == nil { return s.noPeer } s.staged.IP = s.remoteIP s.staged.ControlCipher = newControlCipher(privateKey, peer.EncPubKey) s.staged.DataCipher = newDataCipher() if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.remotePub = true s.staged.Relay = peer.Mediator s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) } if s.remotePub == localPub { if localIP < s.remoteIP { return s.serverAccept } return s.clientInit } if s.remotePub { return s.clientInit } return s.serverAccept } // ---------------------------------------------------------------------------- func (s *peerSupervisor) serverAccept() stateFunc { s.logf("STATE: server-accept") s.staged.Up = false s.staged.DataCipher = nil s.staged.RemoteAddr = zeroAddrPort s.staged.RelayIP = 0 s.publish() var syn synPacket for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case synPacket: syn = p s.staged.RemoteAddr = pkt.RemoteAddr s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) s.staged.RelayIP = syn.RelayIP s.publish() s.sendControlPacket(synAckPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr}) case ackPacket: if p.TraceID != syn.TraceID { continue } // Publish. return s.serverConnected(syn.TraceID) } } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) serverConnected(traceID uint64) stateFunc { s.logf("STATE: server-connected") s.staged.Up = true s.publish() return func() stateFunc { return s._serverConnected(traceID) } } func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc { timeoutTimer := time.NewTimer(timeoutInterval) defer timeoutTimer.Stop() for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case ackPacket: if p.TraceID != traceID { return s.serverAccept } s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr}) timeoutTimer.Reset(timeoutInterval) } case <-timeoutTimer.C: s.logf("server timeout") return s.serverAccept } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) clientInit() stateFunc { s.logf("STATE: client-init") if !s.remotePub { // TODO: Check local discovery for IP. // TODO: Attempt UDP hole punch. // TODO: client-relayed return s.clientSelectRelay } return s.clientDial } // ---------------------------------------------------------------------------- func (s *peerSupervisor) clientSelectRelay() stateFunc { s.logf("STATE: client-select-relay") timer := time.NewTimer(0) defer timer.Stop() for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case <-timer.C: relay := s.selectRelay() if relay != nil { s.logf("Got relay: %d", relay.IP) s.staged.RelayIP = relay.IP s.staged.LocalAddr = relay.LocalAddr s.publish() return s.clientDial } s.logf("No relay available.") timer.Reset(pingInterval) } } } func (s *peerSupervisor) selectRelay() *peerRoute { possible := make([]*peerRoute, 0, 8) for i := range routingTable { route := routingTable[i].Load() if !route.Up || !route.Relay { continue } possible = append(possible, route) } if len(possible) == 0 { return nil } return possible[rand.Intn(len(possible))] } // ---------------------------------------------------------------------------- func (s *peerSupervisor) clientDial() stateFunc { s.logf("STATE: client-dial") var ( syn = synPacket{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), RelayIP: s.staged.RelayIP, } timeout = time.NewTimer(dialTimeout) ) defer timeout.Stop() s.sendControlPacket(syn) for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case synAckPacket: if p.TraceID != syn.TraceID { continue // Hmm... } s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr}) return s.clientConnected(p) } case <-timeout.C: return s.clientInit } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) clientConnected(p synAckPacket) stateFunc { s.logf("STATE: client-connected") s.staged.Up = true s.staged.LocalAddr = p.RecvAddr s.publish() return func() stateFunc { return s._clientConnected(p.TraceID) } } func (s *peerSupervisor) _clientConnected(traceID uint64) stateFunc { pingTimer := time.NewTimer(pingInterval) timeoutTimer := time.NewTimer(timeoutInterval) defer pingTimer.Stop() defer timeoutTimer.Stop() for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case ackPacket: if p.TraceID != traceID { return s.clientInit } timeoutTimer.Reset(timeoutInterval) } case <-pingTimer.C: s.sendControlPacket(ackPacket{TraceID: traceID}) pingTimer.Reset(pingInterval) case <-timeoutTimer.C: s.logf("client timeout") return s.clientInit } } }