diff --git a/README.md b/README.md index 3aa4d04..b9d291e 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ ## Roadmap +* Rename Mediator -> Relay * Node: use symmetric encryption after handshake * AEAD-AES uses a 12 byte nonce. We need to shrink the header: * Remove Forward and replace it with a HeaderFlags bitfield. diff --git a/node/addrutil.go b/node/addrutil.go new file mode 100644 index 0000000..590c80c --- /dev/null +++ b/node/addrutil.go @@ -0,0 +1,8 @@ +package node + +import "net/netip" + +func addrIsValid(in []byte) bool { + _, ok := netip.AddrFromSlice(in) + return ok +} diff --git a/node/cipher-control.go b/node/cipher-control.go new file mode 100644 index 0000000..e9b56d5 --- /dev/null +++ b/node/cipher-control.go @@ -0,0 +1,26 @@ +package node + +import "golang.org/x/crypto/nacl/box" + +type controlCipher struct { + sharedKey [32]byte +} + +func newControlCipher(privKey, pubKey []byte) *controlCipher { + shared := [32]byte{} + box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) + return &controlCipher{shared} +} + +func (cc *controlCipher) Encrypt(h xHeader, data, out []byte) []byte { + const s = controlHeaderSize + out = out[:s+controlCipherOverhead+len(data)] + h.Marshal(out[:s]) + box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) + return out +} + +func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { + const s = controlHeaderSize + return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) +} diff --git a/node/cipher-routing_test.go b/node/cipher-control_test.go similarity index 62% rename from node/cipher-routing_test.go rename to node/cipher-control_test.go index 09824f7..c571aa2 100644 --- a/node/cipher-routing_test.go +++ b/node/cipher-control_test.go @@ -3,12 +3,13 @@ package node import ( "bytes" "crypto/rand" + "reflect" "testing" "golang.org/x/crypto/nacl/box" ) -func newRoutingCipherForTesting() (c1, c2 routingCipher) { +func newControlCipherForTesting() (c1, c2 *controlCipher) { pubKey1, privKey1, err := box.GenerateKey(rand.Reader) if err != nil { panic(err) @@ -19,14 +20,14 @@ func newRoutingCipherForTesting() (c1, c2 routingCipher) { panic(err) } - return newRoutingCipher(privKey1[:], pubKey2[:]), - newRoutingCipher(privKey2[:], pubKey1[:]) + return newControlCipher(privKey1[:], pubKey2[:]), + newControlCipher(privKey2[:], pubKey1[:]) } -func TestRoutingCipher(t *testing.T) { - c1, c2 := newRoutingCipherForTesting() +func TestControlCipher(t *testing.T) { + c1, c2 := newControlCipherForTesting() - maxSizePlaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) + maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) rand.Read(maxSizePlaintext) testCases := [][]byte{ @@ -40,6 +41,7 @@ func TestRoutingCipher(t *testing.T) { for _, plaintext := range testCases { h1 := xHeader{ + StreamID: controlStreamID, Counter: 235153, SourceIP: 4, DestIP: 88, @@ -49,6 +51,12 @@ func TestRoutingCipher(t *testing.T) { encrypted = c1.Encrypt(h1, plaintext, encrypted) + h2 := xHeader{} + h2.Parse(encrypted) + if !reflect.DeepEqual(h1, h2) { + t.Fatal(h1, h2) + } + decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) if !ok { t.Fatal(ok) @@ -60,9 +68,9 @@ func TestRoutingCipher(t *testing.T) { } } -func TestRoutingCipher_ShortCiphertext(t *testing.T) { - c1, _ := newRoutingCipherForTesting() - shortText := make([]byte, routingHeaderSize+routingCipherOverhead-1) +func TestControlCipher_ShortCiphertext(t *testing.T) { + c1, _ := newControlCipherForTesting() + shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) rand.Read(shortText) _, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) if ok { @@ -70,15 +78,15 @@ func TestRoutingCipher_ShortCiphertext(t *testing.T) { } } -func BenchmarkRoutingCipher_Encrypt(b *testing.B) { - c1, _ := newRoutingCipherForTesting() +func BenchmarkControlCipher_Encrypt(b *testing.B) { + c1, _ := newControlCipherForTesting() h1 := xHeader{ Counter: 235153, SourceIP: 4, DestIP: 88, } - plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) rand.Read(plaintext) encrypted := make([]byte, bufferSize) @@ -89,8 +97,8 @@ func BenchmarkRoutingCipher_Encrypt(b *testing.B) { } } -func BenchmarkRoutingCipher_Decrypt(b *testing.B) { - c1, c2 := newRoutingCipherForTesting() +func BenchmarkControlCipher_Decrypt(b *testing.B) { + c1, c2 := newControlCipherForTesting() h1 := xHeader{ Counter: 235153, @@ -98,7 +106,7 @@ func BenchmarkRoutingCipher_Decrypt(b *testing.B) { DestIP: 88, } - plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) rand.Read(plaintext) encrypted := make([]byte, bufferSize) diff --git a/node/cipher-data.go b/node/cipher-data.go index c0fc273..26d3121 100644 --- a/node/cipher-data.go +++ b/node/cipher-data.go @@ -6,22 +6,23 @@ import ( "crypto/rand" ) +// TODO: Use [32]byte for simplicity everywhere. type dataCipher struct { - key []byte + key [32]byte aead cipher.AEAD } func newDataCipher() *dataCipher { - key := make([]byte, 32) - if _, err := rand.Read(key); err != nil { + key := [32]byte{} + if _, err := rand.Read(key[:]); err != nil { panic(err) } return newDataCipherFromKey(key) } // key must be 32 bytes. -func newDataCipherFromKey(key []byte) *dataCipher { - block, err := aes.NewCipher(key) +func newDataCipherFromKey(key [32]byte) *dataCipher { + block, err := aes.NewCipher(key[:]) if err != nil { panic(err) } @@ -34,14 +35,14 @@ func newDataCipherFromKey(key []byte) *dataCipher { return &dataCipher{key: key, aead: aead} } -func (sc *dataCipher) Key() []byte { +func (sc *dataCipher) Key() [32]byte { return sc.key } func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte { const s = dataHeaderSize out = out[:s+dataCipherOverhead+len(data)] - h.Marshal(dataStreamID, out[:s]) + h.Marshal(out[:s]) sc.aead.Seal(out[s:s], out[:s], data, nil) return out } diff --git a/node/cipher-data_test.go b/node/cipher-data_test.go index d1523d8..c3892bb 100644 --- a/node/cipher-data_test.go +++ b/node/cipher-data_test.go @@ -23,6 +23,7 @@ func TestDataCipher(t *testing.T) { for _, plaintext := range testCases { h1 := xHeader{ + StreamID: dataStreamID, Counter: 235153, SourceIP: 4, DestIP: 88, diff --git a/node/cipher-routing.go b/node/cipher-routing.go deleted file mode 100644 index 795ac7a..0000000 --- a/node/cipher-routing.go +++ /dev/null @@ -1,26 +0,0 @@ -package node - -import "golang.org/x/crypto/nacl/box" - -type routingCipher struct { - sharedKey [32]byte -} - -func newRoutingCipher(privKey, pubKey []byte) routingCipher { - shared := [32]byte{} - box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) - return routingCipher{shared} -} - -func (rc routingCipher) Encrypt(h xHeader, data, out []byte) []byte { - const s = routingHeaderSize - out = out[:s+routingCipherOverhead+len(data)] - h.Marshal(routingStreamID, out[:s]) - box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &rc.sharedKey) - return out -} - -func (rc routingCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { - const s = routingHeaderSize - return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &rc.sharedKey) -} diff --git a/node/conn.go b/node/conn.go index 8a57641..7f7e4e3 100644 --- a/node/conn.go +++ b/node/conn.go @@ -1,6 +1,7 @@ package node import ( + "io" "log" "net" "net/netip" @@ -9,6 +10,48 @@ import ( "vppn/fasttime" ) +// ---------------------------------------------------------------------------- + +type connWriter2 struct { + lock sync.Mutex + conn *net.UDPConn +} + +func newConnWriter2(conn *net.UDPConn) *connWriter2 { + return &connWriter2{conn: conn} +} + +func (w *connWriter2) WriteTo(packet []byte, addr netip.AddrPort) { + w.lock.Lock() + if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { + log.Fatalf("Failed to write to UDP port: %v", err) + } + w.lock.Unlock() +} + +// ---------------------------------------------------------------------------- + +type ifWriter struct { + lock sync.Mutex + iface io.ReadWriteCloser +} + +func newIFWriter(iface io.ReadWriteCloser) *ifWriter { + return &ifWriter{iface: iface} +} + +func (w *ifWriter) Write(packet []byte) { + w.lock.Lock() + if _, err := w.iface.Write(packet); err != nil { + log.Fatalf("Failed to write to interface: %v", err) + } + w.lock.Unlock() +} + +// ---------------------------------------------------------------------------- + +// TODO: Delete below?? + type connWriter struct { *net.UDPConn lock sync.Mutex diff --git a/node/header.go b/node/header.go index a409576..d2eb142 100644 --- a/node/header.go +++ b/node/header.go @@ -5,30 +5,33 @@ import "unsafe" // ---------------------------------------------------------------------------- const ( - routingStreamID = 2 - routingHeaderSize = 24 - routingCipherOverhead = 16 + controlStreamID = 2 + controlHeaderSize = 24 + controlCipherOverhead = 16 dataStreamID = 1 dataHeaderSize = 12 dataCipherOverhead = 16 + + forwardStreamID = 3 ) -// TODO: Rename type xHeader struct { + StreamID byte Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. SourceIP byte DestIP byte } func (h *xHeader) Parse(b []byte) { + h.StreamID = b[0] h.Counter = *(*uint64)(unsafe.Pointer(&b[1])) h.SourceIP = b[9] h.DestIP = b[10] } -func (h *xHeader) Marshal(streamID byte, buf []byte) { - buf[0] = streamID +func (h *xHeader) Marshal(buf []byte) { + buf[0] = h.StreamID *(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter buf[9] = h.SourceIP buf[10] = h.DestIP @@ -40,7 +43,7 @@ func (h *xHeader) Marshal(streamID byte, buf []byte) { const ( headerSize = 24 streamData = 1 - streamRouting = 2 + streamControl = 2 ) type header struct { diff --git a/node/header_test.go b/node/header_test.go index 7a87354..0205d87 100644 --- a/node/header_test.go +++ b/node/header_test.go @@ -3,18 +3,17 @@ package node import "testing" func TestHeaderMarshalParse(t *testing.T) { - nIn := header{ + nIn := xHeader{ + StreamID: 23, Counter: 3212, SourceIP: 34, DestIP: 200, - Forward: 1, - Stream: 44, } buf := make([]byte, headerSize) nIn.Marshal(buf) - nOut := header{} + nOut := xHeader{} nOut.Parse(buf) if nIn != nOut { t.Fatal(nIn, nOut) diff --git a/node/main.go b/node/main.go index cac6df8..f5c9bc7 100644 --- a/node/main.go +++ b/node/main.go @@ -102,15 +102,19 @@ func main(netName, listenIP string, port uint16) { log.Fatalf("Failed to open UDP port: %v", err) } - routing := newRoutingTable() + connWriter := newConnWriter2(conn) + ifWriter := newIFWriter(iface) - w := newConnWriter(conn, conf.PeerIP, routing) - r := newConnReader(conn, conf.PeerIP, routing) + peers := remotePeers{} - router := newRouter(netName, conf, routing, w) + for i := range peers { + peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter) + } + + go newHubPoller(netName, conf, peers).Run() + go readFromConn(conn, peers) + readFromIFace(iface, peers) - go nodeConnReader(r, w, iface, router) - nodeIFaceReader(w, iface, router) } // ---------------------------------------------------------------------------- @@ -127,43 +131,39 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 { // ---------------------------------------------------------------------------- -func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) { +func readFromConn(conn *net.UDPConn, peers remotePeers) { + defer panicHandler() + var ( remoteAddr netip.AddrPort - h header + n int + err error buf = make([]byte, bufferSize) data []byte - err error + h xHeader ) for { - remoteAddr, h, data = r.Read(buf) - - if h.Forward != 0 { - w.Forward(h.DestIP, data) - continue + n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize]) + if err != nil { + log.Fatalf("Failed to read from UDP port: %v", err) } - switch h.Stream { + data = buf[:n] - case streamData: - if _, err = iface.Write(data); err != nil { - log.Printf("Malformed data from peer %d: %v", h.SourceIP, err) - } - - case streamRouting: - router.HandlePacket(h.SourceIP, remoteAddr, data) - - default: - log.Printf("Dropping unknown stream: %d", h.Stream) + if n < headerSize { + continue // Packet it soo short. } + + h.Parse(data) + peers[h.SourceIP].HandlePacket(remoteAddr, h, data) } } // ---------------------------------------------------------------------------- -func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { +func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) { var ( buf = make([]byte, bufferSize) @@ -173,16 +173,11 @@ func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { ) for { - packet, remoteIP, err = readNextPacket(iface, buf) if err != nil { log.Fatalf("Failed to read from interface: %v", err) } - if remoteIP == w.localIP { - continue // Don't write to self. - } - - w.WriteTo(remoteIP, streamData, packet) + peers[remoteIP].SendData(packet) } } diff --git a/node/packets.go b/node/packets.go index 75f4e6e..d197f58 100644 --- a/node/packets.go +++ b/node/packets.go @@ -16,10 +16,10 @@ const ( // ---------------------------------------------------------------------------- -type packetWrapper struct { +type controlPacket struct { SrcIP byte RemoteAddr netip.AddrPort - Packet any + Payload any } // ---------------------------------------------------------------------------- @@ -46,13 +46,13 @@ func (p pingPacket) Marshal(buf []byte) []byte { return buf } -func (p *pingPacket) Parse(buf []byte) error { +func parsePingPacket(buf []byte) (p pingPacket, err error) { if len(buf) != 41 { - return errMalformedPacket + return p, errMalformedPacket } p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) copy(p.SharedKey[:], buf[9:41]) - return nil + return } // ---------------------------------------------------------------------------- @@ -78,12 +78,11 @@ func (p pongPacket) Marshal(buf []byte) []byte { return buf } -func (p *pongPacket) Parse(buf []byte) error { +func parsePongPacket(buf []byte) (p pongPacket, err error) { if len(buf) != 17 { - return errMalformedPacket + return p, errMalformedPacket } p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9])) - - return nil + return } diff --git a/node/packets_test.go b/node/packets_test.go index bd89215..b385c2b 100644 --- a/node/packets_test.go +++ b/node/packets_test.go @@ -15,8 +15,8 @@ func TestPacketPing(t *testing.T) { p := newPingPacket(sharedKey) out := p.Marshal(buf) - p2 := pingPacket{} - if err := p2.Parse(out); err != nil { + p2, err := parsePingPacket(out) + if err != nil { t.Fatal(err) } @@ -31,8 +31,8 @@ func TestPacketPong(t *testing.T) { p := newPongPacket(123566) out := p.Marshal(buf) - p2 := pongPacket{} - if err := p2.Parse(out); err != nil { + p2, err := parsePongPacket(out) + if err != nil { t.Fatal(err) } diff --git a/node/peer-pollhub.go b/node/peer-pollhub.go new file mode 100644 index 0000000..aa1c91b --- /dev/null +++ b/node/peer-pollhub.go @@ -0,0 +1,97 @@ +package node + +import ( + "encoding/json" + "io" + "log" + "net/http" + "net/url" + "time" + "vppn/m" +) + +type hubPoller struct { + netName string + localIP byte + client *http.Client + req *http.Request + peers remotePeers +} + +func newHubPoller(netName string, conf m.PeerConfig, peers remotePeers) *hubPoller { + u, err := url.Parse(conf.HubAddress) + if err != nil { + log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err) + } + u.Path = "/peer/fetch-state/" + + client := &http.Client{Timeout: 8 * time.Second} + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Header: http.Header{}, + } + req.SetBasicAuth("", conf.APIKey) + + return &hubPoller{ + netName: netName, + localIP: conf.PeerIP, + client: client, + req: req, + peers: peers, + } +} + +func (hp *hubPoller) Run() { + defer panicHandler() + + state, err := loadNetworkState(hp.netName) + if err != nil { + log.Printf("Failed to load network state: %v", err) + log.Printf("Polling hub...") + hp.pollHub() + } else { + hp.applyNetworkState(state) + } + + for range time.Tick(64 * time.Second) { + hp.pollHub() + } +} + +func (hp *hubPoller) pollHub() { + var state m.NetworkState + + log.Printf("Fetching peer state...") + resp, err := hp.client.Do(hp.req) + if err != nil { + log.Printf("Failed to fetch peer state: %v", err) + return + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + log.Printf("Failed to read body from hub: %v", err) + return + } + + if err := json.Unmarshal(body, &state); err != nil { + log.Printf("Failed to unmarshal response from hub: %v", err) + return + } + + hp.applyNetworkState(state) + + if err := storeNetworkState(hp.netName, state); err != nil { + log.Printf("Failed to store network state: %v", err) + } +} + +func (hp *hubPoller) applyNetworkState(state m.NetworkState) { + for i := range state.Peers { + if i != int(hp.localIP) { + hp.peers[i].HandlePeerUpdate(state.Peers[i]) + } + } +} diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go new file mode 100644 index 0000000..cfcb43b --- /dev/null +++ b/node/peer-supervisor.go @@ -0,0 +1,197 @@ +package node + +import ( + "net/netip" + "sync/atomic" + "time" + "vppn/m" +) + +const ( + connectTimeout = 6 * time.Second + pingInterval = 6 * time.Second + timeoutInterval = 20 * time.Second +) + +type stateFunc func() stateFunc + +type peerSuper struct { + *remotePeer + + peer *m.Peer + remotePublic bool + peerData peerData + + pktBuf []byte + encBuf []byte +} + +func newPeerSuper(rp *remotePeer) *peerSuper { + return &peerSuper{ + remotePeer: rp, + peer: nil, + pktBuf: make([]byte, bufferSize), + encBuf: make([]byte, bufferSize), + } +} + +func (rp *peerSuper) Run() { + defer panicHandler() + state := rp.stateInit + for { + state = state() + } +} + +// ---------------------------------------------------------------------------- + +func (rp *peerSuper) stateInit() stateFunc { + //rp.logf("STATE: Init") + x := peerData{} + rp.shared.Store(&x) + + rp.peerData.controlCipher = nil + rp.peerData.dataCipher = nil + rp.peerData.remoteAddr = zeroAddrPort + + if rp.peer == nil { + return rp.stateDisconnected + } + + var addr netip.Addr + addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) + if rp.remotePublic { + rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) + } + + rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) + + return rp.stateSelectRole() +} + +// ---------------------------------------------------------------------------- + +func (rp *peerSuper) stateDisconnected() stateFunc { + //rp.logf("STATE: Disconnected") + for { + select { + case <-rp.controlPackets: + // Drop + case rp.peer = <-rp.peerUpdates: + return rp.stateInit + } + } +} + +// ---------------------------------------------------------------------------- + +func (rp *peerSuper) stateSelectRole() stateFunc { + rp.logf("STATE: SelectRole") + + if !rp.localPublic && !rp.remotePublic { + // TODO! + return rp.stateDisconnected + } + + if !rp.localPublic { + return rp.stateServer + } else if !rp.remotePublic { + return rp.stateClient + } + + if rp.localIP < rp.peer.PeerIP { + return rp.stateClient + } + return rp.stateServer +} + +// The remote is a server. +func (rp *peerSuper) stateServer() stateFunc { + rp.logf("STATE: Server") + rp.peerData.dataCipher = newDataCipher() + rp.updateShared() + + var ( + pingTimer = time.NewTimer(pingInterval) + ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} + ) + defer pingTimer.Stop() + + ping.SentAt = time.Now().UnixMilli() + rp.sendControlPacket(ping) + + for { + select { + case <-pingTimer.C: + ping.SentAt = time.Now().UnixMilli() + rp.sendControlPacket(ping) + pingTimer.Reset(pingInterval) + + case <-rp.controlPackets: + // Ignore + + case rp.peer = <-rp.peerUpdates: + return rp.stateInit + } + } +} + +// ---------------------------------------------------------------------------- + +// The remote is a client. +func (rp *peerSuper) stateClient() stateFunc { + rp.logf("STATE: Client") + rp.updateShared() + + // TODO: Could use timeout to set dataCipher to nil. + var currentKey = [32]byte{} + + for { + select { + case cPkt := <-rp.controlPackets: + if cPkt.RemoteAddr != rp.peerData.remoteAddr { + rp.peerData.remoteAddr = cPkt.RemoteAddr + rp.logf("Got new remote address: %v", cPkt.RemoteAddr) + rp.updateShared() + } + + ping, ok := cPkt.Payload.(pingPacket) + if !ok { + continue + } + + if ping.SharedKey != currentKey { + rp.logf("Connected with new shared key") + currentKey = ping.SharedKey + rp.peerData.dataCipher = newDataCipherFromKey(currentKey) + rp.updateShared() + } + + rp.sendControlPacket(newPongPacket(ping.SentAt)) + + case rp.peer = <-rp.peerUpdates: + return rp.stateInit + } + } +} + +// ---------------------------------------------------------------------------- + +func (rp *peerSuper) updateShared() { + data := rp.peerData + rp.shared.Store(&data) +} + +// ---------------------------------------------------------------------------- + +func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { + buf := pkt.Marshal(rp.pktBuf) + h := xHeader{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(&rp.counter, 1), + SourceIP: rp.localIP, + DestIP: rp.remoteIP, + } + buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) + rp.conn.WriteTo(buf, rp.peerData.remoteAddr) +} diff --git a/node/peer.go b/node/peer.go index 2b4023a..19cddfd 100644 --- a/node/peer.go +++ b/node/peer.go @@ -1 +1,206 @@ package node + +import ( + "fmt" + "log" + "net/netip" + "sync/atomic" + "time" + "vppn/m" +) + +type remotePeers [256]*remotePeer + +type peerData struct { + controlCipher *controlCipher + dataCipher *dataCipher + remoteAddr netip.AddrPort +} + +type remotePeer struct { + // Immutable data. + localIP byte + remoteIP byte + privKey []byte + localPublic bool // True if local node is public. + iface *ifWriter + conn *connWriter2 + + // Shared state. + shared *atomic.Pointer[peerData] + + // Only used in HandlePeerUpdate. + peerVersion int64 + + // Only used in HandlePacket / Not synchronized. + dupCheck *dupCheck + decryptBuf []byte + + // Only used in SendData / Not synchronized. + encryptBuf []byte + + // Used for sending control and data packets. Atomic access only. + counter uint64 + + // For communicating with the supervisor thread. + peerUpdates chan *m.Peer + controlPackets chan controlPacket +} + +func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter2) *remotePeer { + rp := &remotePeer{ + localIP: conf.PeerIP, + remoteIP: remoteIP, + privKey: conf.EncPrivKey, + localPublic: addrIsValid(conf.PublicIP), + iface: iface, + conn: conn, + shared: &atomic.Pointer[peerData]{}, + dupCheck: newDupCheck(0), + decryptBuf: make([]byte, bufferSize), + encryptBuf: make([]byte, bufferSize), + counter: uint64(time.Now().Unix()) << 30, + peerUpdates: make(chan *m.Peer), + controlPackets: make(chan controlPacket, 512), + } + + pd := peerData{} + rp.shared.Store(&pd) + + go newPeerSuper(rp).Run() + + return rp +} + +func (rp *remotePeer) logf(msg string, args ...any) { + log.Printf(fmt.Sprintf("[%03d] ", rp.remoteIP)+msg, args...) +} + +func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) { + if peer != nil && peer.Version != rp.peerVersion { + rp.peerUpdates <- peer + rp.peerVersion = peer.Version + } +} + +// ---------------------------------------------------------------------------- + +// HandlePacket accepts a raw data packet coming in from the network. +// +// This function is called by a single thread. +func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h xHeader, data []byte) { + switch h.StreamID { + case controlStreamID: + rp.handleControlPacket(addr, h, data) + + case dataStreamID: + rp.handleDataPacket(data) + + case forwardStreamID: + fallthrough + // TODO + //rp.handleForwardPacket(h, data) + default: + rp.logf("Unknown stream ID: %d", h.StreamID) + } +} + +// ---------------------------------------------------------------------------- + +func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h xHeader, data []byte) { + shared := rp.shared.Load() + if shared.controlCipher == nil { + rp.logf("Not connected (control).") + return + } + + out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf) + if !ok { + rp.logf("Failed to decrypt control packet.") + return + } + + if len(out) == 0 { + rp.logf("Empty control packet from: %d", h.SourceIP) + return + } + + if rp.dupCheck.IsDup(h.Counter) { + rp.logf("Duplicate control packet: %d", h.Counter) + return + } + + if h.DestIP != rp.localIP { + // TODO: Forward control packet. + // TODO: Probably this should be dropped. + // Control packets should be forwarded as data for efficiency. + return + } + + pkt := controlPacket{ + SrcIP: h.SourceIP, + RemoteAddr: addr, + } + + var err error + + switch out[0] { + case packetTypePing: + pkt.Payload, err = parsePingPacket(out) + case packetTypePong: + pkt.Payload, err = parsePongPacket(out) + default: + rp.logf("Unknown control packet type: %d", out[0]) + return + } + + if err != nil { + rp.logf("Failed to parse control packet: %v", err) + return + } + + select { + case rp.controlPackets <- pkt: + default: + rp.logf("Dropping control packet.") + } +} + +func (rp *remotePeer) handleDataPacket(data []byte) { + shared := rp.shared.Load() + if shared.dataCipher == nil { + rp.logf("Not connected (recv).") + return + } + + dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf) + if !ok { + rp.logf("Failed to decrypt data packet.") + return + } + + rp.iface.Write(dec) +} + +// ---------------------------------------------------------------------------- + +// SendData sends data coming from the interface going to the network. +// +// This function is called by a single thread. +func (rp *remotePeer) SendData(data []byte) { + shared := rp.shared.Load() + if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { + rp.logf("Not connected (send).") + return + } + + h := xHeader{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&rp.counter, 1), + SourceIP: rp.localIP, + DestIP: rp.remoteIP, + } + + enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) + rp.conn.WriteTo(enc, shared.remoteAddr) +} diff --git a/node/peersupervisor.go b/node/peersupervisor.go index bdcf03f..90763b4 100644 --- a/node/peersupervisor.go +++ b/node/peersupervisor.go @@ -8,12 +8,6 @@ import ( "vppn/m" ) -const ( - connectTimeout = 6 * time.Second - pingInterval = 6 * time.Second - timeoutInterval = 20 * time.Second -) - type routingPacketWrapper struct { routingPacket Addr netip.AddrPort // Source. @@ -113,8 +107,6 @@ func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) { // ---------------------------------------------------------------------------- -type stateFunc func() stateFunc - func (s *peerSupervisor) stateInit() stateFunc { if s.peer == nil { return s.stateDisconnected @@ -316,12 +308,12 @@ func (s *peerSupervisor) updateRoutingTable(up bool) { func (s *peerSupervisor) sendPing() uint64 { traceID := newTraceID() pkt := newRoutingPacket(packetTypePing, traceID) - s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) + s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf)) s.pingTimer.Reset(pingInterval) return traceID } func (s *peerSupervisor) sendPong(traceID uint64) { pkt := newRoutingPacket(packetTypePong, traceID) - s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) + s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf)) } diff --git a/node/router.go b/node/router.go index c99f763..0e74d14 100644 --- a/node/router.go +++ b/node/router.go @@ -19,7 +19,7 @@ type peer struct { Up bool // No data will be sent to peers that are down. Addr netip.AddrPort // If we have direct connection, otherwise use mediator. Mediator bool // True if the peer will mediate. - RoutingCipher routingCipher + RoutingCipher controlCipher DataCipher dataCipher // TODO: Deprecated below.