vppn/node/peersupervisor.go
2024-12-18 14:40:25 +01:00

328 lines
6.5 KiB
Go

package node
import (
"fmt"
"log"
"net/netip"
"time"
"vppn/m"
)
const (
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
type routingPacketWrapper struct {
routingPacket
Addr netip.AddrPort // Source.
}
type peerSupervisor struct {
// Constants:
localIP byte
localPublic bool
remoteIP byte
privKey []byte
// Shared data:
w *connWriter
table *routingTable
packets chan routingPacketWrapper
peerUpdates chan *m.Peer
// Peer-related items.
version int64 // Ony accessed in HandlePeerUpdate.
peer *m.Peer
remoteAddrPort *netip.AddrPort
mediated bool
sharedKey []byte
// Used by our state functions.
pingTimer *time.Timer
timeoutTimer *time.Timer
buf []byte
}
// ----------------------------------------------------------------------------
func newPeerSupervisor(
conf m.PeerConfig,
remoteIP byte,
w *connWriter,
table *routingTable,
) *peerSupervisor {
s := &peerSupervisor{
localIP: conf.PeerIP,
remoteIP: remoteIP,
privKey: conf.EncPrivKey,
w: w,
table: table,
packets: make(chan routingPacketWrapper, 256),
peerUpdates: make(chan *m.Peer, 1),
pingTimer: time.NewTimer(pingInterval),
timeoutTimer: time.NewTimer(timeoutInterval),
buf: make([]byte, bufferSize),
}
_, s.localPublic = netip.AddrFromSlice(conf.PublicIP)
go s.mainLoop()
return s
}
func (s *peerSupervisor) logf(msg string, args ...any) {
msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg
log.Printf(msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) mainLoop() {
defer panicHandler()
state := s.stateInit
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) {
if p != nil {
if p.Version == s.version {
return
}
s.version = p.Version
} else {
s.version = 0
}
s.peerUpdates <- p
}
func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) {
select {
case s.packets <- w:
default:
// Drop
}
}
// ----------------------------------------------------------------------------
type stateFunc func() stateFunc
func (s *peerSupervisor) stateInit() stateFunc {
if s.peer == nil {
return s.stateDisconnected
}
addr, ok := netip.AddrFromSlice(s.peer.PublicIP)
if ok {
addrPort := netip.AddrPortFrom(addr, s.peer.Port)
s.remoteAddrPort = &addrPort
} else {
s.remoteAddrPort = nil
}
s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey)
return s.stateSelectRole()
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateDisconnected() stateFunc {
s.clearRoutingTable()
for {
select {
case <-s.packets:
// Drop
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateSelectRole() stateFunc {
s.logf("STATE: SelectRole")
s.updateRoutingTable(false)
if s.remoteAddrPort != nil {
s.mediated = false
// If both remote and local are public, one side acts as client, and one
// side as server.
if s.localPublic && s.localIP < s.peer.PeerIP {
return s.stateAccept
}
return s.stateDial
}
// We're public, remote is not => can only wait for connection
if s.localPublic {
s.mediated = false
return s.stateAccept
}
// Both non-public => need to use mediator.
return s.stateMediated
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateAccept() stateFunc {
s.logf("STATE: Accept")
for {
select {
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePing:
s.remoteAddrPort = &pkt.Addr
s.updateRoutingTable(true)
s.sendPong(pkt.TraceID)
return s.stateConnected
default:
// Still waiting for ping...
}
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateDial() stateFunc {
s.logf("STATE: Dial")
s.updateRoutingTable(false)
s.sendPing()
for {
select {
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePong:
s.updateRoutingTable(true)
return s.stateConnected
default:
// Ignore
}
case <-s.pingTimer.C:
s.sendPing()
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateConnected() stateFunc {
s.logf("STATE: Connected")
s.timeoutTimer.Reset(timeoutInterval)
for {
select {
case <-s.pingTimer.C:
s.sendPing()
case <-s.timeoutTimer.C:
s.logf("Timeout")
return s.stateInit
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePing:
s.sendPong(pkt.TraceID)
// Server should always follow remote port.
if s.localPublic {
if pkt.Addr != *s.remoteAddrPort {
s.remoteAddrPort = &pkt.Addr
s.updateRoutingTable(true)
}
}
case packetTypePong:
s.timeoutTimer.Reset(timeoutInterval)
default:
// Drop packet.
}
case s.peer = <-s.peerUpdates:
s.logf("New peer: %v", s.peer)
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateMediated() stateFunc {
s.logf("STATE: Mediated")
s.mediated = true
s.updateRoutingTable(true)
for {
select {
case <-s.packets:
// Drop.
case s.peer = <-s.peerUpdates:
s.logf("New peer: %v", s.peer)
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clearRoutingTable() {
s.table.Set(s.remoteIP, nil)
}
func (s *peerSupervisor) updateRoutingTable(up bool) {
s.table.Set(s.remoteIP, &peer{
Up: up,
Mediator: s.peer.Mediator,
Mediated: s.mediated,
IP: s.remoteIP,
Addr: s.remoteAddrPort,
SharedKey: s.sharedKey,
})
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendPing() uint64 {
traceID := newTraceID()
pkt := newRoutingPacket(packetTypePing, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
s.pingTimer.Reset(pingInterval)
return traceID
}
func (s *peerSupervisor) sendPong(traceID uint64) {
pkt := newRoutingPacket(packetTypePong, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
}