447 lines
9.6 KiB
Go
447 lines
9.6 KiB
Go
package peer
|
|
|
|
import (
|
|
"bytes"
|
|
"net/netip"
|
|
"time"
|
|
"vppn/m"
|
|
)
|
|
|
|
type stateFunc func(msg any) stateFunc
|
|
|
|
type remoteFSM struct {
|
|
*Remote
|
|
|
|
pingTimer *time.Ticker
|
|
lastSeen time.Time
|
|
traceID uint64
|
|
probes map[uint64]sentProbe
|
|
sharedKey [32]byte
|
|
|
|
buf []byte
|
|
}
|
|
|
|
func newRemoteFSM(r *Remote) *remoteFSM {
|
|
fsm := &remoteFSM{
|
|
Remote: r,
|
|
pingTimer: time.NewTicker(timeoutInterval),
|
|
probes: map[uint64]sentProbe{},
|
|
buf: make([]byte, bufferSize),
|
|
}
|
|
fsm.pingTimer.Stop()
|
|
return fsm
|
|
}
|
|
|
|
func (r *remoteFSM) Run() {
|
|
go func() {
|
|
for range r.pingTimer.C {
|
|
r.messages <- pingTimerMsg{}
|
|
}
|
|
}()
|
|
state := r.enterDisconnected()
|
|
for msg := range r.messages {
|
|
state = state(msg)
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (r *remoteFSM) enterDisconnected() stateFunc {
|
|
r.updateConf(remoteConfig{})
|
|
return r.stateDisconnected
|
|
}
|
|
|
|
func (r *remoteFSM) stateDisconnected(iMsg any) stateFunc {
|
|
switch msg := iMsg.(type) {
|
|
case peerUpdateMsg:
|
|
return r.enterPeerUpdating(msg.Peer)
|
|
case controlMsg[packetInit]:
|
|
r.logf("Unexpected INIT")
|
|
case controlMsg[packetSyn]:
|
|
r.logf("Unexpected SYN")
|
|
case controlMsg[packetAck]:
|
|
r.logf("Unexpected ACK")
|
|
case controlMsg[packetProbe]:
|
|
r.logf("Unexpected probe")
|
|
case controlMsg[packetLocalDiscovery]:
|
|
// Ignore
|
|
case pingTimerMsg:
|
|
r.logf("Unexpected ping")
|
|
default:
|
|
r.logf("Ignoring message: %#v", iMsg)
|
|
}
|
|
|
|
return r.stateDisconnected
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (r *remoteFSM) enterPeerUpdating(peer *m.Peer) stateFunc {
|
|
if peer == nil {
|
|
return r.enterDisconnected()
|
|
}
|
|
|
|
conf := remoteConfig{
|
|
Peer: peer,
|
|
ControlCipher: newControlCipher(r.PrivKey, peer.PubKey),
|
|
}
|
|
r.updateConf(conf)
|
|
|
|
if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
|
|
if r.LocalAddrValid && r.LocalPeerIP < peer.PeerIP {
|
|
return r.enterServer()
|
|
}
|
|
return r.enterClientInit()
|
|
}
|
|
|
|
if r.LocalAddrValid || r.LocalPeerIP < peer.PeerIP {
|
|
return r.enterServer()
|
|
}
|
|
|
|
return r.enterClientInit()
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (r *remoteFSM) enterServer() stateFunc {
|
|
|
|
conf := r.conf()
|
|
conf.Server = true
|
|
r.updateConf(conf)
|
|
r.logf("==> Server")
|
|
|
|
r.pingTimer.Reset(pingInterval)
|
|
r.lastSeen = time.Now()
|
|
clear(r.sharedKey[:])
|
|
return r.stateServer
|
|
}
|
|
|
|
func (r *remoteFSM) stateServer(iMsg any) stateFunc {
|
|
switch msg := iMsg.(type) {
|
|
case peerUpdateMsg:
|
|
return r.enterPeerUpdating(msg.Peer)
|
|
case controlMsg[packetInit]:
|
|
r.stateServer_onInit(msg)
|
|
case controlMsg[packetSyn]:
|
|
r.stateServer_onSyn(msg)
|
|
case controlMsg[packetAck]:
|
|
r.logf("Unexpected ACK")
|
|
case controlMsg[packetProbe]:
|
|
r.stateServer_onProbe(msg)
|
|
case controlMsg[packetLocalDiscovery]:
|
|
// Ignore
|
|
case pingTimerMsg:
|
|
r.stateServer_onPingTimer()
|
|
default:
|
|
r.logf("Unexpected message: %#v", iMsg)
|
|
}
|
|
|
|
return r.stateServer
|
|
}
|
|
|
|
func (r *remoteFSM) stateServer_onInit(msg controlMsg[packetInit]) {
|
|
conf := r.conf()
|
|
conf.Up = false
|
|
conf.Direct = msg.Packet.Direct
|
|
conf.DirectAddr = msg.SrcAddr
|
|
r.updateConf(conf)
|
|
|
|
init := packetInit{
|
|
TraceID: msg.Packet.TraceID,
|
|
Direct: conf.Direct,
|
|
Version: version,
|
|
}
|
|
|
|
r.sendControl(conf, init.Marshal(r.buf))
|
|
}
|
|
|
|
func (r *remoteFSM) stateServer_onSyn(msg controlMsg[packetSyn]) {
|
|
r.logf("Got SYN: %v", msg.Packet)
|
|
r.lastSeen = time.Now()
|
|
p := msg.Packet
|
|
|
|
// Before we can respond to this packet, we need to make sure the
|
|
// route is setup properly.
|
|
conf := r.conf()
|
|
if !conf.Up || conf.Direct != p.Direct {
|
|
r.logf("Got SYN.")
|
|
}
|
|
|
|
conf.Up = true
|
|
conf.Direct = p.Direct
|
|
conf.DirectAddr = msg.SrcAddr
|
|
|
|
// Update data cipher if the key has changed.
|
|
if !bytes.Equal(r.sharedKey[:], p.SharedKey[:]) {
|
|
conf.DataCipher = newDataCipherFromKey(p.SharedKey)
|
|
copy(r.sharedKey[:], p.SharedKey[:])
|
|
}
|
|
|
|
r.updateConf(conf)
|
|
|
|
r.sendControl(conf, packetAck{
|
|
TraceID: p.TraceID,
|
|
ToAddr: conf.DirectAddr,
|
|
PossibleAddrs: r.PubAddrs.Get(),
|
|
}.Marshal(r.buf))
|
|
|
|
if p.Direct {
|
|
return
|
|
}
|
|
|
|
// Send probes if not a direct connection.
|
|
for _, addr := range msg.Packet.PossibleAddrs {
|
|
if !addr.IsValid() {
|
|
break
|
|
}
|
|
r.logf("Probing %v...", addr)
|
|
r.sendControlToAddr(packetProbe{TraceID: newTraceID()}.Marshal(r.buf), addr)
|
|
}
|
|
}
|
|
|
|
func (r *remoteFSM) stateServer_onProbe(msg controlMsg[packetProbe]) {
|
|
if !msg.SrcAddr.IsValid() {
|
|
return
|
|
}
|
|
|
|
data := packetProbe{TraceID: msg.Packet.TraceID}.Marshal(r.buf)
|
|
r.sendControlToAddr(data, msg.SrcAddr)
|
|
}
|
|
|
|
func (r *remoteFSM) stateServer_onPingTimer() {
|
|
conf := r.conf()
|
|
if time.Since(r.lastSeen) > timeoutInterval && conf.Up {
|
|
conf.Up = false
|
|
r.updateConf(conf)
|
|
r.logf("Timeout.")
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (r *remoteFSM) enterClientInit() stateFunc {
|
|
conf := r.conf()
|
|
ip, ipValid := netip.AddrFromSlice(conf.Peer.PublicIP)
|
|
|
|
conf.Up = false
|
|
conf.Server = false
|
|
conf.Direct = ipValid
|
|
conf.DirectAddr = netip.AddrPortFrom(ip, conf.Peer.Port)
|
|
conf.DataCipher = newDataCipher()
|
|
|
|
r.updateConf(conf)
|
|
r.logf("==> ClientInit")
|
|
|
|
r.lastSeen = time.Now()
|
|
r.pingTimer.Reset(pingInterval)
|
|
r.stateClientInit_sendInit()
|
|
return r.stateClientInit
|
|
}
|
|
|
|
func (r *remoteFSM) stateClientInit(iMsg any) stateFunc {
|
|
switch msg := iMsg.(type) {
|
|
case peerUpdateMsg:
|
|
return r.enterPeerUpdating(msg.Peer)
|
|
case controlMsg[packetInit]:
|
|
return r.stateClientInit_onInit(msg)
|
|
case controlMsg[packetSyn]:
|
|
r.logf("Unexpected SYN")
|
|
case controlMsg[packetAck]:
|
|
r.logf("Unexpected ACK")
|
|
case controlMsg[packetProbe]:
|
|
// Ignore
|
|
case controlMsg[packetLocalDiscovery]:
|
|
// Ignore
|
|
case pingTimerMsg:
|
|
return r.stateClientInit_onPing()
|
|
default:
|
|
r.logf("Unexpected message: %#v", iMsg)
|
|
}
|
|
|
|
return r.stateClientInit
|
|
}
|
|
|
|
func (r *remoteFSM) stateClientInit_sendInit() {
|
|
conf := r.conf()
|
|
r.traceID = newTraceID()
|
|
init := packetInit{
|
|
TraceID: r.traceID,
|
|
Direct: conf.Direct,
|
|
Version: version,
|
|
}
|
|
r.sendControl(conf, init.Marshal(r.buf))
|
|
}
|
|
|
|
func (r *remoteFSM) stateClientInit_onInit(msg controlMsg[packetInit]) stateFunc {
|
|
if msg.Packet.TraceID != r.traceID {
|
|
r.logf("Invalid trace ID on INIT.")
|
|
return r.stateClientInit
|
|
}
|
|
r.logf("Got INIT version %d.", msg.Packet.Version)
|
|
return r.enterClient()
|
|
}
|
|
|
|
func (r *remoteFSM) stateClientInit_onPing() stateFunc {
|
|
if time.Since(r.lastSeen) < timeoutInterval {
|
|
r.stateClientInit_sendInit()
|
|
return r.stateClientInit
|
|
}
|
|
|
|
// Direct connect failed. Try indirect.
|
|
conf := r.conf()
|
|
|
|
if conf.Direct {
|
|
conf.Direct = false
|
|
r.updateConf(conf)
|
|
r.lastSeen = time.Now()
|
|
r.stateClientInit_sendInit()
|
|
r.logf("Direct connection failed. Attempting indirect connection.")
|
|
return r.stateClientInit
|
|
}
|
|
|
|
// Indirect failed. Re-enter init state.
|
|
r.logf("Timeout.")
|
|
return r.enterClientInit()
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (r *remoteFSM) enterClient() stateFunc {
|
|
conf := r.conf()
|
|
r.probes = make(map[uint64]sentProbe, 8)
|
|
|
|
r.traceID = newTraceID()
|
|
r.stateClient_sendSyn(conf)
|
|
|
|
r.pingTimer.Reset(pingInterval)
|
|
r.logf("==> Client")
|
|
return r.stateClient
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient(iMsg any) stateFunc {
|
|
switch msg := iMsg.(type) {
|
|
case peerUpdateMsg:
|
|
return r.enterPeerUpdating(msg.Peer)
|
|
case controlMsg[packetAck]:
|
|
r.stateClient_onAck(msg)
|
|
case controlMsg[packetProbe]:
|
|
r.stateClient_onProbe(msg)
|
|
case controlMsg[packetLocalDiscovery]:
|
|
r.stateClient_onLocalDiscovery(msg)
|
|
case pingTimerMsg:
|
|
return r.stateClient_onPingTimer()
|
|
default:
|
|
r.logf("Ignoring message: %v", iMsg)
|
|
}
|
|
return r.stateClient
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient_onAck(msg controlMsg[packetAck]) {
|
|
if msg.Packet.TraceID != r.traceID {
|
|
return
|
|
}
|
|
|
|
r.lastSeen = time.Now()
|
|
|
|
conf := r.conf()
|
|
if !conf.Up {
|
|
conf.Up = true
|
|
r.updateConf(conf)
|
|
r.logf("Got ACK.")
|
|
}
|
|
|
|
if conf.Direct {
|
|
r.PubAddrs.Store(msg.Packet.ToAddr)
|
|
return
|
|
}
|
|
|
|
// Relayed.
|
|
|
|
r.stateClient_cleanProbes()
|
|
|
|
for _, addr := range msg.Packet.PossibleAddrs {
|
|
if !addr.IsValid() {
|
|
break
|
|
}
|
|
r.stateClient_sendProbeTo(addr)
|
|
}
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient_cleanProbes() {
|
|
for key, sent := range r.probes {
|
|
if time.Since(sent.SentAt) > pingInterval {
|
|
delete(r.probes, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient_sendProbeTo(addr netip.AddrPort) {
|
|
probe := packetProbe{TraceID: newTraceID()}
|
|
r.probes[probe.TraceID] = sentProbe{
|
|
SentAt: time.Now(),
|
|
Addr: addr,
|
|
}
|
|
r.logf("Probing %v...", addr)
|
|
r.sendControlToAddr(probe.Marshal(r.buf), addr)
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient_onProbe(msg controlMsg[packetProbe]) {
|
|
conf := r.conf()
|
|
if conf.Direct {
|
|
return
|
|
}
|
|
|
|
r.stateClient_cleanProbes()
|
|
|
|
sent, ok := r.probes[msg.Packet.TraceID]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
conf.Direct = true
|
|
conf.DirectAddr = sent.Addr
|
|
r.updateConf(conf)
|
|
|
|
r.traceID = newTraceID()
|
|
r.stateClient_sendSyn(conf)
|
|
r.logf("Successful probe to %v.", sent.Addr)
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient_onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
|
|
conf := r.conf()
|
|
if conf.Direct {
|
|
return
|
|
}
|
|
|
|
// The source port will be the multicast port, so we'll have to
|
|
// construct the correct address using the peer's listed port.
|
|
addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), conf.Peer.Port)
|
|
r.stateClient_sendProbeTo(addr)
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient_onPingTimer() stateFunc {
|
|
conf := r.conf()
|
|
|
|
if time.Since(r.lastSeen) > timeoutInterval {
|
|
if conf.Up {
|
|
r.logf("Timeout.")
|
|
}
|
|
return r.enterClientInit()
|
|
}
|
|
|
|
r.traceID = newTraceID()
|
|
r.stateClient_sendSyn(conf)
|
|
return r.stateClient
|
|
}
|
|
|
|
func (r *remoteFSM) stateClient_sendSyn(conf remoteConfig) {
|
|
syn := packetSyn{
|
|
TraceID: r.traceID,
|
|
SharedKey: conf.DataCipher.Key(),
|
|
Direct: conf.Direct,
|
|
PossibleAddrs: r.PubAddrs.Get(),
|
|
}
|
|
|
|
r.sendControl(conf, syn.Marshal(r.buf))
|
|
}
|