vppn/node/peer-supervisor.go

330 lines
7.2 KiB
Go

package node
import (
"fmt"
"log"
"math/rand"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
dialTimeout = 8 * time.Second
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
// ----------------------------------------------------------------------------
type peerSupervisor struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Incoming events.
peerUpdates chan *m.Peer
controlPackets chan controlPacket
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
}
func newPeerSupervisor(i int) *peerSupervisor {
return &peerSupervisor{
published: routingTable[i],
remoteIP: byte(i),
peerUpdates: peerUpdates[i],
controlPackets: controlPackets[i],
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
}
}
type stateFunc func() stateFunc
func (s *peerSupervisor) Run() {
state := s.noPeer
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
}
func (s *peerSupervisor) sendControlPacketTo(
pkt interface{ Marshal([]byte) []byte },
addr netip.AddrPort,
) {
if !addr.IsValid() {
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
return
}
route := s.staged
route.RelayIP = 0
route.RemoteAddr = addr
_sendControlPacket(pkt, route, s.buf1, s.buf2)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) getLocalAddr() netip.AddrPort {
if localPub {
return localAddr
}
if s.staged.RelayIP != 0 {
if addr := routingTable[s.staged.RelayIP].Load().LocalAddr; addr.IsValid() {
return addr
}
}
return s.staged.LocalAddr
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) logf(msg string, args ...any) {
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) publish() {
data := s.staged
s.published.Store(&data)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) noPeer() stateFunc {
return s.peerUpdate(<-s.peerUpdates)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
return func() stateFunc { return s._peerUpdate(peer) }
}
func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
defer s.publish()
s.peer = peer
s.staged = peerRoute{}
if s.peer == nil {
return s.noPeer
}
s.staged.IP = s.remoteIP
s.staged.ControlCipher = newControlCipher(privateKey, peer.PubKey)
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return s.server
}
return s.clientInit
}
if s.remotePub {
return s.clientInit
}
return s.server
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) server() stateFunc {
s.logf("STATE: server")
var syn synPacket
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synPacket:
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
if p.TraceID != syn.TraceID {
syn = p
s.staged.Up = true
s.staged.RemoteAddr = pkt.SrcAddr
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
s.staged.RelayIP = syn.RelayIP
s.staged.LocalAddr = s.getLocalAddr()
s.publish()
}
// We should always respond.
s.sendControlPacket(synAckPacket{
TraceID: syn.TraceID,
FromAddr: s.staged.LocalAddr,
ToAddr: pkt.SrcAddr,
})
// If we're relayed, attempt to probe the client.
if s.staged.RelayIP != 0 && syn.FromAddr.IsValid() {
probe := probePacket{TraceID: newTraceID()}
s.logf("SERVER sending probe %v: %v", probe, syn.FromAddr)
s.sendControlPacketTo(probe, syn.FromAddr)
}
case probePacket:
s.logf("SERVER got probe: %v", p)
s.logf("SERVER sending probe: %v", pkt.SrcAddr)
s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr)
}
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clientInit() stateFunc {
s.logf("STATE: client-init")
if !s.remotePub {
return s.clientSelectRelay
}
return s.client
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clientSelectRelay() stateFunc {
s.logf("STATE: client-select-relay")
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case <-timer.C:
relay := s.selectRelay()
if relay != nil {
s.logf("Got relay: %d", relay.IP)
s.staged.RelayIP = relay.IP
s.staged.LocalAddr = relay.LocalAddr
s.publish()
return s.client
}
s.logf("No relay available.")
timer.Reset(pingInterval)
}
}
}
func (s *peerSupervisor) selectRelay() *peerRoute {
possible := make([]*peerRoute, 0, 8)
for i := range routingTable {
route := routingTable[i].Load()
if !route.Up || !route.Relay {
continue
}
possible = append(possible, route)
}
if len(possible) == 0 {
return nil
}
return possible[rand.Intn(len(possible))]
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) client() stateFunc {
s.logf("STATE: client")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
RelayIP: s.staged.RelayIP,
FromAddr: s.getLocalAddr(),
}
ack synAckPacket
probe = probePacket{TraceID: newTraceID()}
timeoutTimer = time.NewTimer(timeoutInterval)
pingTimer = time.NewTimer(pingInterval)
)
defer timeoutTimer.Stop()
defer pingTimer.Stop()
s.sendControlPacket(syn)
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synAckPacket:
if p.TraceID != syn.TraceID {
s.logf("Bad traceID?")
continue // Hmm...
}
ack = p
timeoutTimer.Reset(timeoutInterval)
if !s.staged.Up {
s.staged.Up = true
s.staged.LocalAddr = p.ToAddr
s.publish()
}
case probePacket:
s.logf("CLIENT got probe: %v", p)
}
case <-pingTimer.C:
s.sendControlPacket(syn)
pingTimer.Reset(pingInterval)
if s.staged.RelayIP != 0 && ack.FromAddr.IsValid() {
s.logf("CLIENT sending probe %v: %v", probe, ack.FromAddr)
s.sendControlPacketTo(probe, ack.FromAddr)
}
case <-timeoutTimer.C:
return s.clientInit
}
}
}