vppn/peer/peerstates.go
2025-02-23 16:58:41 +01:00

421 lines
9.5 KiB
Go

package peer
import (
"fmt"
"log"
"net/netip"
"strings"
"time"
"vppn/m"
"git.crumpington.com/lib/go/ratelimiter"
)
type peerState interface {
OnMsg(raw any) peerState
}
// ----------------------------------------------------------------------------
type pState struct {
// Output.
publish func(remotePeer)
sendControlPacket func(remotePeer, marshaller)
pingTimer *time.Ticker
// Immutable data.
localIP byte
remoteIP byte
privKey []byte
localAddr netip.AddrPort // If valid, then local peer is publicly accessible.
pubAddrs *pubAddrStore
// The purpose of this state machine is to manage the RemotePeer object,
// publishing it as necessary.
staged remotePeer // Local copy of shared data. See publish().
// Mutable peer data.
peer *m.Peer
// We rate limit per remote endpoint because if we don't we tend to lose
// packets.
limiter *ratelimiter.Limiter
}
func (s *pState) OnPeerUpdate(peer *m.Peer) peerState {
defer func() {
// Don't defer directly otherwise s.staged will be evaluated immediately
// and won't reflect changes made in the function.
s.publish(s.staged)
}()
s.peer = peer
s.staged.localIP = s.localIP
s.staged.Up = false
s.staged.Relay = false
s.staged.Direct = false
s.staged.DirectAddr = netip.AddrPort{}
s.staged.PubSignKey = nil
s.staged.ControlCipher = nil
s.staged.DataCipher = nil
if peer == nil {
return enterStateDisconnected(s)
}
s.staged.IP = peer.PeerIP
s.staged.PubSignKey = peer.PubSignKey
s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey)
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.staged.Relay = peer.Relay
s.staged.Direct = true
s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port)
if s.localAddr.IsValid() && s.localIP < s.remoteIP {
return enterStateServer(s)
}
return enterStateClientDirect(s)
}
if s.localAddr.IsValid() {
s.staged.Direct = true
return enterStateServer(s)
}
if s.localIP < s.remoteIP {
return enterStateServer(s)
}
return enterStateClientRelayed(s)
}
func (s *pState) logf(format string, args ...any) {
b := strings.Builder{}
name := ""
if s.peer != nil {
name = s.peer.Name
}
b.WriteString(fmt.Sprintf("%03d", s.remoteIP))
b.WriteString(fmt.Sprintf("%30s: ", name))
if s.staged.Direct {
b.WriteString("DIRECT | ")
} else {
b.WriteString("RELAYED | ")
}
if s.staged.Up {
b.WriteString("UP | ")
} else {
b.WriteString("DOWN | ")
}
log.Printf(b.String()+format, args...)
}
// ----------------------------------------------------------------------------
func (s *pState) SendTo(pkt marshaller, addr netip.AddrPort) {
if !addr.IsValid() {
return
}
route := s.staged
route.Direct = true
route.DirectAddr = addr
s.Send(route, pkt)
}
func (s *pState) Send(peer remotePeer, pkt marshaller) {
if err := s.limiter.Limit(); err != nil {
s.logf("Rate limited.")
return
}
s.sendControlPacket(peer, pkt)
}
// ----------------------------------------------------------------------------
type stateDisconnected struct{ *pState }
func enterStateDisconnected(s *pState) peerState {
s.pingTimer.Stop()
return &stateDisconnected{pState: s}
}
func (s *stateDisconnected) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return s.OnPeerUpdate(msg.Peer)
default:
// TODO: Log.
return s
}
}
func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s }
func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {}
func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s }
func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {}
func (s *stateDisconnected) OnPingTimer() peerState { return s }
// ----------------------------------------------------------------------------
type stateServer struct {
*stateDisconnected
lastSeen time.Time
synTraceID uint64
}
func enterStateServer(s *pState) peerState {
s.logf("==> Server")
s.pingTimer.Reset(pingInterval)
return &stateServer{stateDisconnected: &stateDisconnected{pState: s}}
}
func (s *stateServer) OnMsg(rawMsg any) peerState {
switch msg := rawMsg.(type) {
case peerUpdateMsg:
return s.OnPeerUpdate(msg.Peer)
case controlMsg[packetSyn]:
return s.OnSyn(msg)
case controlMsg[packetProbe]:
return s.OnProbe(msg)
case pingTimerMsg:
return s.OnPingTimer()
default:
// TODO: Log
return s
}
}
func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState {
s.lastSeen = time.Now()
p := msg.Packet
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != s.synTraceID || !s.staged.Up {
s.synTraceID = p.TraceID
s.staged.Up = true
s.staged.Direct = p.Direct
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
s.staged.DirectAddr = msg.SrcAddr
s.publish(s.staged)
s.logf("Got SYN.")
}
// Always respond.
ack := packetAck{
TraceID: p.TraceID,
ToAddr: s.staged.DirectAddr,
PossibleAddrs: s.pubAddrs.Get(),
}
s.Send(s.staged, ack)
if p.Direct {
return s
}
for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() {
break
}
s.SendTo(packetProbe{TraceID: newTraceID()}, addr)
}
return s
}
func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState {
if msg.SrcAddr.IsValid() {
s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
}
return s
}
func (s *stateServer) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
s.staged.Up = false
s.publish(s.staged)
s.logf("Timeout.")
}
return s
}
// ----------------------------------------------------------------------------
type stateClientDirect struct {
*stateDisconnected
lastSeen time.Time
syn packetSyn
}
func enterStateClientDirect(s *pState) peerState {
s.logf("==> ClientDirect")
s.pingTimer.Reset(pingInterval)
return newStateClientDirect(s)
}
func newStateClientDirect(s *pState) *stateClientDirect {
state := &stateClientDirect{
stateDisconnected: &stateDisconnected{s},
lastSeen: time.Now(), // Avoid immediate timeout.
}
state.syn = packetSyn{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct,
PossibleAddrs: s.pubAddrs.Get(),
}
state.Send(s.staged, state.syn)
return state
}
func (s *stateClientDirect) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return s.OnPeerUpdate(msg.Peer)
case controlMsg[packetAck]:
s.OnAck(msg)
return s
case pingTimerMsg:
if next := s.onPingTimer(); next != nil {
return next
}
return s
default:
// TODO: Log
return s
}
}
func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) {
if msg.Packet.TraceID != s.syn.TraceID {
return
}
s.lastSeen = time.Now()
if !s.staged.Up {
s.staged.Up = true
s.publish(s.staged)
s.logf("Got ACK.")
}
s.pubAddrs.Store(msg.Packet.ToAddr)
}
func (s *stateClientDirect) onPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval {
if s.staged.Up {
s.staged.Up = false
s.publish(s.staged)
s.logf("Timeout.")
}
return s.OnPeerUpdate(s.peer)
}
s.Send(s.staged, s.syn)
return nil
}
// ----------------------------------------------------------------------------
type stateClientRelayed struct {
*stateClientDirect
ack packetAck
probes map[uint64]netip.AddrPort // TODO: something better
localDiscoveryAddr netip.AddrPort // TODO: Remove
}
func enterStateClientRelayed(s *pState) peerState {
s.logf("==> ClientRelayed")
s.pingTimer.Reset(pingInterval)
return &stateClientRelayed{
stateClientDirect: newStateClientDirect(s),
probes: map[uint64]netip.AddrPort{},
}
}
func (s *stateClientRelayed) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return s.OnPeerUpdate(msg.Peer)
case controlMsg[packetAck]:
s.OnAck(msg)
return s
case controlMsg[packetProbe]:
return s.OnProbe(msg)
case controlMsg[packetLocalDiscovery]:
s.OnLocalDiscovery(msg)
return s
case pingTimerMsg:
return s.OnPingTimer()
default:
// TODO: Log
return s
}
}
func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) {
s.ack = msg.Packet
s.stateClientDirect.OnAck(msg)
// TODO: Send probes now.
}
func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState {
addr, ok := s.probes[msg.Packet.TraceID]
if !ok {
return s
}
s.staged.DirectAddr = addr
s.staged.Direct = true
s.publish(s.staged)
return enterStateClientDirect(s.stateClientDirect.pState)
}
func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
// TODO: s.sendProbeTo(s.localDiscoveryAddr)
}
func (s *stateClientRelayed) OnPingTimer() peerState {
if next := s.stateClientDirect.onPingTimer(); next != nil {
return next
}
clear(s.probes)
for _, addr := range s.ack.PossibleAddrs {
if !addr.IsValid() {
break
}
s.sendProbeTo(addr)
}
if s.localDiscoveryAddr.IsValid() {
s.sendProbeTo(s.localDiscoveryAddr)
s.localDiscoveryAddr = netip.AddrPort{}
}
return s
}
func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) {
probe := packetProbe{TraceID: newTraceID()}
s.probes[probe.TraceID] = addr
s.SendTo(probe, addr)
}