package node import ( "fmt" "log" "net/netip" "sync/atomic" "time" "vppn/m" ) const ( pingInterval = 8 * time.Second timeoutInterval = 25 * time.Second ) // ---------------------------------------------------------------------------- type peerSupervisor struct { // The purpose of this state machine is to manage this published data. published *atomic.Pointer[peerRoute] staged peerRoute // Local copy of shared data. See publish(). // Immutable data. remoteIP byte // Remote VPN IP. // Mutable peer data. peer *m.Peer remotePub bool // Incoming events. peerUpdates chan *m.Peer controlPackets chan controlPacket // Buffers for sending control packets. buf1 []byte buf2 []byte } func newPeerSupervisor(i int) *peerSupervisor { return &peerSupervisor{ published: routingTable[i], remoteIP: byte(i), peerUpdates: peerUpdates[i], controlPackets: controlPackets[i], buf1: make([]byte, bufferSize), buf2: make([]byte, bufferSize), } } type stateFunc func() stateFunc func (s *peerSupervisor) Run() { state := s.noPeer for { state = state() } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { _sendControlPacket(pkt, s.staged, s.buf1, s.buf2) time.Sleep(500 * time.Millisecond) // Rate limit packets. } func (s *peerSupervisor) sendControlPacketTo( pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort, ) { if !addr.IsValid() { s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) return } route := s.staged route.Direct = true route.RemoteAddr = addr _sendControlPacket(pkt, route, s.buf1, s.buf2) time.Sleep(500 * time.Millisecond) // Rate limit packets. } // ---------------------------------------------------------------------------- func (s *peerSupervisor) logf(msg string, args ...any) { log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) publish() { data := s.staged s.published.Store(&data) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) noPeer() stateFunc { return s.peerUpdate(<-s.peerUpdates) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc { return func() stateFunc { return s._peerUpdate(peer) } } func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { defer s.publish() s.peer = peer s.staged = peerRoute{} if s.peer == nil { return s.noPeer } s.staged.IP = s.remoteIP s.staged.ControlCipher = newControlCipher(privateKey, peer.PubKey) s.staged.DataCipher = newDataCipher() if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.remotePub = true s.staged.Relay = peer.Relay s.staged.Direct = true s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) } else if localPub { s.staged.Direct = true } if s.remotePub == localPub { if localIP < s.remoteIP { return s.server } return s.client } if s.remotePub { return s.client } return s.server } // ---------------------------------------------------------------------------- func (s *peerSupervisor) server() stateFunc { logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) } logf("DOWN") var ( syn synPacket timeoutTimer = time.NewTimer(timeoutInterval) ) // Timer will be restarted once we have established a connection. timeoutTimer.Stop() for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case synPacket: timeoutTimer.Reset(timeoutInterval) // Before we can respond to this packet, we need to make sure the // route is setup properly. // // The client will update the syn's TraceID whenever there's a change. // The server will follow the client's request. if p.TraceID != syn.TraceID || !s.staged.Up { if p.Direct { logf("UP - Direct") } else { logf("UP - Relayed") } syn = p s.staged.Up = true s.staged.Direct = syn.Direct s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) s.staged.RemoteAddr = pkt.SrcAddr s.publish() } // We should always respond. ack := synAckPacket{ TraceID: syn.TraceID, FromAddr: getLocalAddr(), } s.sendControlPacket(ack) if s.staged.Direct { continue } if !syn.FromAddr.IsValid() { continue } probe := probePacket{TraceID: newTraceID()} s.sendControlPacketTo(probe, syn.FromAddr) case probePacket: if pkt.SrcAddr.IsValid() { s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr) } else { logf("Invalid probe address") } } case <-timeoutTimer.C: logf("Connection timeout") s.staged.Up = false s.publish() } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) client() stateFunc { logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) } logf("DOWN") var ( syn = synPacket{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), Direct: s.staged.Direct, FromAddr: getLocalAddr(), } ack synAckPacket probe probePacket probeAddr netip.AddrPort lAddr netip.AddrPort timeoutTimer = time.NewTimer(timeoutInterval) pingTimer = time.NewTimer(pingInterval) ) defer timeoutTimer.Stop() defer pingTimer.Stop() s.sendControlPacket(syn) for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case synAckPacket: if p.TraceID != syn.TraceID { continue // Hmm... } ack = p timeoutTimer.Reset(timeoutInterval) if !s.staged.Up { if s.staged.Direct { logf("UP - Direct") } else { logf("UP - Relayed") } s.staged.Up = true s.publish() } case probePacket: if s.staged.Direct { continue } if p.TraceID != probe.TraceID { continue } // Upgrade connection. logf("UP - Direct") s.staged.Direct = true s.staged.RemoteAddr = probeAddr s.publish() syn.TraceID = newTraceID() syn.Direct = true syn.FromAddr = getLocalAddr() s.sendControlPacket(syn) } case <-pingTimer.C: // Send syn. syn.FromAddr = getLocalAddr() if syn.FromAddr != lAddr { syn.TraceID = newTraceID() lAddr = syn.FromAddr } s.sendControlPacket(syn) pingTimer.Reset(pingInterval) if s.staged.Direct { continue } if !ack.FromAddr.IsValid() { continue } probe = probePacket{TraceID: newTraceID()} probeAddr = ack.FromAddr s.sendControlPacketTo(probe, ack.FromAddr) case <-timeoutTimer.C: logf("Connection timeout") return s.peerUpdate(s.peer) } } }