wip: trying to get relaying to work.

This commit is contained in:
jdl 2024-12-20 17:11:20 +01:00
parent 1be5c79186
commit 5b34b3311b
3 changed files with 142 additions and 27 deletions

View File

@ -108,11 +108,11 @@ func main(netName, listenIP string, port uint16) {
peers := remotePeers{} peers := remotePeers{}
for i := range peers { for i := range peers {
peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter) peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter, &peers)
} }
go newHubPoller(netName, conf, peers).Run() go newHubPoller(netName, conf, peers).Run()
go readFromConn(conn, peers) go readFromConn(conf.PeerIP, conn, peers)
readFromIFace(iface, peers) readFromIFace(iface, peers)
} }
@ -131,7 +131,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
func readFromConn(conn *net.UDPConn, peers remotePeers) { func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) {
defer panicHandler() defer panicHandler()
@ -157,7 +157,12 @@ func readFromConn(conn *net.UDPConn, peers remotePeers) {
} }
h.Parse(data) h.Parse(data)
if h.DestIP == localIP {
peers[h.SourceIP].HandlePacket(remoteAddr, h, data) peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
} else {
peers[h.DestIP].ForwardPacket(data)
}
} }
} }

View File

@ -1,6 +1,8 @@
package node package node
import ( import (
"log"
"math/rand"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
@ -47,12 +49,15 @@ func (rp *peerSuper) Run() {
func (rp *peerSuper) stateInit() stateFunc { func (rp *peerSuper) stateInit() stateFunc {
//rp.logf("STATE: Init") //rp.logf("STATE: Init")
x := peerData{} x := peerData{}
rp.shared.Store(&x) rp.shared.Store(&x)
rp.peerData.relay = false
rp.peerData.controlCipher = nil rp.peerData.controlCipher = nil
rp.peerData.dataCipher = nil rp.peerData.dataCipher = nil
rp.peerData.remoteAddr = zeroAddrPort rp.peerData.remoteAddr = zeroAddrPort
rp.peerData.relayIP = 0
if rp.peer == nil { if rp.peer == nil {
return rp.stateDisconnected return rp.stateDisconnected
@ -62,6 +67,8 @@ func (rp *peerSuper) stateInit() stateFunc {
addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
if rp.remotePublic { if rp.remotePublic {
rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
} else {
rp.peerData.relay = false
} }
rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
@ -89,8 +96,7 @@ func (rp *peerSuper) stateSelectRole() stateFunc {
rp.logf("STATE: SelectRole") rp.logf("STATE: SelectRole")
if !rp.localPublic && !rp.remotePublic { if !rp.localPublic && !rp.remotePublic {
// TODO! return rp.stateSelectMediator
return rp.stateDisconnected
} }
if !rp.localPublic { if !rp.localPublic {
@ -99,12 +105,55 @@ func (rp *peerSuper) stateSelectRole() stateFunc {
return rp.stateClient return rp.stateClient
} }
if rp.localIP < rp.peer.PeerIP { if rp.localIP < rp.remoteIP {
return rp.stateClient return rp.stateClient
} }
return rp.stateServer return rp.stateServer
} }
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateSelectMediator() stateFunc {
rp.logf("STATE: SelectMediator")
for {
log.Printf("Selecting mediator...")
if ip := rp.selectMediator(); ip != 0 {
rp.logf("Got mediator: %d", ip)
rp.peerData.relayIP = ip
if rp.localIP < rp.remoteIP {
return rp.stateClient
}
return rp.stateServer
}
select {
case <-time.After(pingInterval):
continue
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
func (rp *peerSuper) selectMediator() byte {
possible := make([]byte, 0, 8)
for _, peer := range rp.peers {
if peer.canRelay() {
rp.logf("relay: %v", peer.shared.Load())
possible = append(possible, peer.remoteIP)
}
}
if len(possible) == 0 {
return 0
}
return possible[rand.Intn(len(possible))]
}
// ----------------------------------------------------------------------------
// The remote is a server. // The remote is a server.
func (rp *peerSuper) stateServer() stateFunc { func (rp *peerSuper) stateServer() stateFunc {
rp.logf("STATE: Server") rp.logf("STATE: Server")
@ -113,9 +162,11 @@ func (rp *peerSuper) stateServer() stateFunc {
var ( var (
pingTimer = time.NewTimer(pingInterval) pingTimer = time.NewTimer(pingInterval)
timeoutTimer = time.NewTimer(timeoutInterval)
ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
) )
defer pingTimer.Stop() defer pingTimer.Stop()
defer timeoutTimer.Stop()
ping.SentAt = time.Now().UnixMilli() ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping) rp.sendControlPacket(ping)
@ -127,8 +178,18 @@ func (rp *peerSuper) stateServer() stateFunc {
rp.sendControlPacket(ping) rp.sendControlPacket(ping)
pingTimer.Reset(pingInterval) pingTimer.Reset(pingInterval)
case <-rp.controlPackets: case cPkt := <-rp.controlPackets:
// Ignore if _, ok := cPkt.Payload.(pongPacket); ok {
timeoutTimer.Reset(timeoutInterval)
}
case <-timeoutTimer.C:
if rp.peerData.relayIP != 0 {
rp.logf("Timeout (server, relay)")
return rp.stateSelectMediator
} else {
rp.logf("Timeout (server)")
}
case rp.peer = <-rp.peerUpdates: case rp.peer = <-rp.peerUpdates:
return rp.stateInit return rp.stateInit
@ -143,8 +204,12 @@ func (rp *peerSuper) stateClient() stateFunc {
rp.logf("STATE: Client") rp.logf("STATE: Client")
rp.updateShared() rp.updateShared()
// TODO: Could use timeout to set dataCipher to nil. var (
var currentKey = [32]byte{} currentKey = [32]byte{}
timeoutTimer = time.NewTimer(timeoutInterval)
)
defer timeoutTimer.Stop()
for { for {
select { select {
@ -163,12 +228,22 @@ func (rp *peerSuper) stateClient() stateFunc {
if ping.SharedKey != currentKey { if ping.SharedKey != currentKey {
rp.logf("Connected with new shared key") rp.logf("Connected with new shared key")
currentKey = ping.SharedKey currentKey = ping.SharedKey
rp.peerData.up = true
rp.peerData.dataCipher = newDataCipherFromKey(currentKey) rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
rp.updateShared() rp.updateShared()
} }
timeoutTimer.Reset(timeoutInterval)
rp.sendControlPacket(newPongPacket(ping.SentAt)) rp.sendControlPacket(newPongPacket(ping.SentAt))
case <-timeoutTimer.C:
if rp.peerData.relayIP != 0 {
rp.logf("Timeout (server, relay)")
return rp.stateSelectMediator
} else {
rp.logf("Timeout (server)")
}
case rp.peer = <-rp.peerUpdates: case rp.peer = <-rp.peerUpdates:
return rp.stateInit return rp.stateInit
} }
@ -193,5 +268,10 @@ func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte })
DestIP: rp.remoteIP, DestIP: rp.remoteIP,
} }
buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
if rp.peerData.relayIP == 0 {
rp.conn.WriteTo(buf, rp.peerData.remoteAddr) rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
return
}
rp.peers[rp.peerData.relayIP].RelayControlData(buf)
} }

View File

@ -12,6 +12,8 @@ import (
type remotePeers [256]*remotePeer type remotePeers [256]*remotePeer
type peerData struct { type peerData struct {
up bool
relay bool
controlCipher *controlCipher controlCipher *controlCipher
dataCipher *dataCipher dataCipher *dataCipher
remoteAddr netip.AddrPort remoteAddr netip.AddrPort
@ -28,6 +30,7 @@ type remotePeer struct {
conn *connWriter conn *connWriter
// Shared state. // Shared state.
peers *remotePeers
shared *atomic.Pointer[peerData] shared *atomic.Pointer[peerData]
// Only used in HandlePeerUpdate. // Only used in HandlePeerUpdate.
@ -48,7 +51,7 @@ type remotePeer struct {
controlPackets chan controlPacket controlPackets chan controlPacket
} }
func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter) *remotePeer { func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter, peers *remotePeers) *remotePeer {
rp := &remotePeer{ rp := &remotePeer{
localIP: conf.PeerIP, localIP: conf.PeerIP,
remoteIP: remoteIP, remoteIP: remoteIP,
@ -56,6 +59,7 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
localPublic: addrIsValid(conf.PublicIP), localPublic: addrIsValid(conf.PublicIP),
iface: iface, iface: iface,
conn: conn, conn: conn,
peers: peers,
shared: &atomic.Pointer[peerData]{}, shared: &atomic.Pointer[peerData]{},
dupCheck: newDupCheck(0), dupCheck: newDupCheck(0),
decryptBuf: make([]byte, bufferSize), decryptBuf: make([]byte, bufferSize),
@ -97,10 +101,6 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
case dataStreamID: case dataStreamID:
rp.handleDataPacket(data) rp.handleDataPacket(data)
case forwardStreamID:
fallthrough
// TODO
//rp.handleForwardPacket(h, data)
default: default:
rp.logf("Unknown stream ID: %d", h.StreamID) rp.logf("Unknown stream ID: %d", h.StreamID)
} }
@ -115,6 +115,11 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
return return
} }
if h.DestIP != rp.localIP {
rp.logf("Incorrect destination IP on control packet.")
return
}
out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf) out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf)
if !ok { if !ok {
rp.logf("Failed to decrypt control packet.") rp.logf("Failed to decrypt control packet.")
@ -131,13 +136,6 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
return return
} }
if h.DestIP != rp.localIP {
// TODO: Forward control packet.
// TODO: Probably this should be dropped.
// Control packets should be forwarded as data for efficiency.
return
}
pkt := controlPacket{ pkt := controlPacket{
SrcIP: h.SourceIP, SrcIP: h.SourceIP,
RemoteAddr: addr, RemoteAddr: addr,
@ -167,6 +165,8 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
} }
} }
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleDataPacket(data []byte) { func (rp *remotePeer) handleDataPacket(data []byte) {
shared := rp.shared.Load() shared := rp.shared.Load()
if shared.dataCipher == nil { if shared.dataCipher == nil {
@ -189,6 +189,29 @@ func (rp *remotePeer) handleDataPacket(data []byte) {
// //
// This function is called by a single thread. // This function is called by a single thread.
func (rp *remotePeer) SendData(data []byte) { func (rp *remotePeer) SendData(data []byte) {
rp.sendData(dataStreamID, data)
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) RelayControlData(data []byte) {
rp.sendData(forwardStreamID, data)
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) ForwardPacket(data []byte) {
shared := rp.shared.Load()
if shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (forward).")
return
}
rp.conn.WriteTo(data, shared.remoteAddr)
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) sendData(streamID byte, data []byte) {
shared := rp.shared.Load() shared := rp.shared.Load()
if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (send).") rp.logf("Not connected (send).")
@ -196,7 +219,7 @@ func (rp *remotePeer) SendData(data []byte) {
} }
h := header{ h := header{
StreamID: dataStreamID, StreamID: streamID,
Counter: atomic.AddUint64(&rp.counter, 1), Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP, SourceIP: rp.localIP,
DestIP: rp.remoteIP, DestIP: rp.remoteIP,
@ -205,3 +228,10 @@ func (rp *remotePeer) SendData(data []byte) {
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
rp.conn.WriteTo(enc, shared.remoteAddr) rp.conn.WriteTo(enc, shared.remoteAddr)
} }
// ----------------------------------------------------------------------------
func (rp *remotePeer) canRelay() bool {
shared := rp.shared.Load()
return shared.relay && shared.up
}