281 lines
5.6 KiB
Go
281 lines
5.6 KiB
Go
package node
|
|
|
|
import (
|
|
"math/rand"
|
|
"net/netip"
|
|
"time"
|
|
"vppn/m"
|
|
)
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) noPeer() stateFunc {
|
|
return s.peerUpdate(<-s.peerUpdates)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) peerUpdate(peer *m.Peer) stateFunc {
|
|
return func() stateFunc { return s._peerUpdate(peer) }
|
|
}
|
|
|
|
func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc {
|
|
defer s.publish()
|
|
|
|
s.peer = peer
|
|
s.staged = peerRouteInfo{}
|
|
|
|
if s.peer == nil {
|
|
return s.noPeer
|
|
}
|
|
|
|
s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
|
|
s.staged.dataCipher = newDataCipher()
|
|
|
|
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
|
|
s.remotePub = true
|
|
s.staged.relay = peer.Mediator
|
|
s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
|
|
}
|
|
|
|
if s.remotePub == s.localPub {
|
|
if s.localIP < s.remoteIP {
|
|
return s.serverAccept
|
|
}
|
|
return s.clientInit
|
|
}
|
|
|
|
if s.remotePub {
|
|
return s.clientInit
|
|
}
|
|
return s.serverAccept
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) serverAccept() stateFunc {
|
|
s.logf("STATE: server-accept")
|
|
s.staged.up = false
|
|
s.staged.dataCipher = nil
|
|
s.staged.remoteAddr = zeroAddrPort
|
|
s.staged.relayIP = 0
|
|
s.publish()
|
|
|
|
var syn synPacket
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
|
|
case synPacket:
|
|
syn = p
|
|
s.staged.remoteAddr = pkt.RemoteAddr
|
|
s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey)
|
|
s.staged.relayIP = syn.RelayIP
|
|
s.publish()
|
|
s.sendControlPacket(synAckPacket{
|
|
TraceID: syn.TraceID,
|
|
RecvAddr: pkt.RemoteAddr,
|
|
})
|
|
|
|
case ackPacket:
|
|
if p.TraceID != syn.TraceID {
|
|
continue
|
|
}
|
|
|
|
// Publish.
|
|
return s.serverConnected(syn.TraceID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) serverConnected(traceID uint64) stateFunc {
|
|
s.logf("STATE: server-connected")
|
|
s.staged.up = true
|
|
s.publish()
|
|
return func() stateFunc {
|
|
return s._serverConnected(traceID)
|
|
}
|
|
}
|
|
|
|
func (s *peerSuper) _serverConnected(traceID uint64) stateFunc {
|
|
|
|
timeoutTimer := time.NewTimer(timeoutInterval)
|
|
defer timeoutTimer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
|
|
case ackPacket:
|
|
if p.TraceID != traceID {
|
|
return s.serverAccept
|
|
}
|
|
|
|
s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr})
|
|
timeoutTimer.Reset(timeoutInterval)
|
|
}
|
|
|
|
case <-timeoutTimer.C:
|
|
s.logf("server timeout")
|
|
return s.serverAccept
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) clientInit() stateFunc {
|
|
s.logf("STATE: client-init")
|
|
if !s.remotePub {
|
|
// TODO: Check local discovery for IP.
|
|
// TODO: Attempt UDP hole punch.
|
|
// TODO: client-relayed
|
|
return s.clientSelectRelay
|
|
}
|
|
|
|
return s.clientDial
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) clientSelectRelay() stateFunc {
|
|
s.logf("STATE: client-select-relay")
|
|
|
|
timer := time.NewTimer(0)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case <-timer.C:
|
|
ip := s.selectRelayIP()
|
|
if ip != 0 {
|
|
s.logf("Got relay: %d", ip)
|
|
s.staged.relayIP = ip
|
|
s.publish()
|
|
return s.clientDial
|
|
}
|
|
|
|
s.logf("No relay available.")
|
|
timer.Reset(pingInterval)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *peerSuper) selectRelayIP() byte {
|
|
possible := make([]byte, 0, 8)
|
|
for i, peer := range s.peers {
|
|
if peer.CanRelay() {
|
|
possible = append(possible, byte(i))
|
|
}
|
|
}
|
|
|
|
if len(possible) == 0 {
|
|
return 0
|
|
}
|
|
return possible[rand.Intn(len(possible))]
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) clientDial() stateFunc {
|
|
s.logf("STATE: client-dial")
|
|
|
|
var (
|
|
syn = synPacket{
|
|
TraceID: newTraceID(),
|
|
SharedKey: s.staged.dataCipher.Key(),
|
|
RelayIP: s.staged.relayIP,
|
|
}
|
|
|
|
timeout = time.NewTimer(dialTimeout)
|
|
)
|
|
|
|
defer timeout.Stop()
|
|
|
|
s.sendControlPacket(syn)
|
|
|
|
for {
|
|
select {
|
|
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
case synAckPacket:
|
|
if p.TraceID != syn.TraceID {
|
|
continue // Hmm...
|
|
}
|
|
s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
|
|
return s.clientConnected(p)
|
|
}
|
|
|
|
case <-timeout.C:
|
|
return s.clientInit
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerSuper) clientConnected(p synAckPacket) stateFunc {
|
|
s.logf("STATE: client-connected")
|
|
s.staged.up = true
|
|
s.staged.localAddr = p.RecvAddr
|
|
s.publish()
|
|
|
|
return func() stateFunc {
|
|
return s._clientConnected(p.TraceID)
|
|
}
|
|
}
|
|
|
|
func (s *peerSuper) _clientConnected(traceID uint64) stateFunc {
|
|
|
|
pingTimer := time.NewTimer(pingInterval)
|
|
timeoutTimer := time.NewTimer(timeoutInterval)
|
|
|
|
defer pingTimer.Stop()
|
|
defer timeoutTimer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case peer := <-s.peerUpdates:
|
|
return s.peerUpdate(peer)
|
|
|
|
case pkt := <-s.controlPackets:
|
|
switch p := pkt.Payload.(type) {
|
|
|
|
case ackPacket:
|
|
if p.TraceID != traceID {
|
|
return s.clientInit
|
|
}
|
|
timeoutTimer.Reset(timeoutInterval)
|
|
}
|
|
|
|
case <-pingTimer.C:
|
|
s.sendControlPacket(ackPacket{TraceID: traceID})
|
|
pingTimer.Reset(pingInterval)
|
|
|
|
case <-timeoutTimer.C:
|
|
s.logf("client timeout")
|
|
return s.clientInit
|
|
|
|
}
|
|
}
|
|
}
|