package node import ( "log" "math/rand" "net/netip" "sync/atomic" "time" "vppn/m" ) const ( connectTimeout = 6 * time.Second pingInterval = 6 * time.Second timeoutInterval = 20 * time.Second ) type stateFunc func() stateFunc type peerSuper struct { *remotePeer peer *m.Peer remotePublic bool peerData peerData pktBuf []byte encBuf []byte } func newPeerSuper(rp *remotePeer) *peerSuper { return &peerSuper{ remotePeer: rp, peer: nil, pktBuf: make([]byte, bufferSize), encBuf: make([]byte, bufferSize), } } func (rp *peerSuper) Run() { defer panicHandler() state := rp.stateInit for { state = state() } } // ---------------------------------------------------------------------------- func (rp *peerSuper) stateInit() stateFunc { //rp.logf("STATE: Init") x := peerData{} rp.shared.Store(&x) rp.peerData.relay = false rp.peerData.controlCipher = nil rp.peerData.dataCipher = nil rp.peerData.remoteAddr = zeroAddrPort rp.peerData.relayIP = 0 if rp.peer == nil { return rp.stateDisconnected } var addr netip.Addr addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) if rp.remotePublic { rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) } else { rp.peerData.relay = false } rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) return rp.stateSelectRole() } // ---------------------------------------------------------------------------- func (rp *peerSuper) stateDisconnected() stateFunc { //rp.logf("STATE: Disconnected") for { select { case <-rp.controlPackets: // Drop case rp.peer = <-rp.peerUpdates: return rp.stateInit } } } // ---------------------------------------------------------------------------- func (rp *peerSuper) stateSelectRole() stateFunc { rp.logf("STATE: SelectRole") if !rp.localPublic && !rp.remotePublic { return rp.stateSelectMediator } if !rp.localPublic { return rp.stateServer } else if !rp.remotePublic { return rp.stateClient } if rp.localIP < rp.remoteIP { return rp.stateClient } return rp.stateServer } // ---------------------------------------------------------------------------- func (rp *peerSuper) stateSelectMediator() stateFunc { rp.logf("STATE: SelectMediator") for { log.Printf("Selecting mediator...") if ip := rp.selectMediator(); ip != 0 { rp.logf("Got mediator: %d", ip) rp.peerData.relayIP = ip if rp.localIP < rp.remoteIP { return rp.stateClient } return rp.stateServer } select { case <-time.After(pingInterval): continue case rp.peer = <-rp.peerUpdates: return rp.stateInit } } } func (rp *peerSuper) selectMediator() byte { possible := make([]byte, 0, 8) for _, peer := range rp.peers { if peer.canRelay() { rp.logf("relay: %v", peer.shared.Load()) possible = append(possible, peer.remoteIP) } } if len(possible) == 0 { return 0 } return possible[rand.Intn(len(possible))] } // ---------------------------------------------------------------------------- // The remote is a server. func (rp *peerSuper) stateServer() stateFunc { rp.logf("STATE: Server") rp.peerData.dataCipher = newDataCipher() rp.updateShared() var ( pingTimer = time.NewTimer(pingInterval) timeoutTimer = time.NewTimer(timeoutInterval) ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} ) defer pingTimer.Stop() defer timeoutTimer.Stop() ping.SentAt = time.Now().UnixMilli() rp.sendControlPacket(ping) for { select { case <-pingTimer.C: ping.SentAt = time.Now().UnixMilli() rp.sendControlPacket(ping) pingTimer.Reset(pingInterval) case cPkt := <-rp.controlPackets: if _, ok := cPkt.Payload.(pongPacket); ok { timeoutTimer.Reset(timeoutInterval) } case <-timeoutTimer.C: if rp.peerData.relayIP != 0 { rp.logf("Timeout (server, relay)") return rp.stateSelectMediator } else { rp.logf("Timeout (server)") } case rp.peer = <-rp.peerUpdates: return rp.stateInit } } } // ---------------------------------------------------------------------------- // The remote is a client. func (rp *peerSuper) stateClient() stateFunc { rp.logf("STATE: Client") rp.updateShared() var ( currentKey = [32]byte{} timeoutTimer = time.NewTimer(timeoutInterval) ) defer timeoutTimer.Stop() for { select { case cPkt := <-rp.controlPackets: if cPkt.RemoteAddr != rp.peerData.remoteAddr { rp.peerData.remoteAddr = cPkt.RemoteAddr rp.logf("Got new remote address: %v", cPkt.RemoteAddr) rp.updateShared() } ping, ok := cPkt.Payload.(pingPacket) if !ok { continue } if ping.SharedKey != currentKey { rp.logf("Connected with new shared key") currentKey = ping.SharedKey rp.peerData.up = true rp.peerData.dataCipher = newDataCipherFromKey(currentKey) rp.updateShared() } timeoutTimer.Reset(timeoutInterval) rp.sendControlPacket(newPongPacket(ping.SentAt)) case <-timeoutTimer.C: if rp.peerData.relayIP != 0 { rp.logf("Timeout (server, relay)") return rp.stateSelectMediator } else { rp.logf("Timeout (server)") } case rp.peer = <-rp.peerUpdates: return rp.stateInit } } } // ---------------------------------------------------------------------------- func (rp *peerSuper) updateShared() { data := rp.peerData rp.shared.Store(&data) } // ---------------------------------------------------------------------------- func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { buf := pkt.Marshal(rp.pktBuf) h := header{ StreamID: controlStreamID, Counter: atomic.AddUint64(&rp.counter, 1), SourceIP: rp.localIP, DestIP: rp.remoteIP, } buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) if rp.peerData.relayIP == 0 { rp.conn.WriteTo(buf, rp.peerData.remoteAddr) return } rp.peers[rp.peerData.relayIP].RelayControlData(buf) }