package node import ( "fmt" "log" "net/netip" "strings" "sync/atomic" "time" "vppn/m" "git.crumpington.com/lib/go/ratelimiter" ) const ( pingInterval = 8 * time.Second timeoutInterval = 30 * time.Second ) // ---------------------------------------------------------------------------- func startPeerSuper() { peers := [256]peerState{} for i := range peers { data := &peerStateData{ published: routingTable[i], remoteIP: byte(i), buf1: make([]byte, bufferSize), buf2: make([]byte, bufferSize), limiter: ratelimiter.New(ratelimiter.Config{ FillPeriod: 20 * time.Millisecond, MaxWaitCount: 1, }), } peers[i] = data.OnPeerUpdate(nil) } go runPeerSuper(peers) } func runPeerSuper(peers [256]peerState) { for raw := range messages { switch msg := raw.(type) { case peerUpdateMsg: peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer) case controlMsg[synPacket]: peers[msg.SrcIP].OnSyn(msg) case controlMsg[ackPacket]: peers[msg.SrcIP].OnAck(msg) case controlMsg[probePacket]: peers[msg.SrcIP].OnProbe(msg) case controlMsg[localDiscoveryPacket]: peers[msg.SrcIP].OnLocalDiscovery(msg) case pingTimerMsg: publicAddrs.Clean() for i := range peers { if newState := peers[i].OnPingTimer(); newState != nil { peers[i] = newState } } default: log.Printf("WARNING: unknown message type: %+v", msg) } } } // ---------------------------------------------------------------------------- type peerState interface { OnPeerUpdate(*m.Peer) peerState OnSyn(controlMsg[synPacket]) OnAck(controlMsg[ackPacket]) OnProbe(controlMsg[probePacket]) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) OnPingTimer() peerState } // ---------------------------------------------------------------------------- type peerStateData struct { // The purpose of this state machine is to manage this published data. published *atomic.Pointer[peerRoute] staged peerRoute // Local copy of shared data. See publish(). // Immutable data. remoteIP byte // Remote VPN IP. // Mutable peer data. peer *m.Peer remotePub bool // Buffers for sending control packets. buf1 []byte buf2 []byte // For logging. Set per-state. client bool // We rate limit per remote endpoint because if we don't we tend to lose // packets. limiter *ratelimiter.Limiter } // ---------------------------------------------------------------------------- func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { s._sendControlPacket(pkt, s.staged) } func (s *peerStateData) sendControlPacketTo(pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort) { if !addr.IsValid() { s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) return } route := s.staged route.Direct = true route.RemoteAddr = addr s._sendControlPacket(pkt, route) } func (s *peerStateData) _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute) { if err := s.limiter.Limit(); err != nil { s.logf("Not sending control packet: rate limited.") // Shouldn't happen. return } _sendControlPacket(pkt, route, s.buf1, s.buf2) } // ---------------------------------------------------------------------------- func (s *peerStateData) publish() { data := s.staged s.published.Store(&data) } func (s *peerStateData) logf(format string, args ...any) { b := strings.Builder{} b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name)) if s.client { b.WriteString("CLIENT | ") } else { b.WriteString("SERVER | ") } if s.staged.Direct { b.WriteString("DIRECT | ") } else { b.WriteString("RELAYED | ") } if s.staged.Up { b.WriteString("UP | ") } else { b.WriteString("DOWN | ") } log.Printf(b.String()+format, args...) } // ---------------------------------------------------------------------------- func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { defer s.publish() if peer == nil { return enterStateDisconnected(s) } s.peer = peer s.staged = peerRoute{ IP: s.remoteIP, PubSignKey: peer.PubSignKey, ControlCipher: newControlCipher(privKey, peer.PubKey), DataCipher: newDataCipher(), } s.remotePub = false if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.remotePub = true s.staged.Relay = peer.Relay s.staged.Direct = true s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) } else if localPub { s.staged.Direct = true } if s.remotePub == localPub { if localIP < s.remoteIP { return enterStateServer(s) } return enterStateClient(s) } if s.remotePub { return enterStateClient(s) } return enterStateServer(s) } // ---------------------------------------------------------------------------- type stateDisconnected struct { *peerStateData } func enterStateDisconnected(s *peerStateData) peerState { s.peer = nil s.staged = peerRoute{} s.publish() return &stateDisconnected{s} } func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {} func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {} func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {} func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {} func (s *stateDisconnected) OnPingTimer() peerState { return nil } // ---------------------------------------------------------------------------- type stateServer struct { *stateDisconnected lastSeen time.Time synTraceID uint64 } func enterStateServer(s *peerStateData) peerState { s.client = false return &stateServer{stateDisconnected: &stateDisconnected{s}} } func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { s.lastSeen = time.Now() p := msg.Packet // Before we can respond to this packet, we need to make sure the // route is setup properly. // // The client will update the syn's TraceID whenever there's a change. // The server will follow the client's request. if p.TraceID != s.synTraceID || !s.staged.Up { s.synTraceID = p.TraceID s.staged.Up = true s.staged.Direct = p.Direct s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) s.staged.RemoteAddr = msg.SrcAddr s.publish() s.logf("Got syn.") } // Always respond. ack := ackPacket{ TraceID: p.TraceID, ToAddr: s.staged.RemoteAddr, PossibleAddrs: publicAddrs.Get(), } s.sendControlPacket(ack) if s.staged.Direct { return } // Not direct => send probes. for _, addr := range p.PossibleAddrs { if !addr.IsValid() { break } s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr) } } func (s *stateServer) OnProbe(msg controlMsg[probePacket]) { if !msg.SrcAddr.IsValid() { s.logf("Invalid probe address.") return } s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) } func (s *stateServer) OnPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { s.staged.Up = false s.publish() s.logf("Connection timeout.") } return nil } // ---------------------------------------------------------------------------- type stateClient struct { *stateDisconnected lastSeen time.Time syn synPacket ack ackPacket probes map[uint64]netip.AddrPort localDiscoveryAddr netip.AddrPort } func enterStateClient(s *peerStateData) peerState { s.client = true ss := &stateClient{ stateDisconnected: &stateDisconnected{s}, probes: map[uint64]netip.AddrPort{}, } ss.syn = synPacket{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), Direct: s.staged.Direct, PossibleAddrs: publicAddrs.Get(), } ss.sendControlPacket(ss.syn) return ss } func (s *stateClient) sendProbeTo(addr netip.AddrPort) { probe := probePacket{TraceID: newTraceID()} s.probes[probe.TraceID] = addr s.sendControlPacketTo(probe, addr) } func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { if msg.Packet.TraceID != s.syn.TraceID { s.logf("Ack has incorrect trace ID") return } s.ack = msg.Packet s.lastSeen = time.Now() if !s.staged.Up { s.staged.Up = true s.logf("Got ack.") s.publish() } // Store possible public address if we're not a public node. if !localPub && s.remotePub { publicAddrs.Store(msg.Packet.ToAddr) } } func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { if s.staged.Direct { return } addr, ok := s.probes[msg.Packet.TraceID] if !ok { return } s.staged.RemoteAddr = addr s.staged.Direct = true s.publish() s.syn.TraceID = newTraceID() s.syn.Direct = true s.syn.PossibleAddrs = [8]netip.AddrPort{} s.sendControlPacket(s.syn) s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) } func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { if s.staged.Direct { return } // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) } func (s *stateClient) OnPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval { if s.staged.Up { s.logf("Connection timeout.") } return s.OnPeerUpdate(s.peer) } s.sendControlPacket(s.syn) if s.staged.Direct { return nil } clear(s.probes) for _, addr := range s.ack.PossibleAddrs { if !addr.IsValid() { break } s.sendProbeTo(addr) } if s.localDiscoveryAddr.IsValid() { s.sendProbeTo(s.localDiscoveryAddr) s.localDiscoveryAddr = netip.AddrPort{} } return nil }