fewer-routines #2
@ -66,12 +66,7 @@ var (
|
||||
return
|
||||
}()
|
||||
|
||||
messages [256]chan any = func() (out [256]chan any) {
|
||||
for i := range out {
|
||||
out[i] = make(chan any, 256)
|
||||
}
|
||||
return
|
||||
}()
|
||||
messages = make(chan any, 512)
|
||||
|
||||
// Global routing table.
|
||||
routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) {
|
||||
|
@ -81,10 +81,12 @@ func (hp *hubPoller) pollHub() {
|
||||
func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
|
||||
for i, peer := range state.Peers {
|
||||
if i != int(localIP) {
|
||||
if peer != nil && peer.Version != hp.versions[i] {
|
||||
messages[i] <- peerUpdateMsg{Peer: state.Peers[i]}
|
||||
if peer == nil || peer.Version != hp.versions[i] {
|
||||
messages <- peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}
|
||||
if peer != nil {
|
||||
hp.versions[i] = peer.Version
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -59,8 +59,9 @@ func recvLocalDiscovery(conn *net.UDPConn) {
|
||||
}
|
||||
|
||||
select {
|
||||
case messages[h.SourceIP] <- msg:
|
||||
case messages <- msg:
|
||||
default:
|
||||
log.Printf("Dropping local discovery message.")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -86,7 +87,7 @@ func openLocalDiscoveryPacket(raw, buf []byte) (h header, ok bool) {
|
||||
h.Parse(raw[signOverhead:])
|
||||
route := routingTable[h.SourceIP].Load()
|
||||
if route == nil || route.PubSignKey == nil {
|
||||
log.Printf("Missing signing key")
|
||||
log.Printf("Missing signing key: %d", h.SourceIP)
|
||||
ok = false
|
||||
return
|
||||
}
|
||||
|
16
node/main.go
16
node/main.go
@ -159,11 +159,6 @@ func main() {
|
||||
privKey = config.PrivKey
|
||||
privSignKey = config.PrivSignKey
|
||||
|
||||
// Start supervisors.
|
||||
for i := range 256 {
|
||||
go newPeerSupervisor(i).Run()
|
||||
}
|
||||
|
||||
if localPub {
|
||||
go addrDiscoveryServer()
|
||||
} else {
|
||||
@ -174,15 +169,12 @@ func main() {
|
||||
|
||||
go func() {
|
||||
for range time.Tick(pingInterval) {
|
||||
for i := range messages {
|
||||
select {
|
||||
case messages[i] <- pingTimerMsg{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
messages <- pingTimerMsg{}
|
||||
}
|
||||
}()
|
||||
|
||||
go startPeerSuper()
|
||||
|
||||
go newHubPoller().Run()
|
||||
go readFromConn(conn)
|
||||
readFromIFace(iface)
|
||||
@ -272,7 +264,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
|
||||
}
|
||||
|
||||
select {
|
||||
case messages[h.SourceIP] <- msg:
|
||||
case messages <- msg:
|
||||
default:
|
||||
log.Printf("Dropping control packet.")
|
||||
}
|
||||
|
@ -25,8 +25,8 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error
|
||||
}, err
|
||||
|
||||
case packetTypeSynAck:
|
||||
packet, err := parseSynAckPacket(buf)
|
||||
return controlMsg[synAckPacket]{
|
||||
packet, err := parseAckPacket(buf)
|
||||
return controlMsg[ackPacket]{
|
||||
SrcIP: srcIP,
|
||||
SrcAddr: srcAddr,
|
||||
Packet: packet,
|
||||
@ -56,6 +56,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type peerUpdateMsg struct {
|
||||
PeerIP byte
|
||||
Peer *m.Peer
|
||||
}
|
||||
|
||||
|
@ -49,12 +49,12 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type synAckPacket struct {
|
||||
type ackPacket struct {
|
||||
TraceID uint64
|
||||
FromAddr netip.AddrPort
|
||||
}
|
||||
|
||||
func (p synAckPacket) Marshal(buf []byte) []byte {
|
||||
func (p ackPacket) Marshal(buf []byte) []byte {
|
||||
return newBinWriter(buf).
|
||||
Byte(packetTypeSynAck).
|
||||
Uint64(p.TraceID).
|
||||
@ -62,7 +62,7 @@ func (p synAckPacket) Marshal(buf []byte) []byte {
|
||||
Build()
|
||||
}
|
||||
|
||||
func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
|
||||
func parseAckPacket(buf []byte) (p ackPacket, err error) {
|
||||
err = newBinReader(buf[1:]).
|
||||
Uint64(&p.TraceID).
|
||||
AddrPort(&p.FromAddr).
|
||||
|
@ -25,12 +25,12 @@ func TestPacketSyn(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPacketSynAck(t *testing.T) {
|
||||
in := synAckPacket{
|
||||
in := ackPacket{
|
||||
TraceID: newTraceID(),
|
||||
FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
|
||||
}
|
||||
|
||||
out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize)))
|
||||
out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1,354 +0,0 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"vppn/m"
|
||||
)
|
||||
|
||||
const (
|
||||
pingInterval = 8 * time.Second
|
||||
timeoutInterval = 25 * time.Second
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type peerSupervisor struct {
|
||||
// The purpose of this state machine is to manage this published data.
|
||||
published *atomic.Pointer[peerRoute]
|
||||
staged peerRoute // Local copy of shared data. See publish().
|
||||
|
||||
// Immutable data.
|
||||
remoteIP byte // Remote VPN IP.
|
||||
|
||||
// Mutable peer data.
|
||||
peer *m.Peer
|
||||
remotePub bool
|
||||
|
||||
// Incoming events.
|
||||
messages chan any
|
||||
|
||||
// Buffers for sending control packets.
|
||||
buf1 []byte
|
||||
buf2 []byte
|
||||
}
|
||||
|
||||
func newPeerSupervisor(i int) *peerSupervisor {
|
||||
return &peerSupervisor{
|
||||
published: routingTable[i],
|
||||
remoteIP: byte(i),
|
||||
messages: messages[i],
|
||||
buf1: make([]byte, bufferSize),
|
||||
buf2: make([]byte, bufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
type stateFunc func() stateFunc
|
||||
|
||||
func (s *peerSupervisor) Run() {
|
||||
state := s.noPeer
|
||||
for {
|
||||
state = state()
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
|
||||
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
|
||||
time.Sleep(500 * time.Millisecond) // Rate limit packets.
|
||||
}
|
||||
|
||||
func (s *peerSupervisor) sendControlPacketTo(
|
||||
pkt interface{ Marshal([]byte) []byte },
|
||||
addr netip.AddrPort,
|
||||
) {
|
||||
if !addr.IsValid() {
|
||||
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
|
||||
return
|
||||
}
|
||||
route := s.staged
|
||||
route.Direct = true
|
||||
route.RemoteAddr = addr
|
||||
_sendControlPacket(pkt, route, s.buf1, s.buf2)
|
||||
time.Sleep(500 * time.Millisecond) // Rate limit packets.
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerSupervisor) logf(msg string, args ...any) {
|
||||
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerSupervisor) publish() {
|
||||
data := s.staged
|
||||
s.published.Store(&data)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerSupervisor) noPeer() stateFunc {
|
||||
for {
|
||||
rawMsg := <-s.messages
|
||||
if msg, ok := rawMsg.(peerUpdateMsg); ok {
|
||||
return s.peerUpdate(msg.Peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
|
||||
return func() stateFunc { return s._peerUpdate(peer) }
|
||||
}
|
||||
|
||||
func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
|
||||
defer s.publish()
|
||||
|
||||
s.peer = peer
|
||||
s.staged = peerRoute{}
|
||||
|
||||
if s.peer == nil {
|
||||
return s.noPeer
|
||||
}
|
||||
|
||||
s.staged.IP = s.remoteIP
|
||||
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey)
|
||||
s.staged.PubSignKey = peer.PubSignKey
|
||||
s.staged.DataCipher = newDataCipher()
|
||||
|
||||
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
|
||||
s.remotePub = true
|
||||
s.staged.Relay = peer.Relay
|
||||
s.staged.Direct = true
|
||||
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
|
||||
} else if localPub {
|
||||
s.staged.Direct = true
|
||||
}
|
||||
|
||||
if s.remotePub == localPub {
|
||||
if localIP < s.remoteIP {
|
||||
return s.server
|
||||
}
|
||||
return s.client
|
||||
}
|
||||
|
||||
if s.remotePub {
|
||||
return s.client
|
||||
}
|
||||
return s.server
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerSupervisor) server() stateFunc {
|
||||
logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) }
|
||||
|
||||
logf("DOWN")
|
||||
|
||||
var (
|
||||
syn synPacket
|
||||
lastSeen = time.Now()
|
||||
)
|
||||
|
||||
for {
|
||||
rawMsg := <-s.messages
|
||||
switch msg := rawMsg.(type) {
|
||||
|
||||
case peerUpdateMsg:
|
||||
return s.peerUpdate(msg.Peer)
|
||||
|
||||
case controlMsg[synPacket]:
|
||||
p := msg.Packet
|
||||
lastSeen = time.Now()
|
||||
|
||||
// Before we can respond to this packet, we need to make sure the
|
||||
// route is setup properly.
|
||||
//
|
||||
// The client will update the syn's TraceID whenever there's a change.
|
||||
// The server will follow the client's request.
|
||||
if p.TraceID != syn.TraceID || !s.staged.Up {
|
||||
if p.Direct {
|
||||
logf("UP - Direct")
|
||||
} else {
|
||||
logf("UP - Relayed")
|
||||
}
|
||||
|
||||
syn = p
|
||||
s.staged.Up = true
|
||||
s.staged.Direct = syn.Direct
|
||||
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
|
||||
s.staged.RemoteAddr = msg.SrcAddr
|
||||
|
||||
s.publish()
|
||||
}
|
||||
|
||||
// We should always respond.
|
||||
ack := synAckPacket{
|
||||
TraceID: syn.TraceID,
|
||||
FromAddr: getLocalAddr(),
|
||||
}
|
||||
s.sendControlPacket(ack)
|
||||
|
||||
if s.staged.Direct {
|
||||
continue
|
||||
}
|
||||
|
||||
if !syn.FromAddr.IsValid() {
|
||||
continue
|
||||
}
|
||||
|
||||
probe := probePacket{TraceID: newTraceID()}
|
||||
s.sendControlPacketTo(probe, syn.FromAddr)
|
||||
|
||||
case controlMsg[probePacket]:
|
||||
if !msg.SrcAddr.IsValid() {
|
||||
logf("Invalid probe address")
|
||||
continue
|
||||
}
|
||||
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
|
||||
|
||||
case pingTimerMsg:
|
||||
if time.Since(lastSeen) > timeoutInterval && s.staged.Up {
|
||||
logf("Connection timeout")
|
||||
s.staged.Up = false
|
||||
s.publish()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerSupervisor) client() stateFunc {
|
||||
logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) }
|
||||
|
||||
logf("DOWN")
|
||||
|
||||
var (
|
||||
syn = synPacket{
|
||||
TraceID: newTraceID(),
|
||||
SharedKey: s.staged.DataCipher.Key(),
|
||||
Direct: s.staged.Direct,
|
||||
FromAddr: getLocalAddr(),
|
||||
}
|
||||
|
||||
lastSeen = time.Now()
|
||||
ack synAckPacket
|
||||
|
||||
probe probePacket
|
||||
probeAddr netip.AddrPort
|
||||
|
||||
localProbe probePacket
|
||||
localProbeAddr netip.AddrPort
|
||||
|
||||
lastLocalAddr netip.AddrPort
|
||||
)
|
||||
|
||||
s.sendControlPacket(syn)
|
||||
|
||||
for {
|
||||
rawMsg := <-s.messages
|
||||
switch msg := rawMsg.(type) {
|
||||
|
||||
case peerUpdateMsg:
|
||||
return s.peerUpdate(msg.Peer)
|
||||
|
||||
case controlMsg[synAckPacket]:
|
||||
p := msg.Packet
|
||||
|
||||
if p.TraceID != syn.TraceID {
|
||||
continue // Hmm...
|
||||
}
|
||||
|
||||
lastSeen = time.Now()
|
||||
ack = msg.Packet
|
||||
|
||||
if !s.staged.Up {
|
||||
if s.staged.Direct {
|
||||
logf("UP - Direct")
|
||||
} else {
|
||||
logf("UP - Relayed")
|
||||
}
|
||||
|
||||
s.staged.Up = true
|
||||
s.publish()
|
||||
}
|
||||
|
||||
case controlMsg[probePacket]:
|
||||
if s.staged.Direct {
|
||||
continue
|
||||
}
|
||||
|
||||
p := msg.Packet
|
||||
|
||||
if p.TraceID != localProbe.TraceID && p.TraceID != probe.TraceID {
|
||||
continue
|
||||
}
|
||||
|
||||
// Upgrade connection.
|
||||
|
||||
s.staged.Direct = true
|
||||
if p.TraceID == localProbe.TraceID {
|
||||
logf("UP - Local")
|
||||
s.staged.RemoteAddr = localProbeAddr
|
||||
} else {
|
||||
logf("UP - Direct")
|
||||
s.staged.RemoteAddr = probeAddr
|
||||
}
|
||||
s.publish()
|
||||
|
||||
syn.TraceID = newTraceID()
|
||||
syn.Direct = true
|
||||
syn.FromAddr = getLocalAddr()
|
||||
s.sendControlPacket(syn)
|
||||
|
||||
case controlMsg[localDiscoveryPacket]:
|
||||
if s.staged.Direct {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send probe.
|
||||
//
|
||||
// The source port will be the multicast port, so we'll have to
|
||||
// construct the correct address using the peer's listed port.
|
||||
localProbe = probePacket{TraceID: newTraceID()}
|
||||
localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
|
||||
s.sendControlPacketTo(localProbe, localProbeAddr)
|
||||
|
||||
case pingTimerMsg:
|
||||
if time.Since(lastSeen) > timeoutInterval {
|
||||
if s.staged.Up {
|
||||
logf("Connection timeout")
|
||||
}
|
||||
return s.peerUpdate(s.peer)
|
||||
}
|
||||
|
||||
syn.FromAddr = getLocalAddr()
|
||||
if syn.FromAddr != lastLocalAddr {
|
||||
syn.TraceID = newTraceID()
|
||||
lastLocalAddr = syn.FromAddr
|
||||
}
|
||||
|
||||
s.sendControlPacket(syn)
|
||||
|
||||
if s.staged.Direct {
|
||||
continue
|
||||
}
|
||||
|
||||
if !ack.FromAddr.IsValid() {
|
||||
continue
|
||||
}
|
||||
|
||||
probe = probePacket{TraceID: newTraceID()}
|
||||
probeAddr = ack.FromAddr
|
||||
|
||||
s.sendControlPacketTo(probe, ack.FromAddr)
|
||||
}
|
||||
}
|
||||
}
|
392
node/supervisor.go
Normal file
392
node/supervisor.go
Normal file
@ -0,0 +1,392 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"vppn/m"
|
||||
|
||||
"git.crumpington.com/lib/go/ratelimiter"
|
||||
)
|
||||
|
||||
const (
|
||||
pingInterval = 8 * time.Second
|
||||
timeoutInterval = 25 * time.Second
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func startPeerSuper() {
|
||||
peers := [256]peerState{}
|
||||
for i := range peers {
|
||||
data := &peerStateData{
|
||||
published: routingTable[i],
|
||||
remoteIP: byte(i),
|
||||
buf1: make([]byte, bufferSize),
|
||||
buf2: make([]byte, bufferSize),
|
||||
limiter: ratelimiter.New(ratelimiter.Config{
|
||||
FillPeriod: 50 * time.Millisecond,
|
||||
MaxWaitCount: 1,
|
||||
}),
|
||||
}
|
||||
peers[i] = data.OnPeerUpdate(nil)
|
||||
}
|
||||
go runPeerSuper(peers)
|
||||
}
|
||||
|
||||
func runPeerSuper(peers [256]peerState) {
|
||||
for raw := range messages {
|
||||
switch msg := raw.(type) {
|
||||
|
||||
case peerUpdateMsg:
|
||||
peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer)
|
||||
|
||||
case controlMsg[synPacket]:
|
||||
peers[msg.SrcIP].OnSyn(msg)
|
||||
|
||||
case controlMsg[ackPacket]:
|
||||
peers[msg.SrcIP].OnAck(msg)
|
||||
|
||||
case controlMsg[probePacket]:
|
||||
peers[msg.SrcIP].OnProbe(msg)
|
||||
|
||||
case controlMsg[localDiscoveryPacket]:
|
||||
peers[msg.SrcIP].OnLocalDiscovery(msg)
|
||||
|
||||
case pingTimerMsg:
|
||||
for i := range peers {
|
||||
if newState := peers[i].OnPingTimer(); newState != nil {
|
||||
peers[i] = newState
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
log.Printf("WARNING: unknown message type: %+v", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type peerState interface {
|
||||
OnPeerUpdate(*m.Peer) peerState
|
||||
OnSyn(controlMsg[synPacket])
|
||||
OnAck(controlMsg[ackPacket])
|
||||
OnProbe(controlMsg[probePacket])
|
||||
OnLocalDiscovery(controlMsg[localDiscoveryPacket])
|
||||
OnPingTimer() peerState
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type peerStateData struct {
|
||||
// The purpose of this state machine is to manage this published data.
|
||||
published *atomic.Pointer[peerRoute]
|
||||
staged peerRoute // Local copy of shared data. See publish().
|
||||
|
||||
// Immutable data.
|
||||
remoteIP byte // Remote VPN IP.
|
||||
|
||||
// Mutable peer data.
|
||||
peer *m.Peer
|
||||
remotePub bool
|
||||
|
||||
// Buffers for sending control packets.
|
||||
buf1 []byte
|
||||
buf2 []byte
|
||||
|
||||
// For logging. Set per-state.
|
||||
client bool
|
||||
|
||||
limiter *ratelimiter.Limiter
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
|
||||
s.limiter.Limit()
|
||||
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
|
||||
}
|
||||
|
||||
func (s *peerStateData) sendControlPacketTo(
|
||||
pkt interface{ Marshal([]byte) []byte },
|
||||
addr netip.AddrPort,
|
||||
) {
|
||||
if !addr.IsValid() {
|
||||
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
|
||||
return
|
||||
}
|
||||
route := s.staged
|
||||
route.Direct = true
|
||||
route.RemoteAddr = addr
|
||||
s.limiter.Limit()
|
||||
_sendControlPacket(pkt, route, s.buf1, s.buf2)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerStateData) publish() {
|
||||
data := s.staged
|
||||
s.published.Store(&data)
|
||||
}
|
||||
|
||||
func (s *peerStateData) logf(format string, args ...any) {
|
||||
b := strings.Builder{}
|
||||
b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name))
|
||||
|
||||
if s.client {
|
||||
b.WriteString("CLIENT|")
|
||||
} else {
|
||||
b.WriteString("SERVER|")
|
||||
}
|
||||
|
||||
if s.staged.Direct {
|
||||
b.WriteString("DIRECT |")
|
||||
} else {
|
||||
b.WriteString("RELAYED|")
|
||||
}
|
||||
|
||||
if s.staged.Up {
|
||||
b.WriteString("UP |")
|
||||
} else {
|
||||
b.WriteString("DOWN|")
|
||||
}
|
||||
|
||||
log.Printf(b.String()+format, args...)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState {
|
||||
defer s.publish()
|
||||
|
||||
if peer == nil {
|
||||
return enterStateDisconnected(s)
|
||||
}
|
||||
|
||||
s.peer = peer
|
||||
s.staged.IP = s.remoteIP
|
||||
s.staged.PubSignKey = peer.PubSignKey
|
||||
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey)
|
||||
s.staged.DataCipher = newDataCipher()
|
||||
|
||||
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
|
||||
s.remotePub = true
|
||||
s.staged.Relay = peer.Relay
|
||||
s.staged.Direct = true
|
||||
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
|
||||
} else if localPub {
|
||||
s.staged.Direct = true
|
||||
}
|
||||
|
||||
if s.remotePub == localPub {
|
||||
if localIP < s.remoteIP {
|
||||
return enterStateServer(s)
|
||||
}
|
||||
return enterStateClient(s)
|
||||
}
|
||||
|
||||
if s.remotePub {
|
||||
return enterStateClient(s)
|
||||
}
|
||||
return enterStateServer(s)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type stateDisconnected struct {
|
||||
*peerStateData
|
||||
}
|
||||
|
||||
func enterStateDisconnected(s *peerStateData) peerState {
|
||||
s.peer = nil
|
||||
s.staged = peerRoute{}
|
||||
s.publish()
|
||||
return &stateDisconnected{s}
|
||||
}
|
||||
|
||||
func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {}
|
||||
func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {}
|
||||
func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {}
|
||||
func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {}
|
||||
|
||||
func (s *stateDisconnected) OnPingTimer() peerState {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type stateServer struct {
|
||||
*stateDisconnected
|
||||
lastSeen time.Time
|
||||
synTraceID uint64
|
||||
}
|
||||
|
||||
func enterStateServer(s *peerStateData) peerState {
|
||||
s.client = false
|
||||
return &stateServer{stateDisconnected: &stateDisconnected{s}}
|
||||
}
|
||||
|
||||
func (s *stateServer) OnSyn(msg controlMsg[synPacket]) {
|
||||
s.lastSeen = time.Now()
|
||||
p := msg.Packet
|
||||
|
||||
// Before we can respond to this packet, we need to make sure the
|
||||
// route is setup properly.
|
||||
//
|
||||
// The client will update the syn's TraceID whenever there's a change.
|
||||
// The server will follow the client's request.
|
||||
if p.TraceID != s.synTraceID || !s.staged.Up {
|
||||
s.synTraceID = p.TraceID
|
||||
s.staged.Up = true
|
||||
s.staged.Direct = p.Direct
|
||||
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
|
||||
s.staged.RemoteAddr = msg.SrcAddr
|
||||
s.publish()
|
||||
s.logf("Got syn.")
|
||||
}
|
||||
|
||||
// Always respond.
|
||||
ack := ackPacket{
|
||||
TraceID: p.TraceID,
|
||||
FromAddr: getLocalAddr(),
|
||||
}
|
||||
s.sendControlPacket(ack)
|
||||
|
||||
if !s.staged.Direct && p.FromAddr.IsValid() {
|
||||
s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stateServer) OnProbe(msg controlMsg[probePacket]) {
|
||||
if !msg.SrcAddr.IsValid() {
|
||||
s.logf("Invalid probe address.")
|
||||
return
|
||||
}
|
||||
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
|
||||
}
|
||||
|
||||
func (s *stateServer) OnPingTimer() peerState {
|
||||
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
|
||||
s.staged.Up = false
|
||||
s.publish()
|
||||
s.logf("Connection timeout.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type stateClient struct {
|
||||
*stateDisconnected
|
||||
|
||||
lastSeen time.Time
|
||||
syn synPacket
|
||||
ack ackPacket
|
||||
|
||||
probeTraceID uint64
|
||||
probeAddr netip.AddrPort
|
||||
|
||||
localProbeTraceID uint64
|
||||
localProbeAddr netip.AddrPort
|
||||
}
|
||||
|
||||
func enterStateClient(s *peerStateData) peerState {
|
||||
s.client = true
|
||||
ss := &stateClient{stateDisconnected: &stateDisconnected{s}}
|
||||
ss.syn = synPacket{
|
||||
TraceID: newTraceID(),
|
||||
SharedKey: s.staged.DataCipher.Key(),
|
||||
Direct: s.staged.Direct,
|
||||
FromAddr: getLocalAddr(),
|
||||
}
|
||||
ss.sendSyn()
|
||||
return ss
|
||||
}
|
||||
|
||||
func (s *stateClient) OnAck(msg controlMsg[ackPacket]) {
|
||||
if msg.Packet.TraceID != s.syn.TraceID {
|
||||
s.logf("Ack has incorrect trace ID")
|
||||
return
|
||||
}
|
||||
|
||||
s.ack = msg.Packet
|
||||
s.lastSeen = time.Now()
|
||||
|
||||
if !s.staged.Up {
|
||||
s.staged.Up = true
|
||||
s.logf("Got ack.")
|
||||
s.publish()
|
||||
} else {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stateClient) OnProbe(msg controlMsg[probePacket]) {
|
||||
if s.staged.Direct {
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Packet.TraceID {
|
||||
case s.probeTraceID:
|
||||
s.staged.RemoteAddr = s.probeAddr
|
||||
case s.localProbeTraceID:
|
||||
s.staged.RemoteAddr = s.localProbeAddr
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
s.staged.Direct = true
|
||||
s.publish()
|
||||
|
||||
s.syn.TraceID = newTraceID()
|
||||
s.syn.Direct = true
|
||||
s.syn.FromAddr = getLocalAddr()
|
||||
s.sendControlPacket(s.syn)
|
||||
|
||||
s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String())
|
||||
}
|
||||
|
||||
func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) {
|
||||
if s.staged.Direct {
|
||||
return
|
||||
}
|
||||
|
||||
// Send probe.
|
||||
//
|
||||
// The source port will be the multicast port, so we'll have to
|
||||
// construct the correct address using the peer's listed port.
|
||||
s.localProbeTraceID = newTraceID()
|
||||
s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
|
||||
s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr)
|
||||
}
|
||||
|
||||
func (s *stateClient) OnPingTimer() peerState {
|
||||
if time.Since(s.lastSeen) > timeoutInterval {
|
||||
if s.staged.Up {
|
||||
s.logf("Connection timeout.")
|
||||
}
|
||||
return s.OnPeerUpdate(s.peer)
|
||||
}
|
||||
|
||||
s.sendSyn()
|
||||
|
||||
if !s.staged.Direct && s.ack.FromAddr.IsValid() {
|
||||
s.probeTraceID = newTraceID()
|
||||
s.probeAddr = s.ack.FromAddr
|
||||
s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stateClient) sendSyn() {
|
||||
localAddr := getLocalAddr()
|
||||
if localAddr != s.syn.FromAddr {
|
||||
s.syn.TraceID = newTraceID()
|
||||
s.syn.FromAddr = localAddr
|
||||
}
|
||||
s.sendControlPacket(s.syn)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user