vppn/node/peer-super-states.go
2024-12-22 19:17:58 +01:00

277 lines
5.5 KiB
Go

package node
import (
"math/rand"
"net/netip"
"time"
"vppn/m"
)
// ----------------------------------------------------------------------------
func (s *peerSuper) noPeer() stateFunc {
return s.peerUpdate(<-s.peerUpdates)
}
// ----------------------------------------------------------------------------
func (s *peerSuper) peerUpdate(peer *m.Peer) stateFunc {
return func() stateFunc { return s._peerUpdate(peer) }
}
func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc {
defer s.publish()
s.peer = peer
s.staged = peerRoutingData{}
if s.peer == nil {
return s.noPeer
}
s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
s.staged.dataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.relay = peer.Mediator
s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
}
if s.remotePub == s.localPub {
if s.localIP < s.remoteIP {
return s.serverAccept
}
return s.clientInit
}
if s.remotePub {
return s.clientInit
}
return s.serverAccept
}
// ----------------------------------------------------------------------------
func (s *peerSuper) serverAccept() stateFunc {
s.logf("STATE: server-accept")
s.staged.up = false
s.staged.dataCipher = nil
s.staged.remoteAddr = zeroAddrPort
s.staged.relayIP = 0
s.publish()
var syn synPacket
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synPacket:
syn = p
s.staged.remoteAddr = pkt.RemoteAddr
s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey)
s.staged.relayIP = syn.RelayIP
s.publish()
s.sendControlPacket(newSynAckPacket(p.TraceID))
case ackPacket:
if p.TraceID != syn.TraceID {
continue
}
// Publish.
return s.serverConnected(syn.TraceID)
}
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSuper) serverConnected(traceID uint64) stateFunc {
s.logf("STATE: server-connected")
s.staged.up = true
s.publish()
return func() stateFunc {
return s._serverConnected(traceID)
}
}
func (s *peerSuper) _serverConnected(traceID uint64) stateFunc {
timeoutTimer := time.NewTimer(timeoutInterval)
defer timeoutTimer.Stop()
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case ackPacket:
if p.TraceID != traceID {
return s.serverAccept
}
s.sendControlPacket(ackPacket{TraceID: traceID})
timeoutTimer.Reset(timeoutInterval)
}
case <-timeoutTimer.C:
s.logf("server timeout")
return s.serverAccept
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSuper) clientInit() stateFunc {
s.logf("STATE: client-init")
if !s.remotePub {
// TODO: Check local discovery for IP.
// TODO: Attempt UDP hole punch.
// TODO: client-relayed
return s.clientSelectRelay
}
return s.clientDial
}
// ----------------------------------------------------------------------------
func (s *peerSuper) clientSelectRelay() stateFunc {
s.logf("STATE: client-select-relay")
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case <-timer.C:
ip := s.selectRelayIP()
if ip != 0 {
s.logf("Got relay: %d", ip)
s.staged.relayIP = ip
s.publish()
return s.clientDial
}
s.logf("No relay available.")
timer.Reset(pingInterval)
}
}
}
func (s *peerSuper) selectRelayIP() byte {
possible := make([]byte, 0, 8)
for i, peer := range s.peers {
if peer.CanRelay() {
possible = append(possible, byte(i))
}
}
if len(possible) == 0 {
return 0
}
return possible[rand.Intn(len(possible))]
}
// ----------------------------------------------------------------------------
func (s *peerSuper) clientDial() stateFunc {
s.logf("STATE: client-dial")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.dataCipher.Key(),
RelayIP: s.staged.relayIP,
}
timeout = time.NewTimer(dialTimeout)
)
defer timeout.Stop()
s.sendControlPacket(syn)
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synAckPacket:
if p.TraceID != syn.TraceID {
continue // Hmm...
}
s.sendControlPacket(ackPacket{TraceID: syn.TraceID})
return s.clientConnected(syn.TraceID)
}
case <-timeout.C:
return s.clientInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSuper) clientConnected(traceID uint64) stateFunc {
s.logf("STATE: client-connected")
s.staged.up = true
s.publish()
return func() stateFunc {
return s._clientConnected(traceID)
}
}
func (s *peerSuper) _clientConnected(traceID uint64) stateFunc {
pingTimer := time.NewTimer(pingInterval)
timeoutTimer := time.NewTimer(timeoutInterval)
defer pingTimer.Stop()
defer timeoutTimer.Stop()
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case ackPacket:
if p.TraceID != traceID {
return s.clientInit
}
timeoutTimer.Reset(timeoutInterval)
}
case <-pingTimer.C:
s.sendControlPacket(ackPacket{TraceID: traceID})
pingTimer.Reset(pingInterval)
case <-timeoutTimer.C:
s.logf("client timeout")
return s.clientInit
}
}
}