package node import ( "math/rand" "net/netip" "time" "vppn/m" ) // ---------------------------------------------------------------------------- func (s *peerSuper) noPeer() stateFunc { return s.peerUpdate(<-s.peerUpdates) } // ---------------------------------------------------------------------------- func (s *peerSuper) peerUpdate(peer *m.Peer) stateFunc { return func() stateFunc { return s._peerUpdate(peer) } } func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc { defer s.publish() s.peer = peer s.staged = peerRouteInfo{} if s.peer == nil { return s.noPeer } s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) s.staged.dataCipher = newDataCipher() if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.remotePub = true s.staged.relay = peer.Mediator s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port) } if s.remotePub == s.localPub { if s.localIP < s.remoteIP { return s.serverAccept } return s.clientInit } if s.remotePub { return s.clientInit } return s.serverAccept } // ---------------------------------------------------------------------------- func (s *peerSuper) serverAccept() stateFunc { s.logf("STATE: server-accept") s.staged.up = false s.staged.dataCipher = nil s.staged.remoteAddr = zeroAddrPort s.staged.relayIP = 0 s.publish() var syn synPacket for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case synPacket: syn = p s.staged.remoteAddr = pkt.RemoteAddr s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey) s.staged.relayIP = syn.RelayIP s.publish() s.sendControlPacket(synAckPacket{ TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr, }) case ackPacket: if p.TraceID != syn.TraceID { continue } // Publish. return s.serverConnected(syn.TraceID) } } } } // ---------------------------------------------------------------------------- func (s *peerSuper) serverConnected(traceID uint64) stateFunc { s.logf("STATE: server-connected") s.staged.up = true s.publish() return func() stateFunc { return s._serverConnected(traceID) } } func (s *peerSuper) _serverConnected(traceID uint64) stateFunc { timeoutTimer := time.NewTimer(timeoutInterval) defer timeoutTimer.Stop() for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case ackPacket: if p.TraceID != traceID { return s.serverAccept } s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr}) timeoutTimer.Reset(timeoutInterval) } case <-timeoutTimer.C: s.logf("server timeout") return s.serverAccept } } } // ---------------------------------------------------------------------------- func (s *peerSuper) clientInit() stateFunc { s.logf("STATE: client-init") if !s.remotePub { // TODO: Check local discovery for IP. // TODO: Attempt UDP hole punch. // TODO: client-relayed return s.clientSelectRelay } return s.clientDial } // ---------------------------------------------------------------------------- func (s *peerSuper) clientSelectRelay() stateFunc { s.logf("STATE: client-select-relay") timer := time.NewTimer(0) defer timer.Stop() for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case <-timer.C: ip := s.selectRelayIP() if ip != 0 { s.logf("Got relay: %d", ip) s.staged.relayIP = ip s.publish() return s.clientDial } s.logf("No relay available.") timer.Reset(pingInterval) } } } func (s *peerSuper) selectRelayIP() byte { possible := make([]byte, 0, 8) for i, peer := range s.peers { if peer.CanRelay() { possible = append(possible, byte(i)) } } if len(possible) == 0 { return 0 } return possible[rand.Intn(len(possible))] } // ---------------------------------------------------------------------------- func (s *peerSuper) clientDial() stateFunc { s.logf("STATE: client-dial") var ( syn = synPacket{ TraceID: newTraceID(), SharedKey: s.staged.dataCipher.Key(), RelayIP: s.staged.relayIP, } timeout = time.NewTimer(dialTimeout) ) defer timeout.Stop() s.sendControlPacket(syn) for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case synAckPacket: if p.TraceID != syn.TraceID { continue // Hmm... } s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr}) return s.clientConnected(p) } case <-timeout.C: return s.clientInit } } } // ---------------------------------------------------------------------------- func (s *peerSuper) clientConnected(p synAckPacket) stateFunc { s.logf("STATE: client-connected") s.staged.up = true s.staged.localAddr = p.RecvAddr s.publish() return func() stateFunc { return s._clientConnected(p.TraceID) } } func (s *peerSuper) _clientConnected(traceID uint64) stateFunc { pingTimer := time.NewTimer(pingInterval) timeoutTimer := time.NewTimer(timeoutInterval) defer pingTimer.Stop() defer timeoutTimer.Stop() for { select { case peer := <-s.peerUpdates: return s.peerUpdate(peer) case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { case ackPacket: if p.TraceID != traceID { return s.clientInit } timeoutTimer.Reset(timeoutInterval) } case <-pingTimer.C: s.sendControlPacket(ackPacket{TraceID: traceID}) pingTimer.Reset(pingInterval) case <-timeoutTimer.C: s.logf("client timeout") return s.clientInit } } }