vppn/node/peer-supervisor.go
2025-01-02 07:42:00 +01:00

353 lines
7.4 KiB
Go

package node
import (
"fmt"
"log"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
pingInterval = 8 * time.Second
timeoutInterval = 25 * time.Second
)
// ----------------------------------------------------------------------------
type peerSupervisor struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Incoming events.
messages chan any
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
}
func newPeerSupervisor(i int) *peerSupervisor {
return &peerSupervisor{
published: routingTable[i],
remoteIP: byte(i),
messages: messages[i],
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
}
}
type stateFunc func() stateFunc
func (s *peerSupervisor) Run() {
state := s.noPeer
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
func (s *peerSupervisor) sendControlPacketTo(
pkt interface{ Marshal([]byte) []byte },
addr netip.AddrPort,
) {
if !addr.IsValid() {
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
return
}
route := s.staged
route.Direct = true
route.RemoteAddr = addr
_sendControlPacket(pkt, route, s.buf1, s.buf2)
time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) logf(msg string, args ...any) {
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) publish() {
data := s.staged
s.published.Store(&data)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) noPeer() stateFunc {
for {
rawMsg := <-s.messages
if msg, ok := rawMsg.(peerUpdateMsg); ok {
return s.peerUpdate(msg.Peer)
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
return func() stateFunc { return s._peerUpdate(peer) }
}
func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
defer s.publish()
s.peer = peer
s.staged = peerRoute{}
if s.peer == nil {
return s.noPeer
}
s.staged.IP = s.remoteIP
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey)
s.staged.PubSignKey = peer.PubSignKey
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
s.staged.Direct = true
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
} else if localPub {
s.staged.Direct = true
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return s.server
}
return s.client
}
if s.remotePub {
return s.client
}
return s.server
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) server() stateFunc {
logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) }
logf("DOWN")
var (
syn synPacket
lastSeen = time.Now()
)
for {
rawMsg := <-s.messages
switch msg := rawMsg.(type) {
case peerUpdateMsg:
return s.peerUpdate(msg.Peer)
case controlMsg[synPacket]:
p := msg.Packet
lastSeen = time.Now()
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != syn.TraceID || !s.staged.Up {
if p.Direct {
logf("UP - Direct")
} else {
logf("UP - Relayed")
}
syn = p
s.staged.Up = true
s.staged.Direct = syn.Direct
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
s.staged.RemoteAddr = msg.SrcAddr
s.publish()
}
// We should always respond.
ack := synAckPacket{
TraceID: syn.TraceID,
FromAddr: getLocalAddr(),
}
s.sendControlPacket(ack)
if s.staged.Direct {
continue
}
if !syn.FromAddr.IsValid() {
continue
}
probe := probePacket{TraceID: newTraceID()}
s.sendControlPacketTo(probe, syn.FromAddr)
case controlMsg[probePacket]:
if !msg.SrcAddr.IsValid() {
logf("Invalid probe address")
continue
}
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
case pingTimerMsg:
if time.Since(lastSeen) > timeoutInterval && s.staged.Up {
logf("Connection timeout")
s.staged.Up = false
s.publish()
}
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) client() stateFunc {
logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) }
logf("DOWN")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct,
FromAddr: getLocalAddr(),
}
lastSeen = time.Now()
ack synAckPacket
probe probePacket
probeAddr netip.AddrPort
localProbe probePacket
localProbeAddr netip.AddrPort
lastLocalAddr netip.AddrPort
)
s.sendControlPacket(syn)
for {
rawMsg := <-s.messages
switch msg := rawMsg.(type) {
case peerUpdateMsg:
return s.peerUpdate(msg.Peer)
case controlMsg[synAckPacket]:
p := msg.Packet
if p.TraceID != syn.TraceID {
continue // Hmm...
}
lastSeen = time.Now()
ack = msg.Packet
if !s.staged.Up {
if s.staged.Direct {
logf("UP - Direct")
} else {
logf("UP - Relayed")
}
s.staged.Up = true
s.publish()
}
case controlMsg[probePacket]:
if s.staged.Direct {
continue
}
p := msg.Packet
if p.TraceID != localProbe.TraceID && p.TraceID != probe.TraceID {
continue
}
// Upgrade connection.
s.staged.Direct = true
if p.TraceID == localProbe.TraceID {
logf("UP - Local")
s.staged.RemoteAddr = localProbeAddr
} else {
logf("UP - Direct")
s.staged.RemoteAddr = probeAddr
}
s.publish()
syn.TraceID = newTraceID()
syn.Direct = true
syn.FromAddr = getLocalAddr()
s.sendControlPacket(syn)
case controlMsg[localDiscoveryPacket]:
if s.staged.Direct {
continue
}
// Send probe.
//
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
localProbe = probePacket{TraceID: newTraceID()}
localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
s.sendControlPacketTo(localProbe, localProbeAddr)
case pingTimerMsg:
if time.Since(lastSeen) > timeoutInterval {
logf("Connection timeout")
return s.peerUpdate(s.peer)
}
syn.FromAddr = getLocalAddr()
if syn.FromAddr != lastLocalAddr {
syn.TraceID = newTraceID()
lastLocalAddr = syn.FromAddr
}
s.sendControlPacket(syn)
if s.staged.Direct {
continue
}
if !ack.FromAddr.IsValid() {
continue
}
probe = probePacket{TraceID: newTraceID()}
probeAddr = ack.FromAddr
s.sendControlPacketTo(probe, ack.FromAddr)
}
}
}