353 lines
7.5 KiB
Go
353 lines
7.5 KiB
Go
package node
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"math/rand"
|
|
"net/netip"
|
|
"sync/atomic"
|
|
"time"
|
|
"vppn/m"
|
|
)
|
|
|
|
const (
|
|
dialTimeout = 8 * time.Second
|
|
connectTimeout = 6 * time.Second
|
|
pingInterval = 6 * time.Second
|
|
timeoutInterval = 20 * time.Second
|
|
)
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type peerSupervisor struct {
|
|
// The purpose of this state machine is to manage this published data.
|
|
published *atomic.Pointer[peerRoute]
|
|
staged peerRoute // Local copy of shared data. See publish().
|
|
|
|
// Immutable data.
|
|
remoteIP byte // Remote VPN IP.
|
|
|
|
// Mutable peer data.
|
|
peer *m.Peer
|
|
remotePub bool
|
|
|
|
// Incoming events.
|
|
peerUpdates chan *m.Peer
|
|
controlPackets chan controlPacket
|
|
|
|
// Buffers for sending control packets.
|
|
buf1 []byte
|
|
buf2 []byte
|
|
}
|
|
|
|
func newPeerSupervisor(i int) *peerSupervisor {
|
|
return &peerSupervisor{
|
|
published: routingTable[i],
|
|
remoteIP: byte(i),
|
|
peerUpdates: peerUpdates[i],
|
|
controlPackets: controlPackets[i],
|
|
buf1: make([]byte, bufferSize),
|
|
buf2: make([]byte, bufferSize),
|
|
}
|
|
}
|
|
|
|
type stateFunc func() stateFunc
|
|
|
|
func (s *peerSupervisor) Run() {
|
|
state := s.noPeer
|
|
for {
|
|
state = state()
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
|
|
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) logf(msg string, args ...any) {
|
|
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) publish() {
|
|
data := s.staged
|
|
s.published.Store(&data)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) noPeer() stateFunc {
|
|
return s.peerUpdate(<-s.peerUpdates)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
|
|
return func() stateFunc { return s._peerUpdate(peer) }
|
|
}
|
|
|
|
func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
|
|
defer s.publish()
|
|
|
|
s.peer = peer
|
|
s.staged = peerRoute{}
|
|
|
|
if s.peer == nil {
|
|
return s.noPeer
|
|
}
|
|
|
|
s.staged.IP = s.remoteIP
|
|
s.staged.ControlCipher = newControlCipher(privateKey, peer.EncPubKey)
|
|
s.staged.DataCipher = newDataCipher()
|
|
|
|
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
|
|
s.remotePub = true
|
|
s.staged.Relay = peer.Mediator
|
|
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
|
|
}
|
|
|
|
if s.remotePub == localPub {
|
|
if localIP < s.remoteIP {
|
|
return s.serverAccept
|
|
}
|
|
return s.clientInit
|
|
}
|
|
|
|
if s.remotePub {
|
|
return s.clientInit
|
|
}
|
|
return s.serverAccept
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) serverAccept() stateFunc {
|
|
s.logf("STATE: server-accept")
|
|
s.staged.Up = false
|
|
s.staged.DataCipher = nil
|
|
s.staged.RemoteAddr = zeroAddrPort
|
|
s.staged.RelayIP = 0
|
|
s.publish()
|
|
|
|
var syn synPacket
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
|
|
case synPacket:
|
|
syn = p
|
|
s.staged.RemoteAddr = pkt.RemoteAddr
|
|
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
|
|
s.staged.RelayIP = syn.RelayIP
|
|
s.publish()
|
|
s.sendControlPacket(synAckPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
|
|
|
|
case ackPacket:
|
|
if p.TraceID != syn.TraceID {
|
|
continue
|
|
}
|
|
|
|
// Publish.
|
|
return s.serverConnected(syn.TraceID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) serverConnected(traceID uint64) stateFunc {
|
|
s.logf("STATE: server-connected")
|
|
s.staged.Up = true
|
|
s.publish()
|
|
return func() stateFunc {
|
|
return s._serverConnected(traceID)
|
|
}
|
|
}
|
|
|
|
func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc {
|
|
|
|
timeoutTimer := time.NewTimer(timeoutInterval)
|
|
defer timeoutTimer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
|
|
case ackPacket:
|
|
if p.TraceID != traceID {
|
|
return s.serverAccept
|
|
}
|
|
s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr})
|
|
timeoutTimer.Reset(timeoutInterval)
|
|
}
|
|
|
|
case <-timeoutTimer.C:
|
|
s.logf("server timeout")
|
|
return s.serverAccept
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) clientInit() stateFunc {
|
|
s.logf("STATE: client-init")
|
|
if !s.remotePub {
|
|
// TODO: Check local discovery for IP.
|
|
// TODO: Attempt UDP hole punch.
|
|
// TODO: client-relayed
|
|
return s.clientSelectRelay
|
|
}
|
|
|
|
return s.clientDial
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) clientSelectRelay() stateFunc {
|
|
s.logf("STATE: client-select-relay")
|
|
|
|
timer := time.NewTimer(0)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case <-timer.C:
|
|
relay := s.selectRelay()
|
|
if relay != nil {
|
|
s.logf("Got relay: %d", relay.IP)
|
|
s.staged.RelayIP = relay.IP
|
|
s.staged.LocalAddr = relay.LocalAddr
|
|
s.publish()
|
|
return s.clientDial
|
|
}
|
|
|
|
s.logf("No relay available.")
|
|
timer.Reset(pingInterval)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *peerSupervisor) selectRelay() *peerRoute {
|
|
possible := make([]*peerRoute, 0, 8)
|
|
for i := range routingTable {
|
|
route := routingTable[i].Load()
|
|
if !route.Up || !route.Relay {
|
|
continue
|
|
}
|
|
possible = append(possible, route)
|
|
}
|
|
|
|
if len(possible) == 0 {
|
|
return nil
|
|
}
|
|
return possible[rand.Intn(len(possible))]
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) clientDial() stateFunc {
|
|
s.logf("STATE: client-dial")
|
|
|
|
var (
|
|
syn = synPacket{
|
|
TraceID: newTraceID(),
|
|
SharedKey: s.staged.DataCipher.Key(),
|
|
RelayIP: s.staged.RelayIP,
|
|
}
|
|
|
|
timeout = time.NewTimer(dialTimeout)
|
|
)
|
|
|
|
defer timeout.Stop()
|
|
|
|
s.sendControlPacket(syn)
|
|
|
|
for {
|
|
select {
|
|
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
case synAckPacket:
|
|
if p.TraceID != syn.TraceID {
|
|
continue // Hmm...
|
|
}
|
|
s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
|
|
return s.clientConnected(p)
|
|
}
|
|
|
|
case <-timeout.C:
|
|
return s.clientInit
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSupervisor) clientConnected(p synAckPacket) stateFunc {
|
|
s.logf("STATE: client-connected")
|
|
s.staged.Up = true
|
|
s.staged.LocalAddr = p.RecvAddr
|
|
s.publish()
|
|
|
|
return func() stateFunc {
|
|
return s._clientConnected(p.TraceID)
|
|
}
|
|
}
|
|
|
|
func (s *peerSupervisor) _clientConnected(traceID uint64) stateFunc {
|
|
|
|
pingTimer := time.NewTimer(pingInterval)
|
|
timeoutTimer := time.NewTimer(timeoutInterval)
|
|
|
|
defer pingTimer.Stop()
|
|
defer timeoutTimer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
|
|
case ackPacket:
|
|
if p.TraceID != traceID {
|
|
return s.clientInit
|
|
}
|
|
timeoutTimer.Reset(timeoutInterval)
|
|
}
|
|
|
|
case <-pingTimer.C:
|
|
s.sendControlPacket(ackPacket{TraceID: traceID})
|
|
pingTimer.Reset(pingInterval)
|
|
|
|
case <-timeoutTimer.C:
|
|
s.logf("client timeout")
|
|
return s.clientInit
|
|
|
|
}
|
|
}
|
|
}
|