vppn/node/peer-supervisor.go
2024-12-23 08:08:23 +01:00

353 lines
7.5 KiB
Go

package node
import (
"fmt"
"log"
"math/rand"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
dialTimeout = 8 * time.Second
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
// ----------------------------------------------------------------------------
type peerSupervisor struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Incoming events.
peerUpdates chan *m.Peer
controlPackets chan controlPacket
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
}
func newPeerSupervisor(i int) *peerSupervisor {
return &peerSupervisor{
published: routingTable[i],
remoteIP: byte(i),
peerUpdates: peerUpdates[i],
controlPackets: controlPackets[i],
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
}
}
type stateFunc func() stateFunc
func (s *peerSupervisor) Run() {
state := s.noPeer
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) logf(msg string, args ...any) {
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) publish() {
data := s.staged
s.published.Store(&data)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) noPeer() stateFunc {
return s.peerUpdate(<-s.peerUpdates)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
return func() stateFunc { return s._peerUpdate(peer) }
}
func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
defer s.publish()
s.peer = peer
s.staged = peerRoute{}
if s.peer == nil {
return s.noPeer
}
s.staged.IP = s.remoteIP
s.staged.ControlCipher = newControlCipher(privateKey, peer.EncPubKey)
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Mediator
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return s.serverAccept
}
return s.clientInit
}
if s.remotePub {
return s.clientInit
}
return s.serverAccept
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) serverAccept() stateFunc {
s.logf("STATE: server-accept")
s.staged.Up = false
s.staged.DataCipher = nil
s.staged.RemoteAddr = zeroAddrPort
s.staged.RelayIP = 0
s.publish()
var syn synPacket
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synPacket:
syn = p
s.staged.RemoteAddr = pkt.RemoteAddr
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
s.staged.RelayIP = syn.RelayIP
s.publish()
s.sendControlPacket(synAckPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
case ackPacket:
if p.TraceID != syn.TraceID {
continue
}
// Publish.
return s.serverConnected(syn.TraceID)
}
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) serverConnected(traceID uint64) stateFunc {
s.logf("STATE: server-connected")
s.staged.Up = true
s.publish()
return func() stateFunc {
return s._serverConnected(traceID)
}
}
func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc {
timeoutTimer := time.NewTimer(timeoutInterval)
defer timeoutTimer.Stop()
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case ackPacket:
if p.TraceID != traceID {
return s.serverAccept
}
s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr})
timeoutTimer.Reset(timeoutInterval)
}
case <-timeoutTimer.C:
s.logf("server timeout")
return s.serverAccept
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clientInit() stateFunc {
s.logf("STATE: client-init")
if !s.remotePub {
// TODO: Check local discovery for IP.
// TODO: Attempt UDP hole punch.
// TODO: client-relayed
return s.clientSelectRelay
}
return s.clientDial
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clientSelectRelay() stateFunc {
s.logf("STATE: client-select-relay")
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case <-timer.C:
relay := s.selectRelay()
if relay != nil {
s.logf("Got relay: %d", relay.IP)
s.staged.RelayIP = relay.IP
s.staged.LocalAddr = relay.LocalAddr
s.publish()
return s.clientDial
}
s.logf("No relay available.")
timer.Reset(pingInterval)
}
}
}
func (s *peerSupervisor) selectRelay() *peerRoute {
possible := make([]*peerRoute, 0, 8)
for i := range routingTable {
route := routingTable[i].Load()
if !route.Up || !route.Relay {
continue
}
possible = append(possible, route)
}
if len(possible) == 0 {
return nil
}
return possible[rand.Intn(len(possible))]
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clientDial() stateFunc {
s.logf("STATE: client-dial")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
RelayIP: s.staged.RelayIP,
}
timeout = time.NewTimer(dialTimeout)
)
defer timeout.Stop()
s.sendControlPacket(syn)
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synAckPacket:
if p.TraceID != syn.TraceID {
continue // Hmm...
}
s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
return s.clientConnected(p)
}
case <-timeout.C:
return s.clientInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clientConnected(p synAckPacket) stateFunc {
s.logf("STATE: client-connected")
s.staged.Up = true
s.staged.LocalAddr = p.RecvAddr
s.publish()
return func() stateFunc {
return s._clientConnected(p.TraceID)
}
}
func (s *peerSupervisor) _clientConnected(traceID uint64) stateFunc {
pingTimer := time.NewTimer(pingInterval)
timeoutTimer := time.NewTimer(timeoutInterval)
defer pingTimer.Stop()
defer timeoutTimer.Stop()
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case ackPacket:
if p.TraceID != traceID {
return s.clientInit
}
timeoutTimer.Reset(timeoutInterval)
}
case <-pingTimer.C:
s.sendControlPacket(ackPacket{TraceID: traceID})
pingTimer.Reset(pingInterval)
case <-timeoutTimer.C:
s.logf("client timeout")
return s.clientInit
}
}
}