package node import ( "fmt" "log" "net/netip" "sync/atomic" "time" "vppn/m" ) const ( pingInterval = 8 * time.Second timeoutInterval = 25 * time.Second ) // ---------------------------------------------------------------------------- type peerSupervisor struct { // The purpose of this state machine is to manage this published data. published *atomic.Pointer[peerRoute] staged peerRoute // Local copy of shared data. See publish(). // Immutable data. remoteIP byte // Remote VPN IP. // Mutable peer data. peer *m.Peer remotePub bool // Incoming events. messages chan any // Buffers for sending control packets. buf1 []byte buf2 []byte } func newPeerSupervisor(i int) *peerSupervisor { return &peerSupervisor{ published: routingTable[i], remoteIP: byte(i), messages: messages[i], buf1: make([]byte, bufferSize), buf2: make([]byte, bufferSize), } } type stateFunc func() stateFunc func (s *peerSupervisor) Run() { state := s.noPeer for { state = state() } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { _sendControlPacket(pkt, s.staged, s.buf1, s.buf2) time.Sleep(500 * time.Millisecond) // Rate limit packets. } func (s *peerSupervisor) sendControlPacketTo( pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort, ) { if !addr.IsValid() { s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) return } route := s.staged route.Direct = true route.RemoteAddr = addr _sendControlPacket(pkt, route, s.buf1, s.buf2) time.Sleep(500 * time.Millisecond) // Rate limit packets. } // ---------------------------------------------------------------------------- func (s *peerSupervisor) logf(msg string, args ...any) { log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) publish() { data := s.staged s.published.Store(&data) } // ---------------------------------------------------------------------------- func (s *peerSupervisor) noPeer() stateFunc { for { rawMsg := <-s.messages if msg, ok := rawMsg.(peerUpdateMsg); ok { return s.peerUpdate(msg.Peer) } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc { return func() stateFunc { return s._peerUpdate(peer) } } func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { defer s.publish() s.peer = peer s.staged = peerRoute{} if s.peer == nil { return s.noPeer } s.staged.IP = s.remoteIP s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) s.staged.PubSignKey = peer.PubSignKey s.staged.DataCipher = newDataCipher() if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.remotePub = true s.staged.Relay = peer.Relay s.staged.Direct = true s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) } else if localPub { s.staged.Direct = true } if s.remotePub == localPub { if localIP < s.remoteIP { return s.server } return s.client } if s.remotePub { return s.client } return s.server } // ---------------------------------------------------------------------------- func (s *peerSupervisor) server() stateFunc { logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) } logf("DOWN") var ( syn synPacket lastSeen = time.Now() ) for { rawMsg := <-s.messages switch msg := rawMsg.(type) { case peerUpdateMsg: return s.peerUpdate(msg.Peer) case controlMsg[synPacket]: p := msg.Packet lastSeen = time.Now() // Before we can respond to this packet, we need to make sure the // route is setup properly. // // The client will update the syn's TraceID whenever there's a change. // The server will follow the client's request. if p.TraceID != syn.TraceID || !s.staged.Up { if p.Direct { logf("UP - Direct") } else { logf("UP - Relayed") } syn = p s.staged.Up = true s.staged.Direct = syn.Direct s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) s.staged.RemoteAddr = msg.SrcAddr s.publish() } // We should always respond. ack := synAckPacket{ TraceID: syn.TraceID, FromAddr: getLocalAddr(), } s.sendControlPacket(ack) if s.staged.Direct { continue } if !syn.FromAddr.IsValid() { continue } probe := probePacket{TraceID: newTraceID()} s.sendControlPacketTo(probe, syn.FromAddr) case controlMsg[probePacket]: if !msg.SrcAddr.IsValid() { logf("Invalid probe address") continue } s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) case pingTimerMsg: if time.Since(lastSeen) > timeoutInterval { logf("Connection timeout") s.staged.Up = false s.publish() } } } } // ---------------------------------------------------------------------------- func (s *peerSupervisor) client() stateFunc { logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) } logf("DOWN") var ( syn = synPacket{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), Direct: s.staged.Direct, FromAddr: getLocalAddr(), } lastSeen = time.Now() ack synAckPacket probe probePacket probeAddr netip.AddrPort localProbe probePacket localProbeAddr netip.AddrPort lastLocalAddr netip.AddrPort ) s.sendControlPacket(syn) for { rawMsg := <-s.messages switch msg := rawMsg.(type) { case peerUpdateMsg: return s.peerUpdate(msg.Peer) case controlMsg[synAckPacket]: p := msg.Packet if p.TraceID != syn.TraceID { continue // Hmm... } lastSeen = time.Now() ack = msg.Packet if !s.staged.Up { if s.staged.Direct { logf("UP - Direct") } else { logf("UP - Relayed") } s.staged.Up = true s.publish() } case controlMsg[probePacket]: if s.staged.Direct { continue } p := msg.Packet if p.TraceID != localProbe.TraceID && p.TraceID != probe.TraceID { continue } // Upgrade connection. s.staged.Direct = true if p.TraceID == localProbe.TraceID { logf("UP - Local") s.staged.RemoteAddr = localProbeAddr } else { logf("UP - Direct") s.staged.RemoteAddr = probeAddr } s.publish() syn.TraceID = newTraceID() syn.Direct = true syn.FromAddr = getLocalAddr() s.sendControlPacket(syn) case controlMsg[localDiscoveryPacket]: if s.staged.Direct { continue } // Send probe. // // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. localProbe = probePacket{TraceID: newTraceID()} localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) s.sendControlPacketTo(localProbe, localProbeAddr) case pingTimerMsg: if time.Since(lastSeen) > timeoutInterval { logf("Connection timeout") return s.peerUpdate(s.peer) } syn.FromAddr = getLocalAddr() if syn.FromAddr != lastLocalAddr { syn.TraceID = newTraceID() lastLocalAddr = syn.FromAddr } s.sendControlPacket(syn) if s.staged.Direct { continue } if !ack.FromAddr.IsValid() { continue } probe = probePacket{TraceID: newTraceID()} probeAddr = ack.FromAddr s.sendControlPacketTo(probe, ack.FromAddr) } } }