package node import ( "fmt" "log" "net/netip" "time" "vppn/m" ) const ( connectTimeout = 6 * time.Second pingInterval = 6 * time.Second timeoutInterval = 20 * time.Second ) type routingPacketWrapper struct { routingPacket Addr netip.AddrPort // Source. } type peerSupervisor struct { // Constants: localIP byte localPublic bool remoteIP byte privKey []byte // Shared data: w *connWriter table *routingTable packets chan routingPacketWrapper peerUpdates chan *m.Peer // Peer-related items. version int64 // Ony accessed in HandlePeerUpdate. peer *m.Peer remoteAddrPort *netip.AddrPort mediated bool sharedKey []byte // Used by our state functions. pingTimer *time.Timer timeoutTimer *time.Timer buf []byte } // ---------------------------------------------------------------------------- func newPeerSupervisor( conf m.PeerConfig, remoteIP byte, w *connWriter, table *routingTable, ) *peerSupervisor { s := &peerSupervisor{ localIP: conf.PeerIP, remoteIP: remoteIP, privKey: conf.EncPrivKey, w: w, table: table, packets: make(chan routingPacketWrapper, 256), peerUpdates: make(chan *m.Peer, 1), pingTimer: time.NewTimer(pingInterval), timeoutTimer: time.NewTimer(timeoutInterval), buf: make([]byte, bufferSize), } _, s.localPublic = netip.AddrFromSlice(conf.PublicIP) go s.mainLoop() return s } func (s *peerSupervisor) logf(msg string, args ...any) { msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg log.Printf(msg, args...) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) mainLoop() { defer panicHandler() state := s.stateInit for { state = state() } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) { if p != nil { if p.Version == s.version { return } s.version = p.Version } else { s.version = 0 } s.peerUpdates <- p } func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) { select { case s.packets <- w: default: // Drop } } // ---------------------------------------------------------------------------- type stateFunc func() stateFunc func (s *peerSupervisor) stateInit() stateFunc { if s.peer == nil { return s.stateDisconnected } addr, ok := netip.AddrFromSlice(s.peer.PublicIP) if ok { addrPort := netip.AddrPortFrom(addr, s.peer.Port) s.remoteAddrPort = &addrPort } else { s.remoteAddrPort = nil } s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey) return s.stateSelectRole() } // ---------------------------------------------------------------------------- func (s *peerSupervisor) stateDisconnected() stateFunc { s.clearRoutingTable() for { select { case <-s.packets: // Drop case s.peer = <-s.peerUpdates: return s.stateInit } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) stateSelectRole() stateFunc { s.logf("STATE: SelectRole") s.updateRoutingTable(false) if s.remoteAddrPort != nil { s.mediated = false // If both remote and local are public, one side acts as client, and one // side as server. if s.localPublic && s.localIP < s.peer.PeerIP { return s.stateAccept } return s.stateDial } // We're public, remote is not => can only wait for connection if s.localPublic { s.mediated = false return s.stateAccept } // Both non-public => need to use mediator. return s.stateMediated } // ---------------------------------------------------------------------------- func (s *peerSupervisor) stateAccept() stateFunc { s.logf("STATE: Accept") for { select { case pkt := <-s.packets: switch pkt.Type { case packetTypePing: s.remoteAddrPort = &pkt.Addr s.updateRoutingTable(true) s.sendPong(pkt.TraceID) return s.stateConnected default: // Still waiting for ping... } case s.peer = <-s.peerUpdates: return s.stateInit } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) stateDial() stateFunc { s.logf("STATE: Dial") s.updateRoutingTable(false) s.sendPing() for { select { case pkt := <-s.packets: switch pkt.Type { case packetTypePong: s.updateRoutingTable(true) return s.stateConnected default: // Ignore } case <-s.pingTimer.C: s.sendPing() case s.peer = <-s.peerUpdates: return s.stateInit } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) stateConnected() stateFunc { s.logf("STATE: Connected") s.timeoutTimer.Reset(timeoutInterval) for { select { case <-s.pingTimer.C: s.sendPing() case <-s.timeoutTimer.C: s.logf("Timeout") return s.stateInit case pkt := <-s.packets: switch pkt.Type { case packetTypePing: s.sendPong(pkt.TraceID) // Server should always follow remote port. if s.localPublic { if pkt.Addr != *s.remoteAddrPort { s.remoteAddrPort = &pkt.Addr s.updateRoutingTable(true) } } case packetTypePong: s.timeoutTimer.Reset(timeoutInterval) default: // Drop packet. } case s.peer = <-s.peerUpdates: s.logf("New peer: %v", s.peer) return s.stateInit } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) stateMediated() stateFunc { s.logf("STATE: Mediated") s.mediated = true s.updateRoutingTable(true) for { select { case <-s.packets: // Drop. case s.peer = <-s.peerUpdates: s.logf("New peer: %v", s.peer) return s.stateInit } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) clearRoutingTable() { s.table.Set(s.remoteIP, nil) } func (s *peerSupervisor) updateRoutingTable(up bool) { s.table.Set(s.remoteIP, &peer{ Up: up, Mediator: s.peer.Mediator, Mediated: s.mediated, IP: s.remoteIP, Addr: s.remoteAddrPort, SharedKey: s.sharedKey, }) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) sendPing() uint64 { traceID := newTraceID() pkt := newRoutingPacket(packetTypePing, traceID) s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) s.pingTimer.Reset(pingInterval) return traceID } func (s *peerSupervisor) sendPong(traceID uint64) { pkt := newRoutingPacket(packetTypePong, traceID) s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) }