vppn/node/peer-supervisor.go
2024-12-20 17:11:20 +01:00

278 lines
5.8 KiB
Go

package node
import (
"log"
"math/rand"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
type stateFunc func() stateFunc
type peerSuper struct {
*remotePeer
peer *m.Peer
remotePublic bool
peerData peerData
pktBuf []byte
encBuf []byte
}
func newPeerSuper(rp *remotePeer) *peerSuper {
return &peerSuper{
remotePeer: rp,
peer: nil,
pktBuf: make([]byte, bufferSize),
encBuf: make([]byte, bufferSize),
}
}
func (rp *peerSuper) Run() {
defer panicHandler()
state := rp.stateInit
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateInit() stateFunc {
//rp.logf("STATE: Init")
x := peerData{}
rp.shared.Store(&x)
rp.peerData.relay = false
rp.peerData.controlCipher = nil
rp.peerData.dataCipher = nil
rp.peerData.remoteAddr = zeroAddrPort
rp.peerData.relayIP = 0
if rp.peer == nil {
return rp.stateDisconnected
}
var addr netip.Addr
addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
if rp.remotePublic {
rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
} else {
rp.peerData.relay = false
}
rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
return rp.stateSelectRole()
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateDisconnected() stateFunc {
//rp.logf("STATE: Disconnected")
for {
select {
case <-rp.controlPackets:
// Drop
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateSelectRole() stateFunc {
rp.logf("STATE: SelectRole")
if !rp.localPublic && !rp.remotePublic {
return rp.stateSelectMediator
}
if !rp.localPublic {
return rp.stateServer
} else if !rp.remotePublic {
return rp.stateClient
}
if rp.localIP < rp.remoteIP {
return rp.stateClient
}
return rp.stateServer
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateSelectMediator() stateFunc {
rp.logf("STATE: SelectMediator")
for {
log.Printf("Selecting mediator...")
if ip := rp.selectMediator(); ip != 0 {
rp.logf("Got mediator: %d", ip)
rp.peerData.relayIP = ip
if rp.localIP < rp.remoteIP {
return rp.stateClient
}
return rp.stateServer
}
select {
case <-time.After(pingInterval):
continue
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
func (rp *peerSuper) selectMediator() byte {
possible := make([]byte, 0, 8)
for _, peer := range rp.peers {
if peer.canRelay() {
rp.logf("relay: %v", peer.shared.Load())
possible = append(possible, peer.remoteIP)
}
}
if len(possible) == 0 {
return 0
}
return possible[rand.Intn(len(possible))]
}
// ----------------------------------------------------------------------------
// The remote is a server.
func (rp *peerSuper) stateServer() stateFunc {
rp.logf("STATE: Server")
rp.peerData.dataCipher = newDataCipher()
rp.updateShared()
var (
pingTimer = time.NewTimer(pingInterval)
timeoutTimer = time.NewTimer(timeoutInterval)
ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
)
defer pingTimer.Stop()
defer timeoutTimer.Stop()
ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping)
for {
select {
case <-pingTimer.C:
ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping)
pingTimer.Reset(pingInterval)
case cPkt := <-rp.controlPackets:
if _, ok := cPkt.Payload.(pongPacket); ok {
timeoutTimer.Reset(timeoutInterval)
}
case <-timeoutTimer.C:
if rp.peerData.relayIP != 0 {
rp.logf("Timeout (server, relay)")
return rp.stateSelectMediator
} else {
rp.logf("Timeout (server)")
}
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
// The remote is a client.
func (rp *peerSuper) stateClient() stateFunc {
rp.logf("STATE: Client")
rp.updateShared()
var (
currentKey = [32]byte{}
timeoutTimer = time.NewTimer(timeoutInterval)
)
defer timeoutTimer.Stop()
for {
select {
case cPkt := <-rp.controlPackets:
if cPkt.RemoteAddr != rp.peerData.remoteAddr {
rp.peerData.remoteAddr = cPkt.RemoteAddr
rp.logf("Got new remote address: %v", cPkt.RemoteAddr)
rp.updateShared()
}
ping, ok := cPkt.Payload.(pingPacket)
if !ok {
continue
}
if ping.SharedKey != currentKey {
rp.logf("Connected with new shared key")
currentKey = ping.SharedKey
rp.peerData.up = true
rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
rp.updateShared()
}
timeoutTimer.Reset(timeoutInterval)
rp.sendControlPacket(newPongPacket(ping.SentAt))
case <-timeoutTimer.C:
if rp.peerData.relayIP != 0 {
rp.logf("Timeout (server, relay)")
return rp.stateSelectMediator
} else {
rp.logf("Timeout (server)")
}
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) updateShared() {
data := rp.peerData
rp.shared.Store(&data)
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
buf := pkt.Marshal(rp.pktBuf)
h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
DestIP: rp.remoteIP,
}
buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
if rp.peerData.relayIP == 0 {
rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
return
}
rp.peers[rp.peerData.relayIP].RelayControlData(buf)
}