diff --git a/node/conn.go b/node/conn.go index 344d8d5..7671f36 100644 --- a/node/conn.go +++ b/node/conn.go @@ -5,6 +5,7 @@ import ( "log" "net" "net/netip" + "runtime/debug" "sync" ) @@ -22,6 +23,7 @@ func newConnWriter(conn *net.UDPConn) *connWriter { func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) { w.lock.Lock() if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { + debug.PrintStack() log.Fatalf("Failed to write to UDP port: %v", err) } w.lock.Unlock() diff --git a/node/header.go b/node/header.go index 97e5872..1a022a2 100644 --- a/node/header.go +++ b/node/header.go @@ -10,7 +10,7 @@ const ( controlHeaderSize = 24 dataStreamID = 1 dataHeaderSize = 12 - forwardStreamID = 3 + relayStreamID = 3 ) type header struct { diff --git a/node/main.go b/node/main.go index 9273823..35a00e6 100644 --- a/node/main.go +++ b/node/main.go @@ -114,7 +114,6 @@ func main(netName, listenIP string, port uint16) { go newHubPoller(netName, conf, peers).Run() go readFromConn(conf.PeerIP, conn, peers) readFromIFace(iface, peers) - } // ---------------------------------------------------------------------------- @@ -157,12 +156,7 @@ func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) { } h.Parse(data) - - if h.DestIP == localIP { - peers[h.SourceIP].HandlePacket(remoteAddr, h, data) - } else { - peers[h.DestIP].ForwardPacket(data) - } + peers[h.SourceIP].HandlePacket(remoteAddr, h, data) } } @@ -183,6 +177,6 @@ func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) { log.Fatalf("Failed to read from interface: %v", err) } - peers[remoteIP].SendData(packet) + peers[remoteIP].HandleInterfacePacket(packet) } } diff --git a/node/packets.go b/node/packets.go index d197f58..57c7341 100644 --- a/node/packets.go +++ b/node/packets.go @@ -7,7 +7,10 @@ import ( "unsafe" ) -var errMalformedPacket = errors.New("malformed packet") +var ( + errMalformedPacket = errors.New("malformed packet") + errUnknownPacketType = errors.New("unknown packet type") +) const ( packetTypePing = iota + 1 @@ -22,6 +25,18 @@ type controlPacket struct { Payload any } +func (p *controlPacket) ParsePayload(buf []byte) (err error) { + switch buf[0] { + case packetTypePing: + p.Payload, err = parsePingPacket(buf) + case packetTypePong: + p.Payload, err = parsePongPacket(buf) + default: + return errUnknownPacketType + } + return err +} + // ---------------------------------------------------------------------------- // A pingPacket is sent from a node acting as a client, to a node acting @@ -32,9 +47,9 @@ type pingPacket struct { SharedKey [32]byte } -func newPingPacket(sharedKey []byte) (pp pingPacket) { +func newPingPacket(sharedKey [32]byte) (pp pingPacket) { pp.SentAt = time.Now().UnixMilli() - copy(pp.SharedKey[:], sharedKey) + copy(pp.SharedKey[:], sharedKey[:]) return } diff --git a/node/packets_test.go b/node/packets_test.go index b385c2b..da242d4 100644 --- a/node/packets_test.go +++ b/node/packets_test.go @@ -12,7 +12,7 @@ func TestPacketPing(t *testing.T) { buf := make([]byte, bufferSize) - p := newPingPacket(sharedKey) + p := newPingPacket([32]byte(sharedKey)) out := p.Marshal(buf) p2, err := parsePingPacket(out) diff --git a/node/peer-states.go b/node/peer-states.go new file mode 100644 index 0000000..c3c0904 --- /dev/null +++ b/node/peer-states.go @@ -0,0 +1,214 @@ +package node + +import ( + "fmt" + "log" + "net/netip" + "sync/atomic" + "time" + "vppn/m" +) + +type peerState interface { + Name() string + OnPeerUpdate(*m.Peer) peerState + OnPing(netip.AddrPort, pingPacket) peerState + OnPong(netip.AddrPort, pongPacket) peerState + OnPingTimer() peerState + OnTimeoutTimer() peerState +} + +// ---------------------------------------------------------------------------- + +type stateBase struct { + // The purpose of this state machine is to manage this published data. + published *atomic.Pointer[peerData] + + // The other remote peers. + peers *remotePeers + + // Immutable data. + localIP byte + localPub bool + remoteIP byte + privKey []byte + conn *connWriter + + // For sending to peer. + counter *uint64 + + // Mutable peer data. + peer *m.Peer + remotePub bool + data peerData // Local copy of shared data. See publish(). + + // Timers + pingTimer *time.Timer + timeoutTimer *time.Timer + + buf []byte + encBuf []byte +} + +func (sb *stateBase) Name() string { return "idle" } + +func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState { + // Both nil: no change. + if peer == nil && s.peer == nil { + return nil + } + + // No change. + if peer != nil && s.peer != nil && s.peer.Version == peer.Version { + return nil + } + + s.peer = peer + + s.data = peerData{} + s.data.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) + + ip, isValid := netip.AddrFromSlice(peer.PublicIP) + if isValid { + s.remotePub = true + s.data.remoteAddr = netip.AddrPortFrom(ip, peer.Port) + s.data.relay = peer.Mediator + + if s.localPub && s.localIP < s.remoteIP { + return newStateServer(s) + } + return newStateClient(s) + } + + if s.localPub { + return newStateServer(s) + } + + // TODO: return newStateMediated(a/b) + + return nil +} + +func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil } +func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil } +func (s *stateBase) OnPingTimer() peerState { return nil } +func (s *stateBase) OnTimeoutTimer() peerState { return nil } + +// Helpers. + +func (s *stateBase) resetPingTimer() { s.pingTimer.Reset(pingInterval) } +func (s *stateBase) resetTimeoutTimer() { s.timeoutTimer.Reset(timeoutInterval) } +func (s *stateBase) stopPingTimer() { s.pingTimer.Stop() } +func (s *stateBase) stopTimeoutTimer() { s.timeoutTimer.Stop() } + +func (s *stateBase) logf(msg string, args ...any) { + log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) +} + +func (s *stateBase) publish() { + data := s.data + s.published.Store(&data) +} + +func (s *stateBase) sendPing(sharedKey [32]byte) { + s.sendControlPacket(newPingPacket(sharedKey)) +} + +func (s *stateBase) sendPong(ping pingPacket) { + s.sendControlPacket(newPongPacket(ping.SentAt)) +} + +func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { + buf := pkt.Marshal(s.buf) + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(s.counter, 1), + SourceIP: s.localIP, + DestIP: s.remoteIP, + } + + buf = s.data.controlCipher.Encrypt(h, buf, s.encBuf) + if s.data.relayIP == 0 { + s.conn.WriteTo(buf, s.data.remoteAddr) + return + } + + // TODO: Relay! +} + +// ---------------------------------------------------------------------------- + +type stateClient struct { + sharedKey [32]byte + *stateBase +} + +func newStateClient(b *stateBase) peerState { + s := &stateClient{stateBase: b} + s.publish() + + s.data.dataCipher = newDataCipher() + s.sharedKey = s.data.dataCipher.Key() + + s.sendPing(s.sharedKey) + s.resetPingTimer() + s.resetTimeoutTimer() + return s +} + +func (s *stateClient) Name() string { return "client" } + +func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState { + if !s.data.up { + s.data.up = true + s.publish() + } + s.resetTimeoutTimer() + return nil +} + +func (s *stateClient) OnPingTimer() peerState { + s.sendPing(s.sharedKey) + s.resetPingTimer() + return nil +} + +func (s *stateClient) OnTimeoutTimer() peerState { + s.data.up = false + s.publish() + return nil +} + +// ---------------------------------------------------------------------------- + +type stateServer struct { + *stateBase +} + +func newStateServer(b *stateBase) peerState { + s := &stateServer{b} + s.publish() + s.stopPingTimer() + s.stopTimeoutTimer() + return s +} + +func (s *stateServer) Name() string { return "server" } + +func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState { + if addr != s.data.remoteAddr { + s.logf("Got new peer address: %v", addr) + s.data.remoteAddr = addr + s.data.up = true + s.publish() + } + + if s.data.dataCipher == nil || p.SharedKey != s.data.dataCipher.Key() { + s.logf("Got new shared key.") + s.data.dataCipher = newDataCipherFromKey(p.SharedKey) + s.publish() + } + + s.sendPong(p) + return nil +} diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go index cc615ab..2c46ad2 100644 --- a/node/peer-supervisor.go +++ b/node/peer-supervisor.go @@ -1,10 +1,6 @@ package node import ( - "log" - "math/rand" - "net/netip" - "sync/atomic" "time" "vppn/m" ) @@ -15,263 +11,64 @@ const ( timeoutInterval = 20 * time.Second ) -type stateFunc func() stateFunc - -type peerSuper struct { - *remotePeer - - peer *m.Peer - remotePublic bool - peerData peerData - - pktBuf []byte - encBuf []byte -} - -func newPeerSuper(rp *remotePeer) *peerSuper { - return &peerSuper{ - remotePeer: rp, - peer: nil, - pktBuf: make([]byte, bufferSize), - encBuf: make([]byte, bufferSize), - } -} - -func (rp *peerSuper) Run() { +func (rp *remotePeer) supervise( + conf m.PeerConfig, + remoteIP byte, + conn *connWriter, + peers *remotePeers, +) { defer panicHandler() - state := rp.stateInit - for { - state = state() - } -} -// ---------------------------------------------------------------------------- - -func (rp *peerSuper) stateInit() stateFunc { - //rp.logf("STATE: Init") - - x := peerData{} - rp.shared.Store(&x) - - rp.peerData.relay = false - rp.peerData.controlCipher = nil - rp.peerData.dataCipher = nil - rp.peerData.remoteAddr = zeroAddrPort - rp.peerData.relayIP = 0 - - if rp.peer == nil { - return rp.stateDisconnected + base := &stateBase{ + published: rp.published, + peers: rp.peers, + localIP: rp.localIP, + remoteIP: rp.remoteIP, + privKey: conf.EncPrivKey, + localPub: addrIsValid(conf.PublicIP), + conn: rp.conn, + counter: &rp.counter, + pingTimer: time.NewTimer(time.Second), + timeoutTimer: time.NewTimer(time.Second), + buf: make([]byte, bufferSize), + encBuf: make([]byte, bufferSize), } - var addr netip.Addr - addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) - if rp.remotePublic { - rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) - } else { - rp.peerData.relay = false - } - - rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) - - return rp.stateSelectRole() -} - -// ---------------------------------------------------------------------------- - -func (rp *peerSuper) stateDisconnected() stateFunc { - //rp.logf("STATE: Disconnected") - for { - select { - case <-rp.controlPackets: - // Drop - case rp.peer = <-rp.peerUpdates: - return rp.stateInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (rp *peerSuper) stateSelectRole() stateFunc { - rp.logf("STATE: SelectRole") - - if !rp.localPublic && !rp.remotePublic { - return rp.stateSelectMediator - } - - if !rp.localPublic { - return rp.stateServer - } else if !rp.remotePublic { - return rp.stateClient - } - - if rp.localIP < rp.remoteIP { - return rp.stateClient - } - return rp.stateServer -} - -// ---------------------------------------------------------------------------- - -func (rp *peerSuper) stateSelectMediator() stateFunc { - rp.logf("STATE: SelectMediator") - - for { - log.Printf("Selecting mediator...") - if ip := rp.selectMediator(); ip != 0 { - rp.logf("Got mediator: %d", ip) - rp.peerData.relayIP = ip - - if rp.localIP < rp.remoteIP { - return rp.stateClient - } - return rp.stateServer - } - - select { - case <-time.After(pingInterval): - continue - case rp.peer = <-rp.peerUpdates: - return rp.stateInit - } - } - -} - -func (rp *peerSuper) selectMediator() byte { - possible := make([]byte, 0, 8) - for _, peer := range rp.peers { - if peer.canRelay() { - rp.logf("relay: %v", peer.shared.Load()) - possible = append(possible, peer.remoteIP) - } - } - if len(possible) == 0 { - return 0 - } - return possible[rand.Intn(len(possible))] -} - -// ---------------------------------------------------------------------------- - -// The remote is a server. -func (rp *peerSuper) stateServer() stateFunc { - rp.logf("STATE: Server") - rp.peerData.dataCipher = newDataCipher() - rp.updateShared() + base.pingTimer.Stop() + base.timeoutTimer.Stop() var ( - pingTimer = time.NewTimer(pingInterval) - timeoutTimer = time.NewTimer(timeoutInterval) - ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} + curState peerState = base + nextState peerState ) - defer pingTimer.Stop() - defer timeoutTimer.Stop() - - ping.SentAt = time.Now().UnixMilli() - rp.sendControlPacket(ping) for { + nextState = nil + select { - case <-pingTimer.C: - ping.SentAt = time.Now().UnixMilli() - rp.sendControlPacket(ping) - pingTimer.Reset(pingInterval) + case peer := <-rp.peerUpdates: + nextState = curState.OnPeerUpdate(peer) - case cPkt := <-rp.controlPackets: - if _, ok := cPkt.Payload.(pongPacket); ok { - timeoutTimer.Reset(timeoutInterval) + case pkt := <-rp.controlPackets: + switch p := pkt.Payload.(type) { + case pingPacket: + nextState = curState.OnPing(pkt.RemoteAddr, p) + case pongPacket: + nextState = curState.OnPong(pkt.RemoteAddr, p) + default: + // Unknown packet type. } - case <-timeoutTimer.C: - if rp.peerData.relayIP != 0 { - rp.logf("Timeout (server, relay)") - return rp.stateSelectMediator - } else { - rp.logf("Timeout (server)") - } + case <-base.pingTimer.C: + nextState = curState.OnPingTimer() - case rp.peer = <-rp.peerUpdates: - return rp.stateInit + case <-base.timeoutTimer.C: + nextState = curState.OnTimeoutTimer() + } + + if nextState != nil { + rp.logf("%s --> %s", curState.Name(), nextState.Name()) + curState = nextState } } } - -// ---------------------------------------------------------------------------- - -// The remote is a client. -func (rp *peerSuper) stateClient() stateFunc { - rp.logf("STATE: Client") - rp.updateShared() - - var ( - currentKey = [32]byte{} - timeoutTimer = time.NewTimer(timeoutInterval) - ) - - defer timeoutTimer.Stop() - - for { - select { - case cPkt := <-rp.controlPackets: - if cPkt.RemoteAddr != rp.peerData.remoteAddr { - rp.peerData.remoteAddr = cPkt.RemoteAddr - rp.logf("Got new remote address: %v", cPkt.RemoteAddr) - rp.updateShared() - } - - ping, ok := cPkt.Payload.(pingPacket) - if !ok { - continue - } - - if ping.SharedKey != currentKey { - rp.logf("Connected with new shared key") - currentKey = ping.SharedKey - rp.peerData.up = true - rp.peerData.dataCipher = newDataCipherFromKey(currentKey) - rp.updateShared() - } - - timeoutTimer.Reset(timeoutInterval) - rp.sendControlPacket(newPongPacket(ping.SentAt)) - - case <-timeoutTimer.C: - if rp.peerData.relayIP != 0 { - rp.logf("Timeout (server, relay)") - return rp.stateSelectMediator - } else { - rp.logf("Timeout (server)") - } - - case rp.peer = <-rp.peerUpdates: - return rp.stateInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (rp *peerSuper) updateShared() { - data := rp.peerData - rp.shared.Store(&data) -} - -// ---------------------------------------------------------------------------- - -func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { - buf := pkt.Marshal(rp.pktBuf) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&rp.counter, 1), - SourceIP: rp.localIP, - DestIP: rp.remoteIP, - } - buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) - if rp.peerData.relayIP == 0 { - rp.conn.WriteTo(buf, rp.peerData.remoteAddr) - return - } - - rp.peers[rp.peerData.relayIP].RelayControlData(buf) -} diff --git a/node/peer.go b/node/peer.go index d999339..3cc308e 100644 --- a/node/peer.go +++ b/node/peer.go @@ -22,19 +22,14 @@ type peerData struct { type remotePeer struct { // Immutable data. - localIP byte - remoteIP byte - privKey []byte - localPublic bool // True if local node is public. - iface *ifWriter - conn *connWriter + localIP byte + remoteIP byte + iface *ifWriter + conn *connWriter // Shared state. - peers *remotePeers - shared *atomic.Pointer[peerData] - - // Only used in HandlePeerUpdate. - peerVersion int64 + peers *remotePeers + published *atomic.Pointer[peerData] // Only used in HandlePacket / Not synchronized. dupCheck *dupCheck @@ -55,12 +50,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn rp := &remotePeer{ localIP: conf.PeerIP, remoteIP: remoteIP, - privKey: conf.EncPrivKey, - localPublic: addrIsValid(conf.PublicIP), iface: iface, conn: conn, peers: peers, - shared: &atomic.Pointer[peerData]{}, + published: &atomic.Pointer[peerData]{}, dupCheck: newDupCheck(0), decryptBuf: make([]byte, bufferSize), encryptBuf: make([]byte, bufferSize), @@ -70,10 +63,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn } pd := peerData{} - rp.shared.Store(&pd) - - go newPeerSuper(rp).Run() + rp.published.Store(&pd) + //go newPeerSuper(rp).Run() + go rp.supervise(conf, remoteIP, conn, peers) return rp } @@ -82,10 +75,7 @@ func (rp *remotePeer) logf(msg string, args ...any) { } func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) { - if peer != nil && peer.Version != rp.peerVersion { - rp.peerUpdates <- peer - rp.peerVersion = peer.Version - } + rp.peerUpdates <- peer } // ---------------------------------------------------------------------------- @@ -101,6 +91,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) { case dataStreamID: rp.handleDataPacket(data) + case relayStreamID: + rp.handleRelayPacket(h, data) + default: rp.logf("Unknown stream ID: %d", h.StreamID) } @@ -109,8 +102,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) { // ---------------------------------------------------------------------------- func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) { - shared := rp.shared.Load() + shared := rp.published.Load() if shared.controlCipher == nil { + log.Printf("Shared: %+v", *shared) rp.logf("Not connected (control).") return } @@ -141,19 +135,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] RemoteAddr: addr, } - var err error - - switch out[0] { - case packetTypePing: - pkt.Payload, err = parsePingPacket(out) - case packetTypePong: - pkt.Payload, err = parsePongPacket(out) - default: - rp.logf("Unknown control packet type: %d", out[0]) - return - } - - if err != nil { + if err := pkt.ParsePayload(out); err != nil { rp.logf("Failed to parse control packet: %v", err) return } @@ -168,7 +150,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] // ---------------------------------------------------------------------------- func (rp *remotePeer) handleDataPacket(data []byte) { - shared := rp.shared.Load() + shared := rp.published.Load() if shared.dataCipher == nil { rp.logf("Not connected (recv).") return @@ -185,34 +167,65 @@ func (rp *remotePeer) handleDataPacket(data []byte) { // ---------------------------------------------------------------------------- +func (rp *remotePeer) handleRelayPacket(h header, data []byte) { + shared := rp.published.Load() + if shared.dataCipher == nil { + rp.logf("Not connected (recv).") + return + } + + dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf) + if !ok { + rp.logf("Failed to decrypt data packet.") + return + } + + rp.peers[h.DestIP].sendDirect(dec) +} + +// ---------------------------------------------------------------------------- + // SendData sends data coming from the interface going to the network. // // This function is called by a single thread. func (rp *remotePeer) SendData(data []byte) { - rp.sendData(dataStreamID, data) + rp.sendData(dataStreamID, rp.remoteIP, data) } -// ---------------------------------------------------------------------------- +func (rp *remotePeer) HandleInterfacePacket(data []byte) { + shared := rp.published.Load() -func (rp *remotePeer) RelayControlData(data []byte) { - rp.sendData(forwardStreamID, data) -} - -// ---------------------------------------------------------------------------- - -func (rp *remotePeer) ForwardPacket(data []byte) { - shared := rp.shared.Load() - if shared.remoteAddr == zeroAddrPort { - rp.logf("Not connected (forward).") + if shared.dataCipher == nil { + rp.logf("Not connected (handle interface).") return } - rp.conn.WriteTo(data, shared.remoteAddr) + + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&rp.counter, 1), + SourceIP: rp.localIP, + DestIP: rp.remoteIP, + } + + enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) + + if shared.relayIP != 0 { + rp.peers[shared.relayIP].RelayData(shared.relayIP, enc) + } else { + rp.SendData(data) + } } // ---------------------------------------------------------------------------- -func (rp *remotePeer) sendData(streamID byte, data []byte) { - shared := rp.shared.Load() +func (rp *remotePeer) RelayData(destIP byte, data []byte) { + rp.sendData(relayStreamID, destIP, data) +} + +// ---------------------------------------------------------------------------- + +func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) { + shared := rp.published.Load() if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { rp.logf("Not connected (send).") return @@ -222,16 +235,18 @@ func (rp *remotePeer) sendData(streamID byte, data []byte) { StreamID: streamID, Counter: atomic.AddUint64(&rp.counter, 1), SourceIP: rp.localIP, - DestIP: rp.remoteIP, + DestIP: destIP, } enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) rp.conn.WriteTo(enc, shared.remoteAddr) } -// ---------------------------------------------------------------------------- - -func (rp *remotePeer) canRelay() bool { - shared := rp.shared.Load() - return shared.relay && shared.up +func (rp *remotePeer) sendDirect(data []byte) { + shared := rp.published.Load() + if shared.remoteAddr == zeroAddrPort { + rp.logf("Not connected (send).") + return + } + rp.conn.WriteTo(data, shared.remoteAddr) }