package peer import ( "fmt" "log" "net/netip" "strings" "time" "vppn/m" "git.crumpington.com/lib/go/ratelimiter" ) type PeerState interface { OnPeerUpdate(*m.Peer) PeerState OnSyn(controlMsg[PacketSyn]) PeerState OnAck(controlMsg[PacketAck]) OnProbe(controlMsg[PacketProbe]) PeerState OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) OnPingTimer() PeerState } // ---------------------------------------------------------------------------- type State struct { // Output. publish func(RemotePeer) sendControlPacket func(RemotePeer, Marshaller) // Immutable data. localIP byte remoteIP byte privKey []byte localAddr netip.AddrPort // If valid, then local peer is publicly accessible. pubAddrs *pubAddrStore // The purpose of this state machine is to manage the RemotePeer object, // publishing it as necessary. staged RemotePeer // Local copy of shared data. See publish(). // Mutable peer data. peer *m.Peer // We rate limit per remote endpoint because if we don't we tend to lose // packets. limiter *ratelimiter.Limiter } func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { defer func() { // Don't defer directly otherwise s.staged will be evaluated immediately // and won't reflect changes made in the function. s.publish(s.staged) }() if peer == nil { return EnterStateDisconnected(s) } s.peer = peer s.staged.localIP = s.localIP s.staged.IP = peer.PeerIP s.staged.Up = false s.staged.Relay = false s.staged.Direct = false s.staged.DirectAddr = netip.AddrPort{} s.staged.PubSignKey = peer.PubSignKey s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.DataCipher = newDataCipher() if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.staged.Relay = peer.Relay s.staged.Direct = true s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) if s.localAddr.IsValid() && s.localIP < s.remoteIP { return EnterStateServer(s) } return EnterStateClientDirect(s) } if s.localAddr.IsValid() { s.staged.Direct = true return EnterStateServer(s) } if s.localIP < s.remoteIP { return EnterStateServer(s) } return EnterStateClientRelayed(s) } func (s *State) logf(format string, args ...any) { b := strings.Builder{} name := "" if s.peer != nil { name = s.peer.Name } b.WriteString(fmt.Sprintf("%30s: ", name)) if s.staged.Direct { b.WriteString("DIRECT | ") } else { b.WriteString("RELAYED | ") } if s.staged.Up { b.WriteString("UP | ") } else { b.WriteString("DOWN | ") } log.Printf(b.String()+format, args...) } // ---------------------------------------------------------------------------- func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { if !addr.IsValid() { return } route := s.staged route.Direct = true route.DirectAddr = addr s.Send(route, pkt) } func (s *State) Send(peer RemotePeer, pkt Marshaller) { if err := s.limiter.Limit(); err != nil { s.logf("Rate limited.") return } s.sendControlPacket(peer, pkt) } // ---------------------------------------------------------------------------- type StateDisconnected struct{ *State } func EnterStateDisconnected(s *State) PeerState { s.logf("==> Disconnected") s.peer = nil s.staged.Up = false s.staged.Relay = false s.staged.Direct = false s.staged.DirectAddr = netip.AddrPort{} s.staged.PubSignKey = nil s.staged.ControlCipher = nil s.staged.DataCipher = nil s.publish(s.staged) return &StateDisconnected{State: s} } func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState { return nil } func (s *StateDisconnected) OnAck(controlMsg[PacketAck]) {} func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState { return nil } func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} func (s *StateDisconnected) OnPingTimer() PeerState { return nil } // ---------------------------------------------------------------------------- type StateServer struct { *StateDisconnected lastSeen time.Time synTraceID uint64 } func EnterStateServer(s *State) PeerState { s.logf("==> Server") return &StateServer{StateDisconnected: &StateDisconnected{State: s}} } func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { s.lastSeen = time.Now() p := msg.Packet // Before we can respond to this packet, we need to make sure the // route is setup properly. // // The client will update the syn's TraceID whenever there's a change. // The server will follow the client's request. if p.TraceID != s.synTraceID || !s.staged.Up { s.synTraceID = p.TraceID s.staged.Up = true s.staged.Direct = p.Direct s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) s.staged.DirectAddr = msg.SrcAddr s.publish(s.staged) s.logf("Got SYN.") } // Always respond. ack := PacketAck{ TraceID: p.TraceID, ToAddr: s.staged.DirectAddr, PossibleAddrs: s.pubAddrs.Get(), } s.Send(s.staged, ack) if p.Direct { return nil } for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { break } s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) } return nil } func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { if msg.SrcAddr.IsValid() { s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) } return nil } func (s *StateServer) OnPingTimer() PeerState { if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { s.staged.Up = false s.publish(s.staged) s.logf("Timeout.") } return nil } // ---------------------------------------------------------------------------- type StateClientDirect struct { *StateDisconnected lastSeen time.Time syn PacketSyn } func EnterStateClientDirect(s *State) PeerState { s.logf("==> ClientDirect") return NewStateClientDirect(s) } func NewStateClientDirect(s *State) *StateClientDirect { state := &StateClientDirect{ StateDisconnected: &StateDisconnected{s}, lastSeen: time.Now(), // Avoid immediate timeout. } state.syn = PacketSyn{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), Direct: s.staged.Direct, PossibleAddrs: s.pubAddrs.Get(), } state.Send(s.staged, state.syn) return state } func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { if msg.Packet.TraceID != s.syn.TraceID { return } s.lastSeen = time.Now() if !s.staged.Up { s.staged.Up = true s.publish(s.staged) s.logf("Got ACK.") } s.pubAddrs.Store(msg.Packet.ToAddr) } func (s *StateClientDirect) OnPingTimer() PeerState { if time.Since(s.lastSeen) > timeoutInterval { if s.staged.Up { s.staged.Up = false s.publish(s.staged) s.logf("Timeout.") } return s.OnPeerUpdate(s.peer) } s.Send(s.staged, s.syn) return nil } // ---------------------------------------------------------------------------- type StateClientRelayed struct { *StateClientDirect ack PacketAck probes map[uint64]netip.AddrPort localDiscoveryAddr netip.AddrPort } func EnterStateClientRelayed(s *State) PeerState { s.logf("==> ClientRelayed") return &StateClientRelayed{ StateClientDirect: NewStateClientDirect(s), probes: map[uint64]netip.AddrPort{}, } } func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { s.ack = msg.Packet s.StateClientDirect.OnAck(msg) } func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { addr, ok := s.probes[msg.Packet.TraceID] if !ok { return nil } s.staged.DirectAddr = addr s.staged.Direct = true s.publish(s.staged) return EnterStateClientDirect(s.StateClientDirect.State) } func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) } func (s *StateClientRelayed) OnPingTimer() PeerState { if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { return nextState } clear(s.probes) for _, addr := range s.ack.PossibleAddrs { if !addr.IsValid() { break } s.sendProbeTo(addr) } if s.localDiscoveryAddr.IsValid() { s.sendProbeTo(s.localDiscoveryAddr) s.localDiscoveryAddr = netip.AddrPort{} } return nil } func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { probe := PacketProbe{TraceID: newTraceID()} s.probes[probe.TraceID] = addr s.SendTo(probe, addr) }