vppn/peer/remotefsm.go
2025-09-01 18:15:42 +02:00

452 lines
9.7 KiB
Go

package peer
import (
"bytes"
"net/netip"
"time"
"vppn/m"
)
type stateFunc func(msg any) stateFunc
type sentProbe struct {
SentAt time.Time
Addr netip.AddrPort
}
type remoteFSM struct {
*Remote
pingTimer *time.Ticker
lastSeen time.Time
traceID uint64
probes map[uint64]sentProbe
sharedKey [32]byte
buf []byte
}
func newRemoteFSM(r *Remote) *remoteFSM {
fsm := &remoteFSM{
Remote: r,
pingTimer: time.NewTicker(timeoutInterval),
probes: map[uint64]sentProbe{},
buf: make([]byte, bufferSize),
}
fsm.pingTimer.Stop()
return fsm
}
func (r *remoteFSM) Run() {
go func() {
for range r.pingTimer.C {
r.messages <- pingTimerMsg{}
}
}()
state := r.enterDisconnected()
for msg := range r.messages {
state = state(msg)
}
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterDisconnected() stateFunc {
r.updateConf(remoteConfig{})
return r.stateDisconnected
}
func (r *remoteFSM) stateDisconnected(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetInit]:
r.logf("Unexpected INIT")
case controlMsg[packetSyn]:
r.logf("Unexpected SYN")
case controlMsg[packetAck]:
r.logf("Unexpected ACK")
case controlMsg[packetProbe]:
r.logf("Unexpected probe")
case controlMsg[packetLocalDiscovery]:
// Ignore
case pingTimerMsg:
r.logf("Unexpected ping")
default:
r.logf("Ignoring message: %#v", iMsg)
}
return r.stateDisconnected
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterPeerUpdating(peer *m.Peer) stateFunc {
if peer == nil {
return r.enterDisconnected()
}
conf := remoteConfig{
Peer: peer,
ControlCipher: newControlCipher(r.PrivKey, peer.PubKey),
}
r.updateConf(conf)
if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
if r.LocalAddrValid && r.LocalPeerIP < peer.PeerIP {
return r.enterServer()
}
return r.enterClientInit()
}
if r.LocalAddrValid || r.LocalPeerIP < peer.PeerIP {
return r.enterServer()
}
return r.enterClientInit()
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterServer() stateFunc {
conf := r.conf()
conf.Server = true
r.updateConf(conf)
r.logf("==> Server")
r.pingTimer.Reset(pingInterval)
r.lastSeen = time.Now()
clear(r.sharedKey[:])
return r.stateServer
}
func (r *remoteFSM) stateServer(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetInit]:
r.stateServer_onInit(msg)
case controlMsg[packetSyn]:
r.stateServer_onSyn(msg)
case controlMsg[packetAck]:
r.logf("Unexpected ACK")
case controlMsg[packetProbe]:
r.stateServer_onProbe(msg)
case controlMsg[packetLocalDiscovery]:
// Ignore
case pingTimerMsg:
r.stateServer_onPingTimer()
default:
r.logf("Unexpected message: %#v", iMsg)
}
return r.stateServer
}
func (r *remoteFSM) stateServer_onInit(msg controlMsg[packetInit]) {
conf := r.conf()
conf.Up = false
conf.Direct = msg.Packet.Direct
conf.DirectAddr = msg.SrcAddr
r.updateConf(conf)
init := packetInit{
TraceID: msg.Packet.TraceID,
Direct: conf.Direct,
Version: version,
}
r.sendControl(conf, init.Marshal(r.buf))
}
func (r *remoteFSM) stateServer_onSyn(msg controlMsg[packetSyn]) {
r.lastSeen = time.Now()
p := msg.Packet
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
conf := r.conf()
logSyn := !conf.Up || conf.Direct != p.Direct
conf.Up = true
conf.Direct = p.Direct
conf.DirectAddr = msg.SrcAddr
// Update data cipher if the key has changed.
if !bytes.Equal(r.sharedKey[:], p.SharedKey[:]) {
conf.DataCipher = newDataCipherFromKey(p.SharedKey)
copy(r.sharedKey[:], p.SharedKey[:])
}
r.updateConf(conf)
if logSyn {
r.logf("Got SYN.")
}
r.sendControl(conf, packetAck{
TraceID: p.TraceID,
ToAddr: conf.DirectAddr,
PossibleAddrs: r.PubAddrs.Get(),
}.Marshal(r.buf))
if p.Direct {
return
}
// Send probes if not a direct connection.
for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() {
break
}
r.logf("Probing %v...", addr)
r.sendControlToAddr(packetProbe{TraceID: r.NewTraceID()}.Marshal(r.buf), addr)
}
}
func (r *remoteFSM) stateServer_onProbe(msg controlMsg[packetProbe]) {
if !msg.SrcAddr.IsValid() {
return
}
data := packetProbe{TraceID: msg.Packet.TraceID}.Marshal(r.buf)
r.sendControlToAddr(data, msg.SrcAddr)
}
func (r *remoteFSM) stateServer_onPingTimer() {
conf := r.conf()
if time.Since(r.lastSeen) > timeoutInterval && conf.Up {
conf.Up = false
r.updateConf(conf)
r.logf("Timeout.")
}
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterClientInit() stateFunc {
conf := r.conf()
ip, ipValid := netip.AddrFromSlice(conf.Peer.PublicIP)
conf.Up = false
conf.Server = false
conf.Direct = ipValid
conf.DirectAddr = netip.AddrPortFrom(ip, conf.Peer.Port)
conf.DataCipher = newDataCipher()
r.updateConf(conf)
r.logf("==> ClientInit")
r.lastSeen = time.Now()
r.pingTimer.Reset(pingInterval)
r.stateClientInit_sendInit()
return r.stateClientInit
}
func (r *remoteFSM) stateClientInit(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetInit]:
return r.stateClientInit_onInit(msg)
case controlMsg[packetSyn]:
r.logf("Unexpected SYN")
case controlMsg[packetAck]:
r.logf("Unexpected ACK")
case controlMsg[packetProbe]:
// Ignore
case controlMsg[packetLocalDiscovery]:
// Ignore
case pingTimerMsg:
return r.stateClientInit_onPing()
default:
r.logf("Unexpected message: %#v", iMsg)
}
return r.stateClientInit
}
func (r *remoteFSM) stateClientInit_sendInit() {
conf := r.conf()
r.traceID = r.NewTraceID()
init := packetInit{
TraceID: r.traceID,
Direct: conf.Direct,
Version: version,
}
r.sendControl(conf, init.Marshal(r.buf))
}
func (r *remoteFSM) stateClientInit_onInit(msg controlMsg[packetInit]) stateFunc {
if msg.Packet.TraceID != r.traceID {
r.logf("Invalid trace ID on INIT.")
return r.stateClientInit
}
r.logf("Got INIT version %d.", msg.Packet.Version)
return r.enterClient()
}
func (r *remoteFSM) stateClientInit_onPing() stateFunc {
if time.Since(r.lastSeen) < timeoutInterval {
r.stateClientInit_sendInit()
return r.stateClientInit
}
// Direct connect failed. Try indirect.
conf := r.conf()
if conf.Direct {
conf.Direct = false
r.updateConf(conf)
r.lastSeen = time.Now()
r.stateClientInit_sendInit()
r.logf("Direct connection failed. Attempting indirect connection.")
return r.stateClientInit
}
// Indirect failed. Re-enter init state.
r.logf("Timeout.")
return r.enterClientInit()
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterClient() stateFunc {
conf := r.conf()
r.probes = make(map[uint64]sentProbe, 8)
r.traceID = r.NewTraceID()
r.stateClient_sendSyn(conf)
r.pingTimer.Reset(pingInterval)
r.logf("==> Client")
return r.stateClient
}
func (r *remoteFSM) stateClient(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetAck]:
r.stateClient_onAck(msg)
case controlMsg[packetProbe]:
r.stateClient_onProbe(msg)
case controlMsg[packetLocalDiscovery]:
r.stateClient_onLocalDiscovery(msg)
case pingTimerMsg:
return r.stateClient_onPingTimer()
default:
r.logf("Ignoring message: %v", iMsg)
}
return r.stateClient
}
func (r *remoteFSM) stateClient_onAck(msg controlMsg[packetAck]) {
if msg.Packet.TraceID != r.traceID {
return
}
r.lastSeen = time.Now()
conf := r.conf()
if !conf.Up {
conf.Up = true
r.updateConf(conf)
r.logf("Got ACK.")
}
if conf.Direct {
r.PubAddrs.Store(msg.Packet.ToAddr)
return
}
// Relayed.
r.stateClient_cleanProbes()
for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() {
break
}
r.stateClient_sendProbeTo(addr)
}
}
func (r *remoteFSM) stateClient_cleanProbes() {
for key, sent := range r.probes {
if time.Since(sent.SentAt) > pingInterval {
delete(r.probes, key)
}
}
}
func (r *remoteFSM) stateClient_sendProbeTo(addr netip.AddrPort) {
probe := packetProbe{TraceID: r.NewTraceID()}
r.probes[probe.TraceID] = sentProbe{
SentAt: time.Now(),
Addr: addr,
}
r.logf("Probing %v...", addr)
r.sendControlToAddr(probe.Marshal(r.buf), addr)
}
func (r *remoteFSM) stateClient_onProbe(msg controlMsg[packetProbe]) {
conf := r.conf()
if conf.Direct {
return
}
r.stateClient_cleanProbes()
sent, ok := r.probes[msg.Packet.TraceID]
if !ok {
return
}
conf.Direct = true
conf.DirectAddr = sent.Addr
r.updateConf(conf)
r.traceID = r.NewTraceID()
r.stateClient_sendSyn(conf)
r.logf("Successful probe to %v.", sent.Addr)
}
func (r *remoteFSM) stateClient_onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
conf := r.conf()
if conf.Direct {
return
}
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), conf.Peer.Port)
r.stateClient_sendProbeTo(addr)
}
func (r *remoteFSM) stateClient_onPingTimer() stateFunc {
conf := r.conf()
if time.Since(r.lastSeen) > timeoutInterval {
if conf.Up {
r.logf("Timeout.")
}
return r.enterClientInit()
}
r.stateClient_sendSyn(conf)
return r.stateClient
}
func (r *remoteFSM) stateClient_sendSyn(conf remoteConfig) {
syn := packetSyn{
TraceID: r.traceID,
SharedKey: conf.DataCipher.Key(),
Direct: conf.Direct,
PossibleAddrs: r.PubAddrs.Get(),
}
r.sendControl(conf, syn.Marshal(r.buf))
}