vppn/node/supervisor.go
2025-01-12 20:31:36 +01:00

429 lines
9.4 KiB
Go

package node
import (
"fmt"
"log"
"net/netip"
"strings"
"sync/atomic"
"time"
"vppn/m"
"git.crumpington.com/lib/go/ratelimiter"
)
const (
pingInterval = 8 * time.Second
timeoutInterval = 30 * time.Second
)
// ----------------------------------------------------------------------------
func startPeerSuper() {
peers := [256]peerState{}
for i := range peers {
data := &peerStateData{
published: routingTable[i],
remoteIP: byte(i),
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 20 * time.Millisecond,
MaxWaitCount: 1,
}),
}
peers[i] = data.OnPeerUpdate(nil)
}
go runPeerSuper(peers)
}
func runPeerSuper(peers [256]peerState) {
for raw := range messages {
switch msg := raw.(type) {
case peerUpdateMsg:
peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer)
case controlMsg[synPacket]:
peers[msg.SrcIP].OnSyn(msg)
case controlMsg[ackPacket]:
peers[msg.SrcIP].OnAck(msg)
case controlMsg[probePacket]:
peers[msg.SrcIP].OnProbe(msg)
case controlMsg[localDiscoveryPacket]:
peers[msg.SrcIP].OnLocalDiscovery(msg)
case pingTimerMsg:
publicAddrs.Clean()
for i := range peers {
if newState := peers[i].OnPingTimer(); newState != nil {
peers[i] = newState
}
}
default:
log.Printf("WARNING: unknown message type: %+v", msg)
}
}
}
// ----------------------------------------------------------------------------
type peerState interface {
OnPeerUpdate(*m.Peer) peerState
OnSyn(controlMsg[synPacket])
OnAck(controlMsg[ackPacket])
OnProbe(controlMsg[probePacket])
OnLocalDiscovery(controlMsg[localDiscoveryPacket])
OnPingTimer() peerState
}
// ----------------------------------------------------------------------------
type peerStateData struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
// For logging. Set per-state.
client bool
limiter *ratelimiter.Limiter
}
// ----------------------------------------------------------------------------
func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
s._sendControlPacket(pkt, s.staged)
}
func (s *peerStateData) sendControlPacketTo(pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort) {
if !addr.IsValid() {
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
return
}
route := s.staged
route.Direct = true
route.RemoteAddr = addr
s._sendControlPacket(pkt, route)
}
func (s *peerStateData) _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute) {
if err := s.limiter.Limit(); err != nil {
s.logf("Not sending control packet: rate limited.") // Shouldn't happen.
return
}
_sendControlPacket(pkt, route, s.buf1, s.buf2)
}
// ----------------------------------------------------------------------------
func (s *peerStateData) publish() {
data := s.staged
s.published.Store(&data)
}
func (s *peerStateData) logf(format string, args ...any) {
b := strings.Builder{}
b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name))
if s.client {
b.WriteString("CLIENT|")
} else {
b.WriteString("SERVER|")
}
if s.staged.Direct {
b.WriteString("DIRECT |")
} else {
b.WriteString("RELAYED|")
}
if s.staged.Up {
b.WriteString("UP |")
} else {
b.WriteString("DOWN|")
}
log.Printf(b.String()+format, args...)
}
// ----------------------------------------------------------------------------
func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState {
defer s.publish()
if peer == nil {
return enterStateDisconnected(s)
}
s.peer = peer
s.staged = peerRoute{
IP: s.remoteIP,
PubSignKey: peer.PubSignKey,
ControlCipher: newControlCipher(privKey, peer.PubKey),
DataCipher: newDataCipher(),
}
s.remotePub = false
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
s.staged.Direct = true
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
} else if localPub {
s.staged.Direct = true
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return enterStateServer(s)
}
return enterStateClient(s)
}
if s.remotePub {
return enterStateClient(s)
}
return enterStateServer(s)
}
// ----------------------------------------------------------------------------
type stateDisconnected struct {
*peerStateData
}
func enterStateDisconnected(s *peerStateData) peerState {
s.peer = nil
s.staged = peerRoute{}
s.publish()
return &stateDisconnected{s}
}
func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {}
func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {}
func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {}
func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {}
func (s *stateDisconnected) OnPingTimer() peerState {
return nil
}
// ----------------------------------------------------------------------------
type stateServer struct {
*stateDisconnected
lastSeen time.Time
synTraceID uint64
}
func enterStateServer(s *peerStateData) peerState {
s.client = false
return &stateServer{stateDisconnected: &stateDisconnected{s}}
}
func (s *stateServer) OnSyn(msg controlMsg[synPacket]) {
s.lastSeen = time.Now()
p := msg.Packet
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != s.synTraceID || !s.staged.Up {
s.synTraceID = p.TraceID
s.staged.Up = true
s.staged.Direct = p.Direct
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
s.staged.RemoteAddr = msg.SrcAddr
s.publish()
s.logf("Got syn.")
}
// Always respond.
ack := ackPacket{
TraceID: p.TraceID,
ToAddr: s.staged.RemoteAddr,
PossibleAddrs: publicAddrs.Get(),
}
s.sendControlPacket(ack)
if s.staged.Direct {
return
}
// Not direct => send probes.
for _, addr := range p.PossibleAddrs {
if addr.IsValid() {
s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr)
}
}
}
func (s *stateServer) OnProbe(msg controlMsg[probePacket]) {
if !msg.SrcAddr.IsValid() {
s.logf("Invalid probe address.")
return
}
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
}
func (s *stateServer) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
s.staged.Up = false
s.publish()
s.logf("Connection timeout.")
}
return nil
}
// ----------------------------------------------------------------------------
type stateClient struct {
*stateDisconnected
lastSeen time.Time
syn synPacket
ack ackPacket
probes map[uint64]netip.AddrPort
localDiscoveryAddr chan netip.AddrPort
}
func enterStateClient(s *peerStateData) peerState {
s.client = true
ss := &stateClient{
stateDisconnected: &stateDisconnected{s},
probes: map[uint64]netip.AddrPort{},
localDiscoveryAddr: make(chan netip.AddrPort, 1),
}
ss.syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct,
PossibleAddrs: publicAddrs.Get(),
}
ss.sendControlPacket(ss.syn)
return ss
}
func (s *stateClient) sendProbeTo(addr netip.AddrPort) {
probe := probePacket{TraceID: newTraceID()}
s.probes[probe.TraceID] = addr
s.sendControlPacketTo(probe, addr)
}
func (s *stateClient) OnAck(msg controlMsg[ackPacket]) {
if msg.Packet.TraceID != s.syn.TraceID {
s.logf("Ack has incorrect trace ID")
return
}
s.ack = msg.Packet
s.lastSeen = time.Now()
if !s.staged.Up {
s.staged.Up = true
s.logf("Got ack.")
s.publish()
} else {
// TODO: What????
}
// Store possible public address if we're not a public node.
if !localPub && s.remotePub {
publicAddrs.Store(msg.Packet.ToAddr)
}
}
func (s *stateClient) OnProbe(msg controlMsg[probePacket]) {
if s.staged.Direct {
return
}
addr, ok := s.probes[msg.Packet.TraceID]
if !ok {
return
}
s.staged.RemoteAddr = addr
s.staged.Direct = true
s.publish()
s.syn.TraceID = newTraceID()
s.syn.Direct = true
s.syn.PossibleAddrs = [8]netip.AddrPort{}
s.sendControlPacket(s.syn)
s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String())
}
func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) {
if s.staged.Direct {
return
}
// Send probe.
//
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
select {
case s.localDiscoveryAddr <- addr:
// OK.
default:
log.Printf("Local discovery packet dropped.")
}
}
func (s *stateClient) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval {
if s.staged.Up {
s.logf("Connection timeout.")
}
return s.OnPeerUpdate(s.peer)
}
s.sendControlPacket(s.syn)
if s.staged.Direct {
return nil
}
clear(s.probes)
for _, ip := range publicAddrs.Get() {
if !ip.IsValid() {
break
}
s.sendProbeTo(ip)
}
select {
case addr := <-s.localDiscoveryAddr:
s.sendProbeTo(addr)
default:
// Nothing to do.
}
return nil
}