356 lines
8.3 KiB
Go
356 lines
8.3 KiB
Go
package peer
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"net/netip"
|
|
"strings"
|
|
"time"
|
|
"vppn/m"
|
|
|
|
"git.crumpington.com/lib/go/ratelimiter"
|
|
)
|
|
|
|
type PeerState interface {
|
|
OnPeerUpdate(*m.Peer) PeerState
|
|
OnSyn(controlMsg[PacketSyn]) PeerState
|
|
OnAck(controlMsg[PacketAck])
|
|
OnProbe(controlMsg[PacketProbe]) PeerState
|
|
OnLocalDiscovery(controlMsg[PacketLocalDiscovery])
|
|
OnPingTimer() PeerState
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type State struct {
|
|
// Output.
|
|
publish func(RemotePeer)
|
|
sendControlPacket func(RemotePeer, Marshaller)
|
|
|
|
// Immutable data.
|
|
localIP byte
|
|
remoteIP byte
|
|
privKey []byte
|
|
localAddr netip.AddrPort // If valid, then local peer is publicly accessible.
|
|
|
|
pubAddrs *pubAddrStore
|
|
|
|
// The purpose of this state machine is to manage the RemotePeer object,
|
|
// publishing it as necessary.
|
|
staged RemotePeer // Local copy of shared data. See publish().
|
|
|
|
// Mutable peer data.
|
|
peer *m.Peer
|
|
|
|
// We rate limit per remote endpoint because if we don't we tend to lose
|
|
// packets.
|
|
limiter *ratelimiter.Limiter
|
|
}
|
|
|
|
func (s *State) OnPeerUpdate(peer *m.Peer) PeerState {
|
|
defer func() {
|
|
// Don't defer directly otherwise s.staged will be evaluated immediately
|
|
// and won't reflect changes made in the function.
|
|
s.publish(s.staged)
|
|
}()
|
|
|
|
if peer == nil {
|
|
return EnterStateDisconnected(s)
|
|
}
|
|
|
|
s.peer = peer
|
|
s.staged.Relay = false
|
|
s.staged.Direct = false
|
|
s.staged.DirectAddr = netip.AddrPort{}
|
|
s.staged.PubSignKey = nil
|
|
s.staged.PubSignKey = peer.PubSignKey
|
|
s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey)
|
|
s.staged.DataCipher = newDataCipher()
|
|
|
|
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
|
|
s.staged.Relay = peer.Relay
|
|
s.staged.Direct = true
|
|
s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port)
|
|
|
|
if s.localAddr.IsValid() && s.localIP < s.remoteIP {
|
|
return EnterStateServer(s)
|
|
}
|
|
|
|
return EnterStateClientDirect(s)
|
|
}
|
|
|
|
if s.localAddr.IsValid() {
|
|
s.staged.Direct = true
|
|
return EnterStateServer(s)
|
|
}
|
|
|
|
if s.localIP < s.remoteIP {
|
|
return EnterStateServer(s)
|
|
}
|
|
|
|
return EnterStateClientRelayed(s)
|
|
}
|
|
|
|
func (s *State) logf(format string, args ...any) {
|
|
b := strings.Builder{}
|
|
name := "--"
|
|
if s.peer != nil {
|
|
name = s.peer.Name
|
|
}
|
|
b.WriteString(fmt.Sprintf("%30s: ", name))
|
|
|
|
if s.staged.Direct {
|
|
b.WriteString("DIRECT | ")
|
|
} else {
|
|
b.WriteString("RELAYED | ")
|
|
}
|
|
|
|
if s.staged.Up {
|
|
b.WriteString("UP | ")
|
|
} else {
|
|
b.WriteString("DOWN | ")
|
|
}
|
|
|
|
log.Printf(b.String()+format, args...)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) {
|
|
if !addr.IsValid() {
|
|
return
|
|
}
|
|
route := s.staged
|
|
route.Direct = true
|
|
route.DirectAddr = addr
|
|
s.Send(route, pkt)
|
|
}
|
|
|
|
func (s *State) Send(peer RemotePeer, pkt Marshaller) {
|
|
if err := s.limiter.Limit(); err != nil {
|
|
s.logf("Rate limited.")
|
|
return
|
|
}
|
|
s.sendControlPacket(peer, pkt)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type StateDisconnected struct{ *State }
|
|
|
|
func EnterStateDisconnected(s *State) PeerState {
|
|
s.logf("==> Disconnected")
|
|
s.peer = nil
|
|
s.staged.Up = false
|
|
s.staged.Relay = false
|
|
s.staged.Direct = false
|
|
s.staged.DirectAddr = netip.AddrPort{}
|
|
s.staged.PubSignKey = nil
|
|
s.staged.ControlCipher = nil
|
|
s.staged.DataCipher = nil
|
|
s.publish(s.staged)
|
|
return &StateDisconnected{State: s}
|
|
}
|
|
|
|
func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState { return nil }
|
|
func (s *StateDisconnected) OnAck(controlMsg[PacketAck]) {}
|
|
func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState { return nil }
|
|
func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {}
|
|
func (s *StateDisconnected) OnPingTimer() PeerState { return nil }
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type StateServer struct {
|
|
*StateDisconnected
|
|
lastSeen time.Time
|
|
synTraceID uint64
|
|
}
|
|
|
|
func EnterStateServer(s *State) PeerState {
|
|
s.logf("==> Server")
|
|
return &StateServer{StateDisconnected: &StateDisconnected{State: s}}
|
|
}
|
|
|
|
func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState {
|
|
s.lastSeen = time.Now()
|
|
p := msg.Packet
|
|
|
|
// Before we can respond to this packet, we need to make sure the
|
|
// route is setup properly.
|
|
//
|
|
// The client will update the syn's TraceID whenever there's a change.
|
|
// The server will follow the client's request.
|
|
if p.TraceID != s.synTraceID || !s.staged.Up {
|
|
s.synTraceID = p.TraceID
|
|
s.staged.Up = true
|
|
s.staged.Direct = p.Direct
|
|
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
|
|
s.staged.DirectAddr = msg.SrcAddr
|
|
s.publish(s.staged)
|
|
s.logf("Got SYN.")
|
|
}
|
|
|
|
// Always respond.
|
|
ack := PacketAck{
|
|
TraceID: p.TraceID,
|
|
ToAddr: s.staged.DirectAddr,
|
|
PossibleAddrs: s.pubAddrs.Get(),
|
|
}
|
|
s.Send(s.staged, ack)
|
|
|
|
if p.Direct {
|
|
return nil
|
|
}
|
|
|
|
for _, addr := range msg.Packet.PossibleAddrs {
|
|
if !addr.IsValid() {
|
|
break
|
|
}
|
|
s.SendTo(PacketProbe{TraceID: newTraceID()}, addr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState {
|
|
if msg.SrcAddr.IsValid() {
|
|
s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *StateServer) OnPingTimer() PeerState {
|
|
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
|
|
s.staged.Up = false
|
|
s.publish(s.staged)
|
|
s.logf("Timeout.")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type StateClientDirect struct {
|
|
*StateDisconnected
|
|
lastSeen time.Time
|
|
syn PacketSyn
|
|
}
|
|
|
|
func EnterStateClientDirect(s *State) PeerState {
|
|
s.logf("==> ClientDirect")
|
|
return NewStateClientDirect(s)
|
|
}
|
|
|
|
func NewStateClientDirect(s *State) *StateClientDirect {
|
|
state := &StateClientDirect{
|
|
StateDisconnected: &StateDisconnected{s},
|
|
lastSeen: time.Now(), // Avoid immediate timeout.
|
|
}
|
|
|
|
state.syn = PacketSyn{
|
|
TraceID: newTraceID(),
|
|
SharedKey: s.staged.DataCipher.Key(),
|
|
Direct: s.staged.Direct,
|
|
PossibleAddrs: s.pubAddrs.Get(),
|
|
}
|
|
state.Send(s.staged, state.syn)
|
|
return state
|
|
}
|
|
|
|
func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) {
|
|
if msg.Packet.TraceID != s.syn.TraceID {
|
|
return
|
|
}
|
|
|
|
s.lastSeen = time.Now()
|
|
|
|
if !s.staged.Up {
|
|
s.staged.Up = true
|
|
s.publish(s.staged)
|
|
s.logf("Got ACK.")
|
|
}
|
|
|
|
s.pubAddrs.Store(msg.Packet.ToAddr)
|
|
}
|
|
|
|
func (s *StateClientDirect) OnPingTimer() PeerState {
|
|
if time.Since(s.lastSeen) > timeoutInterval {
|
|
if s.staged.Up {
|
|
s.staged.Up = false
|
|
s.publish(s.staged)
|
|
s.logf("Timeout.")
|
|
}
|
|
return s.OnPeerUpdate(s.peer)
|
|
}
|
|
|
|
s.Send(s.staged, s.syn)
|
|
return nil
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type StateClientRelayed struct {
|
|
*StateClientDirect
|
|
ack PacketAck
|
|
probes map[uint64]netip.AddrPort
|
|
localDiscoveryAddr netip.AddrPort
|
|
}
|
|
|
|
func EnterStateClientRelayed(s *State) PeerState {
|
|
s.logf("==> ClientRelayed")
|
|
return &StateClientRelayed{
|
|
StateClientDirect: NewStateClientDirect(s),
|
|
probes: map[uint64]netip.AddrPort{},
|
|
}
|
|
}
|
|
|
|
func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) {
|
|
s.ack = msg.Packet
|
|
s.StateClientDirect.OnAck(msg)
|
|
}
|
|
|
|
func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState {
|
|
addr, ok := s.probes[msg.Packet.TraceID]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
s.staged.DirectAddr = addr
|
|
s.staged.Direct = true
|
|
s.publish(s.staged)
|
|
return EnterStateClientDirect(s.StateClientDirect.State)
|
|
}
|
|
|
|
func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) {
|
|
// The source port will be the multicast port, so we'll have to
|
|
// construct the correct address using the peer's listed port.
|
|
s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
|
|
}
|
|
|
|
func (s *StateClientRelayed) OnPingTimer() PeerState {
|
|
if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil {
|
|
return nextState
|
|
}
|
|
|
|
clear(s.probes)
|
|
for _, addr := range s.ack.PossibleAddrs {
|
|
if !addr.IsValid() {
|
|
break
|
|
}
|
|
s.sendProbeTo(addr)
|
|
}
|
|
|
|
if s.localDiscoveryAddr.IsValid() {
|
|
s.sendProbeTo(s.localDiscoveryAddr)
|
|
s.localDiscoveryAddr = netip.AddrPort{}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) {
|
|
probe := PacketProbe{TraceID: newTraceID()}
|
|
s.probes[probe.TraceID] = addr
|
|
s.SendTo(probe, addr)
|
|
}
|