diff --git a/README.md b/README.md index abb55f0..1a5ab6a 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,16 @@ ## Roadmap -* Peer: router: create process for managing the routing table -* Peer: router: track mediators, enable / disable ... -* Hub: track peer last-seen timestamp (?) +* Use default port 456 +* Remove signing key from hub +* Peer: UDP hole-punching * Peer: local peer discovery - part of RoutingProcessor * Peer: update hub w/ latest port on startup ## Learnings * Encryption / decryption is 20x faster than signing/opening. +* Allowing out-of order packets is massively important for throughput with TCP ## Principles @@ -111,11 +112,3 @@ TimeoutStopSec=24 [Install] WantedBy=default.target ``` - ---- - -## Sub-packets - -If we make our MTU large, like 8k, our computations become more efficient. - -We can send packets with header like: diff --git a/peer/conndata.go b/peer/conndata.go index dbbc29c..24d680a 100644 --- a/peer/conndata.go +++ b/peer/conndata.go @@ -14,12 +14,9 @@ const ( type connData struct { // Shared data. - routes [256]*atomic.Pointer[route] + routes [MAX_IP]*atomic.Pointer[route] route *atomic.Pointer[route] - // Local data. - mediatorIP byte - // Peer data. server bool // Never changes. peerIP byte // Never changes. @@ -29,13 +26,14 @@ type connData struct { encSharedKey []byte // From hub + private key. publicAddr netip.AddrPort // From hub. + // Connection establishment and maintenance. pingTimer *time.Timer timeoutTimer *time.Timer // Routing data. - addr netip.AddrPort - viaIP byte - up bool + addr netip.AddrPort + useMediator bool + up bool // For sending. buf []byte @@ -47,10 +45,9 @@ func (d *connData) Route() *route { PeerIP: d.peerIP, Up: d.up, Mediator: d.peer.Mediator, - SignPubKey: d.peer.SignPubKey, EncSharedKey: d.encSharedKey, Addr: d.addr, - ViaIP: d.viaIP, + useMediator: d.useMediator, } } @@ -61,14 +58,14 @@ func (d *connData) HandlePeerUpdate(state connState, update peerUpdate) connStat if d.peer == nil && update.Peer == nil { return state } - return newConnStateFromPeer(update, d) + return newStateFromPeerUpdate(update, d) } func (d *connData) HandleSendPing() { route := d.route.Load() req := Ping{SentAt: time.Now().UnixMilli()} req.Marshal(d.buf[:PING_SIZE]) - d.sender.send(PACKET_TYPE_PING, d.buf[:PING_SIZE], route) + d.sender.send(PACKET_TYPE_PING, d.buf[:PING_SIZE], route, nil) d.pingTimer.Reset(pingInterval) } @@ -79,5 +76,5 @@ func (d *connData) sendPong(w wrapper[Ping]) { RecvdAt: time.Now().UnixMilli(), } pong.Marshal(d.buf[:PONG_SIZE]) - d.sender.send(PACKET_TYPE_PONG, d.buf[:PONG_SIZE], route) + d.sender.send(PACKET_TYPE_PONG, d.buf[:PONG_SIZE], route, nil) } diff --git a/peer/connhandler.go b/peer/connhandler.go index c9ff8df..a5dea4e 100644 --- a/peer/connhandler.go +++ b/peer/connhandler.go @@ -21,7 +21,7 @@ type connHandler struct { func newConnHandler( server bool, peerIP byte, - routes [256]*atomic.Pointer[route], + routes [MAX_IP]*atomic.Pointer[route], encPrivKey []byte, sender *safeConnSender, ) *connHandler { @@ -65,9 +65,6 @@ func (h *connHandler) mainLoop() { for { select { - case ip := <-h.mediatorUpdates: - state = state.HandleMediatorUpdate(ip) - case update := <-h.peerUpdates: state = data.HandlePeerUpdate(state, update) @@ -92,13 +89,6 @@ func (h *connHandler) mainLoop() { } } -func (c *connHandler) UpdateMediator(ip byte) { - select { - case c.mediatorUpdates <- ip: - default: - } -} - func (c *connHandler) HandlePing(w wrapper[Ping]) { select { case c.pings <- w: diff --git a/peer/connsender.go b/peer/connsender.go index cf3f3c5..4bf5916 100644 --- a/peer/connsender.go +++ b/peer/connsender.go @@ -9,51 +9,58 @@ import ( ) type connSender struct { - conn *net.UDPConn - sourceIP byte - streamID byte - encrypted []byte - nonceBuf []byte - counter uint64 - signingKey []byte + conn *net.UDPConn + sourceIP byte + streamID byte + encrypted []byte + nonceBuf []byte + counter uint64 } -func newConnSender(conn *net.UDPConn, srcIP, streamID byte, signingPrivKey []byte) *connSender { +func newConnSender(conn *net.UDPConn, srcIP, streamID byte) *connSender { return &connSender{ - conn: conn, - sourceIP: srcIP, - streamID: streamID, - encrypted: make([]byte, BUFFER_SIZE), - nonceBuf: make([]byte, NONCE_SIZE), - counter: uint64(fasttime.Now()) << 30, // Ensure counter is always increasing. - signingKey: signingPrivKey, + conn: conn, + sourceIP: srcIP, + streamID: streamID, + encrypted: make([]byte, BUFFER_SIZE), + nonceBuf: make([]byte, NONCE_SIZE), + counter: uint64(fasttime.Now()) << 30, // Ensure counter is always increasing. } } -func (cs *connSender) send(packetType byte, packet []byte, route *route) { +func (cs *connSender) send(packetType byte, packet []byte, dstRoute, viaRoute *route) { + if dstRoute.useMediator && viaRoute == nil { + log.Printf("Dropping forwarded packet: no mediator.") + return + } + cs.counter++ + nonce := Nonce{ Timestamp: fasttime.Now(), Counter: cs.counter, SourceIP: cs.sourceIP, - ViaIP: route.ViaIP, - DestIP: route.PeerIP, + DestIP: dstRoute.PeerIP, StreamID: cs.streamID, PacketType: packetType, } - nonce.Marshal(cs.nonceBuf) - - encrypted := encryptPacket(route.EncSharedKey, cs.nonceBuf, packet, cs.encrypted) - - var toSend []byte - if route.ViaIP != 0 { - toSend = signPacket(cs.signingKey, encrypted, packet) - } else { - toSend = encrypted + if dstRoute.useMediator { + nonce.ViaIP = viaRoute.PeerIP } - if _, err := cs.conn.WriteToUDPAddrPort(toSend, route.Addr); err != nil { + nonce.Marshal(cs.nonceBuf) + + addr := dstRoute.Addr + + encrypted := encryptPacket(dstRoute.EncSharedKey, cs.nonceBuf, packet, cs.encrypted) + if viaRoute != nil { + packet, encrypted = encrypted, packet + encrypted = encryptPacket(viaRoute.EncSharedKey, cs.nonceBuf, packet, encrypted) + addr = viaRoute.Addr + } + + if _, err := cs.conn.WriteToUDPAddrPort(encrypted, addr); err != nil { log.Fatalf("Failed to write UDP packet: %v\n%s", err, debug.Stack()) } } @@ -69,8 +76,8 @@ func newSafeConnSender(sender *connSender) *safeConnSender { return &safeConnSender{sender: sender} } -func (s *safeConnSender) send(packetType byte, packet []byte, route *route) { +func (s *safeConnSender) send(packetType byte, packet []byte, route, viaRoute *route) { s.lock.Lock() defer s.lock.Unlock() - s.sender.send(packetType, packet, route) + s.sender.send(packetType, packet, route, viaRoute) } diff --git a/peer/connstate.go b/peer/connstate.go index dae3167..73853a8 100644 --- a/peer/connstate.go +++ b/peer/connstate.go @@ -11,32 +11,31 @@ func logState(s connState, msg string, args ...any) { log.Printf("["+s.Name()+"] "+msg, args...) } -// ---------------------------------------------------------------------------- - // The connection state corresponds to what we're connected TO. type connState interface { Name() string - HandleMediatorUpdate(ip byte) connState + //HandleConnReq(wrapper[ConnReq]) connState HandlePing(wrapper[Ping]) connState HandlePong(wrapper[Pong]) connState HandleTimeout() connState } -// Helper function. +// Helper functions. -func newConnStateFromPeer(update peerUpdate, data *connData) connState { - peer := update.Peer - - if peer == nil { - return newConnNull(data) +func newStateFromPeerUpdate(update peerUpdate, data *connData) connState { + if update.Peer != nil { + return newStateFromPeer(update.Peer, data) } + return newConnNull(data) +} +func newStateFromPeer(peer *m.Peer, data *connData) connState { if _, isPublic := netip.AddrFromSlice(peer.PublicIP); isPublic { return newStateServerDown(data, peer) } else if data.server { return newStateClientDown(data, peer) } else { - return newStateMediatedDown(data, peer) + return newStateMediated(data, peer) } } @@ -56,7 +55,7 @@ func newConnNull(data *connData) connState { c.pingTimer.Stop() c.timeoutTimer.Stop() c.addr = c.publicAddr - c.viaIP = 0 + c.useMediator = false c.up = false c.route.Store(nil) return c @@ -66,8 +65,8 @@ func (c connNull) Name() string { return "NoPeer" } -func (c connNull) HandleMediatorUpdate(ip byte) connState { - c.mediatorIP = ip +func (c connNull) HandleConnReq(w wrapper[ConnReq]) connState { + logState(c, "Ignoring conn request.") return c } @@ -102,10 +101,10 @@ func newStateServerDown(data *connData, peer *m.Peer) connState { c.peer = peer c.encSharedKey = computeSharedKey(peer.EncPubKey, c.encPrivKey) c.publicAddr = pubAddr - c.pingTimer.Reset(time.Millisecond) // Ping right away to bring up. - c.timeoutTimer.Stop() // No timeouts yet. + c.pingTimer.Reset(time.Second) // Ping right away to bring up. + c.timeoutTimer.Stop() // No timeouts yet. c.addr = c.publicAddr - c.viaIP = 0 + c.useMediator = false c.up = false c.route.Store(c.Route()) @@ -116,9 +115,9 @@ func (c stateServerDown) Name() string { return "Server:DOWN" } -func (c stateServerDown) HandleMediatorUpdate(ip byte) connState { - // Server connection doesn't use a mediator. - c.mediatorIP = ip +func (c stateServerDown) HandleConnReq(w wrapper[ConnReq]) connState { + // Send ConnResp. + // TODO return c } @@ -149,7 +148,7 @@ func newStateServerUp(data *connData, w wrapper[Pong]) connState { c.pingTimer.Reset(pingInterval) c.timeoutTimer.Reset(timeoutInterval) c.addr = w.SrcAddr - c.viaIP = 0 + c.useMediator = false c.up = true c.route.Store(c.Route()) return c @@ -159,12 +158,6 @@ func (c stateServerUp) Name() string { return "Server:UP" } -func (c stateServerUp) HandleMediatorUpdate(ip byte) connState { - // Server connection doesn't use a mediator. - c.mediatorIP = ip - return c -} - func (c stateServerUp) HandlePing(w wrapper[Ping]) connState { logState(c, "Ignoring ping.") return c @@ -176,7 +169,7 @@ func (c stateServerUp) HandlePong(w wrapper[Pong]) connState { } func (c stateServerUp) HandleTimeout() connState { - return newStateServerDown(c.connData, c.peer) + return newStateFromPeer(c.peer, c.connData) } //////////////////////// @@ -197,7 +190,7 @@ func newStateClientDown(data *connData, peer *m.Peer) connState { c.encPrivKey = data.encPrivKey c.encSharedKey = computeSharedKey(peer.EncPubKey, c.encPrivKey) c.addr = c.publicAddr - c.viaIP = 0 + c.useMediator = false c.up = false c.route.Store(c.Route()) @@ -211,13 +204,8 @@ func (c stateClientDown) Name() string { return "Client:DOWN" } -func (c stateClientDown) HandleMediatorUpdate(ip byte) connState { - // Client connection doesn't use a mediator. - c.mediatorIP = ip - return c -} - func (c stateClientDown) HandlePing(w wrapper[Ping]) connState { + log.Printf("Got ping...") next := newStateClientUp(c.connData, w) c.sendPong(w) // Have to send after transitionsing so route is ok. return next @@ -244,7 +232,7 @@ type stateClientUp struct { func newStateClientUp(data *connData, w wrapper[Ping]) connState { c := stateClientUp{data} c.addr = w.SrcAddr - c.viaIP = 0 + c.useMediator = false c.up = true c.route.Store(c.Route()) @@ -257,12 +245,6 @@ func (c stateClientUp) Name() string { return "Client:UP" } -func (c stateClientUp) HandleMediatorUpdate(ip byte) connState { - // Client connection doesn't use a mediator. - c.mediatorIP = ip - return c -} - func (c stateClientUp) HandlePing(w wrapper[Ping]) connState { // The connection is from a client. If the client's address changes, we // should follow that change. @@ -281,112 +263,51 @@ func (c stateClientUp) HandlePong(w wrapper[Pong]) connState { } func (c stateClientUp) HandleTimeout() connState { - return newStateClientDown(c.connData, c.peer) + return newStateFromPeer(c.peer, c.connData) } -////////////////////////// -// Unconnected Mediator // -////////////////////////// +////////////// +// Mediated // +////////////// -type stateMediatedDown struct { +type stateMediated struct { *connData } -func newStateMediatedDown(data *connData, peer *m.Peer) connState { +func newStateMediated(data *connData, peer *m.Peer) connState { addr, _ := netip.AddrFromSlice(peer.PublicIP) pubAddr := netip.AddrPortFrom(addr, peer.Port) - c := stateMediatedDown{data} + c := stateMediated{data} c.peer = peer c.publicAddr = pubAddr c.encPrivKey = data.encPrivKey c.encSharedKey = computeSharedKey(peer.EncPubKey, c.encPrivKey) c.addr = c.publicAddr - c.viaIP = 0 - c.up = false + c.useMediator = true + c.up = true c.route.Store(c.Route()) c.pingTimer.Stop() // No pings for mediators. c.timeoutTimer.Stop() // No timeouts yet. - - // If we have a mediator route, we can connect. - if mRoute := c.routes[c.mediatorIP].Load(); mRoute != nil { - return newStateMediatedUp(data, mRoute) - } - return c } -func (c stateMediatedDown) Name() string { - return "Mediated:DOWN" +func (c stateMediated) Name() string { + return "Mediated:UP" } -func (c stateMediatedDown) HandleMediatorUpdate(ip byte) connState { - c.mediatorIP = ip - if mRoute := c.routes[c.mediatorIP].Load(); mRoute != nil { - return newStateMediatedUp(c.connData, mRoute) - } - return c -} - -func (c stateMediatedDown) HandlePing(w wrapper[Ping]) connState { +func (c stateMediated) HandlePing(w wrapper[Ping]) connState { logState(c, "Ignorning ping.") return c } -func (c stateMediatedDown) HandlePong(w wrapper[Pong]) connState { +func (c stateMediated) HandlePong(w wrapper[Pong]) connState { logState(c, "Ignorning pong.") return c } -func (c stateMediatedDown) HandleTimeout() connState { +func (c stateMediated) HandleTimeout() connState { logState(c, "Unexpected timeout.") return c } - -//////////////////////// -// Connected Mediator // -//////////////////////// - -type stateMediatedUp struct { - *connData -} - -func newStateMediatedUp(data *connData, route *route) connState { - c := stateMediatedUp{data} - c.addr = route.Addr - c.viaIP = route.PeerIP - c.up = true - c.route.Store(c.Route()) - - // No pings for mediated routes. - c.pingTimer.Stop() - c.timeoutTimer.Stop() - return c -} - -func (c stateMediatedUp) Name() string { - return "Mediated:UP" -} - -func (c stateMediatedUp) HandleMediatorUpdate(ip byte) connState { - c.mediatorIP = ip - if mRoute := c.routes[c.mediatorIP].Load(); mRoute != nil { - return newStateMediatedUp(c.connData, mRoute) - } - return newStateMediatedDown(c.connData, c.peer) -} - -func (c stateMediatedUp) HandlePing(w wrapper[Ping]) connState { - logState(c, "Ignoring ping.") - return c -} - -func (c stateMediatedUp) HandlePong(w wrapper[Pong]) connState { - logState(c, "Ignoring pong.") - return c -} - -func (c stateMediatedUp) HandleTimeout() connState { - return newStateMediatedDown(c.connData, c.peer) -} diff --git a/peer/crypto.go b/peer/crypto.go index 135df55..3d27b32 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -2,7 +2,6 @@ package peer import ( "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" ) func encryptPacket(sharedKey, nonce, packet, out []byte) []byte { @@ -16,15 +15,6 @@ func decryptPacket(sharedKey, packet, out []byte) (decrypted []byte, ok bool) { return decrypted, ok } -// Signed packet should be encrypted with the encryptPacket function first. -func signPacket(privKey, packet, out []byte) []byte { - return sign.Sign(out[:0], packet, (*[64]byte)(privKey)) -} - -func openPacket(pubKey, packet, out []byte) (encPacket []byte, ok bool) { - return sign.Open(out[:0], packet, (*[32]byte)(pubKey)) -} - func computeSharedKey(peerPubKey, privKey []byte) []byte { shared := [32]byte{} box.Precompute(&shared, (*[32]byte)(peerPubKey), (*[32]byte)(privKey)) diff --git a/peer/crypto_test.go b/peer/crypto_test.go index e2b5f1b..623412e 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -6,7 +6,6 @@ import ( "testing" "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" ) func TestEncryptDecryptPacket(t *testing.T) { @@ -105,60 +104,3 @@ func BenchmarkDecryptPacket(b *testing.B) { decrypted, _ = decryptPacket(sharedDecKey[:], encrypted, decrypted) } } - -func BenchmarkSignPacket(b *testing.B) { - _, privKey1, err := sign.GenerateKey(rand.Reader) - if err != nil { - b.Fatal(err) - } - - original := make([]byte, 8192) - rand.Read(original) - out := make([]byte, 9000) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - signPacket(privKey1[:], original, out) - } -} - -func TestSignOpenPacket(t *testing.T) { - pubKey, privKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - - packet := make([]byte, MTU) - - rand.Read(packet) - - signedPacket := signPacket(privKey[:], packet, make([]byte, BUFFER_SIZE)) - - encPacket, ok := openPacket(pubKey[:], signedPacket, make([]byte, BUFFER_SIZE)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(encPacket, packet) { - t.Fatal("not equal") - } -} - -func BenchmarkOpenPacket(b *testing.B) { - pubKey, privKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - b.Fatal(err) - } - - packet := make([]byte, MTU) - - rand.Read(packet) - - signedPacket := signPacket(privKey[:], packet, make([]byte, 9000)) - out := make([]byte, BUFFER_SIZE) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - out, _ = openPacket(pubKey[:], signedPacket, out) - } -} diff --git a/peer/duplist.go b/peer/duplist.go new file mode 100644 index 0000000..622efdc --- /dev/null +++ b/peer/duplist.go @@ -0,0 +1,17 @@ +package peer + +type dupList struct { + items [64]uint64 + index int +} + +func (l *dupList) isDuplicate(in uint64) bool { + for _, i := range l.items { + if i == in { + return true + } + } + l.items[l.index] = in + l.index = (l.index + 1) % 64 + return false +} diff --git a/peer/globals.go b/peer/globals.go index a4560d7..636f895 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -1,22 +1,14 @@ package peer const ( + MAX_IP = 65 DEFAULT_PORT = 515 NONCE_SIZE = 24 KEY_SIZE = 32 SIG_SIZE = 64 - MTU = 1376 - BUFFER_SIZE = 2048 // Definitely big enough. + MTU = 1436 + BUFFER_SIZE = 1536 // Definitely big enough. STREAM_DATA = 0 STREAM_ROUTING = 1 // Routing queries and responses. - - // Basic packet types - PACKET_TYPE_DATA = 0 - PACKET_TYPE_PING = 1 - PACKET_TYPE_PONG = 2 - - // Packet sizes. - PING_SIZE = 8 - PONG_SIZE = 16 ) diff --git a/peer/peer-ifreader.go b/peer/peer-ifreader.go index aaa2855..63ac0c1 100644 --- a/peer/peer-ifreader.go +++ b/peer/peer-ifreader.go @@ -14,15 +14,16 @@ func (peer *Peer) ifReader() { }() var ( - sender = newConnSender(peer.conn, peer.ip, STREAM_DATA, peer.signPrivKey) - n int - destIP byte - router = peer.router - route *route - iface = peer.iface - err error - packet = make([]byte, BUFFER_SIZE) - version byte + sender = newConnSender(peer.conn, peer.ip, STREAM_DATA) + n int + destIP byte + router = peer.router + viaRoute *route + route *route + iface = peer.iface + err error + packet = make([]byte, BUFFER_SIZE) + version byte ) for { @@ -54,6 +55,16 @@ func (peer *Peer) ifReader() { continue } - sender.send(PACKET_TYPE_DATA, packet, route) + if route.useMediator { + viaRoute = router.GetMediator() + if viaRoute == nil || !viaRoute.Up { + log.Printf("Dropping packet due to no mediator: %d", destIP) + continue + } + } else { + viaRoute = nil + } + + sender.send(PACKET_TYPE_DATA, packet, route, viaRoute) } } diff --git a/peer/peer-netreader.go b/peer/peer-netreader.go index 661e0a6..2137df9 100644 --- a/peer/peer-netreader.go +++ b/peer/peer-netreader.go @@ -16,6 +16,7 @@ func (peer *Peer) netReader() { }() var ( + dupList = &dupList{} n int srcAddr netip.AddrPort nonce Nonce @@ -56,8 +57,9 @@ NEXT_PACKET: goto NEXT_PACKET } - if nonce.Counter <= counters[nonce.StreamID][nonce.SourceIP] { - log.Printf("Dropping packet with bad counter: -%d", counters[nonce.StreamID][nonce.SourceIP]-nonce.Counter) + if dupList.isDuplicate(nonce.Counter) { + //if nonce.Counter+64 <= counters[nonce.StreamID][nonce.SourceIP] { + log.Printf("Dropping packet with bad counter: %d (-%d) - %v", nonce.Counter, counters[nonce.StreamID][nonce.SourceIP]-nonce.Counter, srcAddr) goto NEXT_PACKET } @@ -67,26 +69,28 @@ NEXT_PACKET: goto NEXT_PACKET } - switch ip { - case nonce.DestIP: - goto DECRYPT - case nonce.ViaIP: - goto VALIDATE_SIGNATURE - default: - log.Printf("Bad packet: %+v", nonce) - goto NEXT_PACKET - } - -DECRYPT: - decrypted, ok = decryptPacket(route.EncSharedKey, packet, decrypted) if !ok { log.Printf("Failed to decrypt packet: %v", nonce) goto NEXT_PACKET } - // Only updated after verification. - counters[nonce.StreamID][nonce.SourceIP] = nonce.Counter + // Only updated after we've decrypted. + if nonce.Counter > counters[nonce.StreamID][nonce.SourceIP] { + counters[nonce.StreamID][nonce.SourceIP] = nonce.Counter + } + + switch ip { + case nonce.DestIP: + goto PROCESS_LOCAL + case nonce.ViaIP: + goto FORWARD + default: + log.Printf("Invalid nonce: %+v", nonce) + goto NEXT_PACKET + } + +PROCESS_LOCAL: switch nonce.StreamID { case STREAM_DATA: @@ -112,16 +116,7 @@ WRITE_ROUTING_PACKET: goto NEXT_PACKET -VALIDATE_SIGNATURE: - - decrypted, ok = openPacket(route.SignPubKey, packet, decrypted) - if !ok { - log.Printf("Failed to open signed packet: %v", nonce) - goto NEXT_PACKET - } - - // Only updated after verification. - counters[nonce.StreamID][nonce.SourceIP] = nonce.Counter +FORWARD: route = peer.router.GetRoute(nonce.DestIP) if route == nil || !route.Up { @@ -130,7 +125,7 @@ VALIDATE_SIGNATURE: } // We don't forward twice. - if route.ViaIP != 0 { + if route.useMediator { log.Printf("Dropping double-forward packet: %v", nonce) goto NEXT_PACKET } diff --git a/peer/peer.go b/peer/peer.go index c47fcaa..fa50ac2 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -8,17 +8,15 @@ import ( ) type Peer struct { - ip byte // Last byte of IPv4 address. - hubAddr string - apiKey string - isServer bool - isMediator bool - encPubKey []byte - encPrivKey []byte - signPubKey []byte - signPrivKey []byte - conn *net.UDPConn - iface io.ReadWriteCloser + ip byte // Last byte of IPv4 address. + hubAddr string + apiKey string + isServer bool + isMediator bool + encPubKey []byte + encPrivKey []byte + conn *net.UDPConn + iface io.ReadWriteCloser router *Router } @@ -30,14 +28,12 @@ func NewPeer(netName, listenIP string, port uint16) (*Peer, error) { } peer := &Peer{ - ip: conf.PeerIP, - hubAddr: conf.HubAddress, - isMediator: conf.Mediator, - apiKey: conf.APIKey, - encPubKey: conf.EncPubKey, - encPrivKey: conf.EncPrivKey, - signPubKey: conf.SignPubKey, - signPrivKey: conf.SignPrivKey, + ip: conf.PeerIP, + hubAddr: conf.HubAddress, + isMediator: conf.Mediator, + apiKey: conf.APIKey, + encPubKey: conf.EncPubKey, + encPrivKey: conf.EncPrivKey, } _, peer.isServer = netip.AddrFromSlice(conf.PublicIP) diff --git a/peer/router-managemediator.go b/peer/router-managemediator.go index 9778508..ec2d005 100644 --- a/peer/router-managemediator.go +++ b/peer/router-managemediator.go @@ -1,6 +1,7 @@ package peer import ( + "log" "math/rand" "time" ) @@ -31,15 +32,11 @@ func (r *Router) manageMediator() { } if len(mediators) == 0 { - ip = 0 + r.mediatorIP.Store(nil) } else { ip = mediators[rand.Intn(len(mediators))].PeerIP - } - - for _, conn := range r.conns { - if conn != nil { - conn.UpdateMediator(ip) - } + log.Printf("Got mediator IP: %d", ip) + r.mediatorIP.Store(&ip) } } } diff --git a/peer/router-pollhub.go b/peer/router-pollhub.go index eb082b6..88d7dac 100644 --- a/peer/router-pollhub.go +++ b/peer/router-pollhub.go @@ -55,9 +55,9 @@ func (r *Router) _pollHub(client *http.Client, req *http.Request) { return } - for i, peer := range state.Peers { + for i := range r.conns { if r.conns[i] != nil { - r.conns[i].UpdatePeer(peerUpdate{PeerIP: byte(i), Peer: peer}) + r.conns[i].UpdatePeer(peerUpdate{PeerIP: byte(i), Peer: state.Peers[i]}) } } } diff --git a/peer/router-types.go b/peer/router-types.go index 20bd296..f5242d6 100644 --- a/peer/router-types.go +++ b/peer/router-types.go @@ -2,7 +2,6 @@ package peer import ( "net/netip" - "unsafe" "vppn/m" ) @@ -15,10 +14,9 @@ type route struct { PeerIP byte Up bool Mediator bool - SignPubKey []byte - EncSharedKey []byte // Shared key for encoding / decoding packets. - Addr netip.AddrPort // Address to send to. - ViaIP byte // If != 0, this is a forwarding address. + EncSharedKey []byte // Shared key for encoding / decoding packets. + Addr netip.AddrPort + useMediator bool } type peerUpdate struct { @@ -41,34 +39,3 @@ func newWrapper[T any](srcAddr netip.AddrPort, nonce Nonce) wrapper[T] { Nonce: nonce, } } - -// ---------------------------------------------------------------------------- - -type Ping struct { - SentAt int64 // unix milli -} - -func (p *Ping) Parse(buf []byte) { - p.SentAt = *(*int64)(unsafe.Pointer(&buf[0])) -} - -func (p Ping) Marshal(buf []byte) { - *(*int64)(unsafe.Pointer(&buf[0])) = p.SentAt -} - -// ---------------------------------------------------------------------------- - -type Pong struct { - SentAt int64 // unix mili - RecvdAt int64 // unix mili -} - -func (p *Pong) Parse(buf []byte) { - p.SentAt = *(*int64)(unsafe.Pointer(&buf[0])) - p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[8])) -} - -func (p *Pong) Marshal(buf []byte) { - *(*int64)(unsafe.Pointer(&buf[0])) = p.SentAt - *(*int64)(unsafe.Pointer(&buf[8])) = p.RecvdAt -} diff --git a/peer/router.go b/peer/router.go index 60b67ac..6786ee5 100644 --- a/peer/router.go +++ b/peer/router.go @@ -12,12 +12,17 @@ type Router struct { conf m.PeerConfig // Routes used by the peer. - conns [256]*connHandler - routes [256]*atomic.Pointer[route] + conns [MAX_IP]*connHandler + routes [MAX_IP]*atomic.Pointer[route] + addrs [MAX_IP]*atomic.Pointer[netip.AddrPort] + mediatorIP *atomic.Pointer[byte] } func NewRouter(conf m.PeerConfig, conn *net.UDPConn) *Router { - r := &Router{conf: conf} + r := &Router{ + conf: conf, + mediatorIP: &atomic.Pointer[byte]{}, + } for i := range r.routes { r.routes[i] = &atomic.Pointer[route]{} @@ -25,7 +30,7 @@ func NewRouter(conf m.PeerConfig, conn *net.UDPConn) *Router { _, isServer := netip.AddrFromSlice(conf.PublicIP) - sender := newConnSender(conn, conf.PeerIP, STREAM_ROUTING, conf.SignPrivKey) + sender := newConnSender(conn, conf.PeerIP, STREAM_ROUTING) for i := range r.conns { if byte(i) != conf.PeerIP { @@ -54,6 +59,13 @@ func (rm *Router) GetRoute(ip byte) *route { return rm.routes[ip].Load() } +func (rm *Router) GetMediator() *route { + if ip := rm.mediatorIP.Load(); ip != nil { + return rm.GetRoute(*ip) + } + return nil +} + func (r *Router) HandlePacket(src netip.AddrPort, nonce Nonce, data []byte) { if nonce.SourceIP == r.conf.PeerIP { log.Printf("Packet to self...")