package peer import ( "fmt" "log" "net/netip" "strings" "time" "vppn/m" "git.crumpington.com/lib/go/ratelimiter" ) type peerState interface { OnPeerUpdate(*m.Peer) peerState OnSyn(controlMsg[packetSyn]) peerState OnAck(controlMsg[packetAck]) OnProbe(controlMsg[packetProbe]) peerState OnLocalDiscovery(controlMsg[packetLocalDiscovery]) OnPingTimer() peerState } // ---------------------------------------------------------------------------- type pState struct { // Output. publish func(remotePeer) sendControlPacket func(remotePeer, Marshaller) // Immutable data. localIP byte remoteIP byte privKey []byte localAddr netip.AddrPort // If valid, then local peer is publicly accessible. pubAddrs *pubAddrStore // The purpose of this state machine is to manage the RemotePeer object, // publishing it as necessary. staged remotePeer // Local copy of shared data. See publish(). // Mutable peer data. peer *m.Peer // We rate limit per remote endpoint because if we don't we tend to lose // packets. limiter *ratelimiter.Limiter } func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { defer func() { // Don't defer directly otherwise s.staged will be evaluated immediately // and won't reflect changes made in the function. s.publish(s.staged) }() s.peer = peer s.staged.localIP = s.localIP s.staged.Up = false s.staged.Relay = false s.staged.Direct = false s.staged.DirectAddr = netip.AddrPort{} s.staged.PubSignKey = nil s.staged.ControlCipher = nil s.staged.DataCipher = nil if peer == nil { return enterStateDisconnected(s) } s.staged.IP = peer.PeerIP s.staged.PubSignKey = peer.PubSignKey s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.DataCipher = newDataCipher() if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.staged.Relay = peer.Relay s.staged.Direct = true s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) if s.localAddr.IsValid() && s.localIP < s.remoteIP { return enterStateServer(s) } return enterStateClientDirect(s) } if s.localAddr.IsValid() { s.staged.Direct = true return enterStateServer(s) } if s.localIP < s.remoteIP { return enterStateServer(s) } return enterStateClientRelayed(s) } func (s *pState) logf(format string, args ...any) { b := strings.Builder{} name := "" if s.peer != nil { name = s.peer.Name } b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) b.WriteString(fmt.Sprintf("%30s: ", name)) if s.staged.Direct { b.WriteString("DIRECT | ") } else { b.WriteString("RELAYED | ") } if s.staged.Up { b.WriteString("UP | ") } else { b.WriteString("DOWN | ") } log.Printf(b.String()+format, args...) } // ---------------------------------------------------------------------------- func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { if !addr.IsValid() { return } route := s.staged route.Direct = true route.DirectAddr = addr s.Send(route, pkt) } func (s *pState) Send(peer remotePeer, pkt Marshaller) { if err := s.limiter.Limit(); err != nil { s.logf("Rate limited.") return } s.sendControlPacket(peer, pkt) } // ---------------------------------------------------------------------------- type stateDisconnected struct{ *pState } func enterStateDisconnected(s *pState) peerState { return &stateDisconnected{pState: s} } func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s } func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {} func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s } func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {} func (s *stateDisconnected) OnPingTimer() peerState { return s } // ---------------------------------------------------------------------------- type stateServer struct { *stateDisconnected lastSeen time.Time synTraceID uint64 } func enterStateServer(s *pState) peerState { s.logf("==> Server") return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} } func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet // Before we can respond to this packet, we need to make sure the // route is setup properly. // // The client will update the syn's TraceID whenever there's a change. // The server will follow the client's request. if p.TraceID != s.synTraceID || !s.staged.Up { s.synTraceID = p.TraceID s.staged.Up = true s.staged.Direct = p.Direct s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) s.staged.DirectAddr = msg.SrcAddr s.publish(s.staged) s.logf("Got SYN.") } // Always respond. ack := packetAck{ TraceID: p.TraceID, ToAddr: s.staged.DirectAddr, PossibleAddrs: s.pubAddrs.Get(), } s.Send(s.staged, ack) if p.Direct { return s } for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { break } s.SendTo(packetProbe{TraceID: newTraceID()}, addr) } return s } func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState { if msg.SrcAddr.IsValid() { s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) } return s } func (s *stateServer) OnPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { s.staged.Up = false s.publish(s.staged) s.logf("Timeout.") } return s } // ---------------------------------------------------------------------------- type stateClientDirect struct { *stateDisconnected lastSeen time.Time syn packetSyn } func enterStateClientDirect(s *pState) peerState { s.logf("==> ClientDirect") return newStateClientDirect(s) } func newStateClientDirect(s *pState) *stateClientDirect { state := &stateClientDirect{ stateDisconnected: &stateDisconnected{s}, lastSeen: time.Now(), // Avoid immediate timeout. } state.syn = packetSyn{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), Direct: s.staged.Direct, PossibleAddrs: s.pubAddrs.Get(), } state.Send(s.staged, state.syn) return state } func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { if msg.Packet.TraceID != s.syn.TraceID { return } s.lastSeen = time.Now() if !s.staged.Up { s.staged.Up = true s.publish(s.staged) s.logf("Got ACK.") } s.pubAddrs.Store(msg.Packet.ToAddr) } func (s *stateClientDirect) OnPingTimer() peerState { if next := s.onPingTimer(); next != nil { return next } return s } func (s *stateClientDirect) onPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval { if s.staged.Up { s.staged.Up = false s.publish(s.staged) s.logf("Timeout.") } return s.OnPeerUpdate(s.peer) } s.Send(s.staged, s.syn) return nil } // ---------------------------------------------------------------------------- type stateClientRelayed struct { *stateClientDirect ack packetAck probes map[uint64]netip.AddrPort localDiscoveryAddr netip.AddrPort } func enterStateClientRelayed(s *pState) peerState { s.logf("==> ClientRelayed") return &stateClientRelayed{ stateClientDirect: newStateClientDirect(s), probes: map[uint64]netip.AddrPort{}, } } func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { s.ack = msg.Packet s.stateClientDirect.OnAck(msg) } func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { addr, ok := s.probes[msg.Packet.TraceID] if !ok { return s } s.staged.DirectAddr = addr s.staged.Direct = true s.publish(s.staged) return enterStateClientDirect(s.stateClientDirect.pState) } func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) } func (s *stateClientRelayed) OnPingTimer() peerState { if next := s.stateClientDirect.onPingTimer(); next != nil { return next } clear(s.probes) for _, addr := range s.ack.PossibleAddrs { if !addr.IsValid() { break } s.sendProbeTo(addr) } if s.localDiscoveryAddr.IsValid() { s.sendProbeTo(s.localDiscoveryAddr) s.localDiscoveryAddr = netip.AddrPort{} } return s } func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { probe := packetProbe{TraceID: newTraceID()} s.probes[probe.TraceID] = addr s.SendTo(probe, addr) }