fewer-routines #2

Merged
johnnylee merged 3 commits from fewer-routines into main 2025-01-04 12:28:41 +00:00
9 changed files with 414 additions and 385 deletions
Showing only changes of commit b65effa830 - Show all commits

View File

@ -66,12 +66,7 @@ var (
return return
}() }()
messages [256]chan any = func() (out [256]chan any) { messages = make(chan any, 512)
for i := range out {
out[i] = make(chan any, 256)
}
return
}()
// Global routing table. // Global routing table.
routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) {

View File

@ -81,9 +81,11 @@ func (hp *hubPoller) pollHub() {
func (hp *hubPoller) applyNetworkState(state m.NetworkState) { func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
for i, peer := range state.Peers { for i, peer := range state.Peers {
if i != int(localIP) { if i != int(localIP) {
if peer != nil && peer.Version != hp.versions[i] { if peer == nil || peer.Version != hp.versions[i] {
messages[i] <- peerUpdateMsg{Peer: state.Peers[i]} messages <- peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}
hp.versions[i] = peer.Version if peer != nil {
hp.versions[i] = peer.Version
}
} }
} }
} }

View File

@ -59,8 +59,9 @@ func recvLocalDiscovery(conn *net.UDPConn) {
} }
select { select {
case messages[h.SourceIP] <- msg: case messages <- msg:
default: default:
log.Printf("Dropping local discovery message.")
} }
} }
} }
@ -86,7 +87,7 @@ func openLocalDiscoveryPacket(raw, buf []byte) (h header, ok bool) {
h.Parse(raw[signOverhead:]) h.Parse(raw[signOverhead:])
route := routingTable[h.SourceIP].Load() route := routingTable[h.SourceIP].Load()
if route == nil || route.PubSignKey == nil { if route == nil || route.PubSignKey == nil {
log.Printf("Missing signing key") log.Printf("Missing signing key: %d", h.SourceIP)
ok = false ok = false
return return
} }

View File

@ -159,11 +159,6 @@ func main() {
privKey = config.PrivKey privKey = config.PrivKey
privSignKey = config.PrivSignKey privSignKey = config.PrivSignKey
// Start supervisors.
for i := range 256 {
go newPeerSupervisor(i).Run()
}
if localPub { if localPub {
go addrDiscoveryServer() go addrDiscoveryServer()
} else { } else {
@ -174,15 +169,12 @@ func main() {
go func() { go func() {
for range time.Tick(pingInterval) { for range time.Tick(pingInterval) {
for i := range messages { messages <- pingTimerMsg{}
select {
case messages[i] <- pingTimerMsg{}:
default:
}
}
} }
}() }()
go startPeerSuper()
go newHubPoller().Run() go newHubPoller().Run()
go readFromConn(conn) go readFromConn(conn)
readFromIFace(iface) readFromIFace(iface)
@ -272,7 +264,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
} }
select { select {
case messages[h.SourceIP] <- msg: case messages <- msg:
default: default:
log.Printf("Dropping control packet.") log.Printf("Dropping control packet.")
} }

View File

@ -25,8 +25,8 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error
}, err }, err
case packetTypeSynAck: case packetTypeSynAck:
packet, err := parseSynAckPacket(buf) packet, err := parseAckPacket(buf)
return controlMsg[synAckPacket]{ return controlMsg[ackPacket]{
SrcIP: srcIP, SrcIP: srcIP,
SrcAddr: srcAddr, SrcAddr: srcAddr,
Packet: packet, Packet: packet,
@ -56,7 +56,8 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type peerUpdateMsg struct { type peerUpdateMsg struct {
Peer *m.Peer PeerIP byte
Peer *m.Peer
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------

View File

@ -49,12 +49,12 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type synAckPacket struct { type ackPacket struct {
TraceID uint64 TraceID uint64
FromAddr netip.AddrPort FromAddr netip.AddrPort
} }
func (p synAckPacket) Marshal(buf []byte) []byte { func (p ackPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf). return newBinWriter(buf).
Byte(packetTypeSynAck). Byte(packetTypeSynAck).
Uint64(p.TraceID). Uint64(p.TraceID).
@ -62,7 +62,7 @@ func (p synAckPacket) Marshal(buf []byte) []byte {
Build() Build()
} }
func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { func parseAckPacket(buf []byte) (p ackPacket, err error) {
err = newBinReader(buf[1:]). err = newBinReader(buf[1:]).
Uint64(&p.TraceID). Uint64(&p.TraceID).
AddrPort(&p.FromAddr). AddrPort(&p.FromAddr).

View File

@ -25,12 +25,12 @@ func TestPacketSyn(t *testing.T) {
} }
func TestPacketSynAck(t *testing.T) { func TestPacketSynAck(t *testing.T) {
in := synAckPacket{ in := ackPacket{
TraceID: newTraceID(), TraceID: newTraceID(),
FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
} }
out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize))) out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize)))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,354 +0,0 @@
package node
import (
"fmt"
"log"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
pingInterval = 8 * time.Second
timeoutInterval = 25 * time.Second
)
// ----------------------------------------------------------------------------
type peerSupervisor struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Incoming events.
messages chan any
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
}
func newPeerSupervisor(i int) *peerSupervisor {
return &peerSupervisor{
published: routingTable[i],
remoteIP: byte(i),
messages: messages[i],
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
}
}
type stateFunc func() stateFunc
func (s *peerSupervisor) Run() {
state := s.noPeer
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
func (s *peerSupervisor) sendControlPacketTo(
pkt interface{ Marshal([]byte) []byte },
addr netip.AddrPort,
) {
if !addr.IsValid() {
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
return
}
route := s.staged
route.Direct = true
route.RemoteAddr = addr
_sendControlPacket(pkt, route, s.buf1, s.buf2)
time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) logf(msg string, args ...any) {
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) publish() {
data := s.staged
s.published.Store(&data)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) noPeer() stateFunc {
for {
rawMsg := <-s.messages
if msg, ok := rawMsg.(peerUpdateMsg); ok {
return s.peerUpdate(msg.Peer)
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
return func() stateFunc { return s._peerUpdate(peer) }
}
func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
defer s.publish()
s.peer = peer
s.staged = peerRoute{}
if s.peer == nil {
return s.noPeer
}
s.staged.IP = s.remoteIP
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey)
s.staged.PubSignKey = peer.PubSignKey
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
s.staged.Direct = true
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
} else if localPub {
s.staged.Direct = true
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return s.server
}
return s.client
}
if s.remotePub {
return s.client
}
return s.server
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) server() stateFunc {
logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) }
logf("DOWN")
var (
syn synPacket
lastSeen = time.Now()
)
for {
rawMsg := <-s.messages
switch msg := rawMsg.(type) {
case peerUpdateMsg:
return s.peerUpdate(msg.Peer)
case controlMsg[synPacket]:
p := msg.Packet
lastSeen = time.Now()
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != syn.TraceID || !s.staged.Up {
if p.Direct {
logf("UP - Direct")
} else {
logf("UP - Relayed")
}
syn = p
s.staged.Up = true
s.staged.Direct = syn.Direct
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
s.staged.RemoteAddr = msg.SrcAddr
s.publish()
}
// We should always respond.
ack := synAckPacket{
TraceID: syn.TraceID,
FromAddr: getLocalAddr(),
}
s.sendControlPacket(ack)
if s.staged.Direct {
continue
}
if !syn.FromAddr.IsValid() {
continue
}
probe := probePacket{TraceID: newTraceID()}
s.sendControlPacketTo(probe, syn.FromAddr)
case controlMsg[probePacket]:
if !msg.SrcAddr.IsValid() {
logf("Invalid probe address")
continue
}
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
case pingTimerMsg:
if time.Since(lastSeen) > timeoutInterval && s.staged.Up {
logf("Connection timeout")
s.staged.Up = false
s.publish()
}
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) client() stateFunc {
logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) }
logf("DOWN")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct,
FromAddr: getLocalAddr(),
}
lastSeen = time.Now()
ack synAckPacket
probe probePacket
probeAddr netip.AddrPort
localProbe probePacket
localProbeAddr netip.AddrPort
lastLocalAddr netip.AddrPort
)
s.sendControlPacket(syn)
for {
rawMsg := <-s.messages
switch msg := rawMsg.(type) {
case peerUpdateMsg:
return s.peerUpdate(msg.Peer)
case controlMsg[synAckPacket]:
p := msg.Packet
if p.TraceID != syn.TraceID {
continue // Hmm...
}
lastSeen = time.Now()
ack = msg.Packet
if !s.staged.Up {
if s.staged.Direct {
logf("UP - Direct")
} else {
logf("UP - Relayed")
}
s.staged.Up = true
s.publish()
}
case controlMsg[probePacket]:
if s.staged.Direct {
continue
}
p := msg.Packet
if p.TraceID != localProbe.TraceID && p.TraceID != probe.TraceID {
continue
}
// Upgrade connection.
s.staged.Direct = true
if p.TraceID == localProbe.TraceID {
logf("UP - Local")
s.staged.RemoteAddr = localProbeAddr
} else {
logf("UP - Direct")
s.staged.RemoteAddr = probeAddr
}
s.publish()
syn.TraceID = newTraceID()
syn.Direct = true
syn.FromAddr = getLocalAddr()
s.sendControlPacket(syn)
case controlMsg[localDiscoveryPacket]:
if s.staged.Direct {
continue
}
// Send probe.
//
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
localProbe = probePacket{TraceID: newTraceID()}
localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
s.sendControlPacketTo(localProbe, localProbeAddr)
case pingTimerMsg:
if time.Since(lastSeen) > timeoutInterval {
if s.staged.Up {
logf("Connection timeout")
}
return s.peerUpdate(s.peer)
}
syn.FromAddr = getLocalAddr()
if syn.FromAddr != lastLocalAddr {
syn.TraceID = newTraceID()
lastLocalAddr = syn.FromAddr
}
s.sendControlPacket(syn)
if s.staged.Direct {
continue
}
if !ack.FromAddr.IsValid() {
continue
}
probe = probePacket{TraceID: newTraceID()}
probeAddr = ack.FromAddr
s.sendControlPacketTo(probe, ack.FromAddr)
}
}
}

392
node/supervisor.go Normal file
View File

@ -0,0 +1,392 @@
package node
import (
"fmt"
"log"
"net/netip"
"strings"
"sync/atomic"
"time"
"vppn/m"
"git.crumpington.com/lib/go/ratelimiter"
)
const (
pingInterval = 8 * time.Second
timeoutInterval = 25 * time.Second
)
// ----------------------------------------------------------------------------
func startPeerSuper() {
peers := [256]peerState{}
for i := range peers {
data := &peerStateData{
published: routingTable[i],
remoteIP: byte(i),
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 50 * time.Millisecond,
MaxWaitCount: 1,
}),
}
peers[i] = data.OnPeerUpdate(nil)
}
go runPeerSuper(peers)
}
func runPeerSuper(peers [256]peerState) {
for raw := range messages {
switch msg := raw.(type) {
case peerUpdateMsg:
peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer)
case controlMsg[synPacket]:
peers[msg.SrcIP].OnSyn(msg)
case controlMsg[ackPacket]:
peers[msg.SrcIP].OnAck(msg)
case controlMsg[probePacket]:
peers[msg.SrcIP].OnProbe(msg)
case controlMsg[localDiscoveryPacket]:
peers[msg.SrcIP].OnLocalDiscovery(msg)
case pingTimerMsg:
for i := range peers {
if newState := peers[i].OnPingTimer(); newState != nil {
peers[i] = newState
}
}
default:
log.Printf("WARNING: unknown message type: %+v", msg)
}
}
}
// ----------------------------------------------------------------------------
type peerState interface {
OnPeerUpdate(*m.Peer) peerState
OnSyn(controlMsg[synPacket])
OnAck(controlMsg[ackPacket])
OnProbe(controlMsg[probePacket])
OnLocalDiscovery(controlMsg[localDiscoveryPacket])
OnPingTimer() peerState
}
// ----------------------------------------------------------------------------
type peerStateData struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
// For logging. Set per-state.
client bool
limiter *ratelimiter.Limiter
}
// ----------------------------------------------------------------------------
func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
s.limiter.Limit()
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
}
func (s *peerStateData) sendControlPacketTo(
pkt interface{ Marshal([]byte) []byte },
addr netip.AddrPort,
) {
if !addr.IsValid() {
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
return
}
route := s.staged
route.Direct = true
route.RemoteAddr = addr
s.limiter.Limit()
_sendControlPacket(pkt, route, s.buf1, s.buf2)
}
// ----------------------------------------------------------------------------
func (s *peerStateData) publish() {
data := s.staged
s.published.Store(&data)
}
func (s *peerStateData) logf(format string, args ...any) {
b := strings.Builder{}
b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name))
if s.client {
b.WriteString("CLIENT|")
} else {
b.WriteString("SERVER|")
}
if s.staged.Direct {
b.WriteString("DIRECT |")
} else {
b.WriteString("RELAYED|")
}
if s.staged.Up {
b.WriteString("UP |")
} else {
b.WriteString("DOWN|")
}
log.Printf(b.String()+format, args...)
}
// ----------------------------------------------------------------------------
func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState {
defer s.publish()
if peer == nil {
return enterStateDisconnected(s)
}
s.peer = peer
s.staged.IP = s.remoteIP
s.staged.PubSignKey = peer.PubSignKey
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey)
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
s.staged.Direct = true
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
} else if localPub {
s.staged.Direct = true
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return enterStateServer(s)
}
return enterStateClient(s)
}
if s.remotePub {
return enterStateClient(s)
}
return enterStateServer(s)
}
// ----------------------------------------------------------------------------
type stateDisconnected struct {
*peerStateData
}
func enterStateDisconnected(s *peerStateData) peerState {
s.peer = nil
s.staged = peerRoute{}
s.publish()
return &stateDisconnected{s}
}
func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {}
func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {}
func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {}
func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {}
func (s *stateDisconnected) OnPingTimer() peerState {
return nil
}
// ----------------------------------------------------------------------------
type stateServer struct {
*stateDisconnected
lastSeen time.Time
synTraceID uint64
}
func enterStateServer(s *peerStateData) peerState {
s.client = false
return &stateServer{stateDisconnected: &stateDisconnected{s}}
}
func (s *stateServer) OnSyn(msg controlMsg[synPacket]) {
s.lastSeen = time.Now()
p := msg.Packet
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != s.synTraceID || !s.staged.Up {
s.synTraceID = p.TraceID
s.staged.Up = true
s.staged.Direct = p.Direct
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
s.staged.RemoteAddr = msg.SrcAddr
s.publish()
s.logf("Got syn.")
}
// Always respond.
ack := ackPacket{
TraceID: p.TraceID,
FromAddr: getLocalAddr(),
}
s.sendControlPacket(ack)
if !s.staged.Direct && p.FromAddr.IsValid() {
s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr)
}
}
func (s *stateServer) OnProbe(msg controlMsg[probePacket]) {
if !msg.SrcAddr.IsValid() {
s.logf("Invalid probe address.")
return
}
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
}
func (s *stateServer) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
s.staged.Up = false
s.publish()
s.logf("Connection timeout.")
}
return nil
}
// ----------------------------------------------------------------------------
type stateClient struct {
*stateDisconnected
lastSeen time.Time
syn synPacket
ack ackPacket
probeTraceID uint64
probeAddr netip.AddrPort
localProbeTraceID uint64
localProbeAddr netip.AddrPort
}
func enterStateClient(s *peerStateData) peerState {
s.client = true
ss := &stateClient{stateDisconnected: &stateDisconnected{s}}
ss.syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct,
FromAddr: getLocalAddr(),
}
ss.sendSyn()
return ss
}
func (s *stateClient) OnAck(msg controlMsg[ackPacket]) {
if msg.Packet.TraceID != s.syn.TraceID {
s.logf("Ack has incorrect trace ID")
return
}
s.ack = msg.Packet
s.lastSeen = time.Now()
if !s.staged.Up {
s.staged.Up = true
s.logf("Got ack.")
s.publish()
} else {
}
}
func (s *stateClient) OnProbe(msg controlMsg[probePacket]) {
if s.staged.Direct {
return
}
switch msg.Packet.TraceID {
case s.probeTraceID:
s.staged.RemoteAddr = s.probeAddr
case s.localProbeTraceID:
s.staged.RemoteAddr = s.localProbeAddr
default:
return
}
s.staged.Direct = true
s.publish()
s.syn.TraceID = newTraceID()
s.syn.Direct = true
s.syn.FromAddr = getLocalAddr()
s.sendControlPacket(s.syn)
s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String())
}
func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) {
if s.staged.Direct {
return
}
// Send probe.
//
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
s.localProbeTraceID = newTraceID()
s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr)
}
func (s *stateClient) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval {
if s.staged.Up {
s.logf("Connection timeout.")
}
return s.OnPeerUpdate(s.peer)
}
s.sendSyn()
if !s.staged.Direct && s.ack.FromAddr.IsValid() {
s.probeTraceID = newTraceID()
s.probeAddr = s.ack.FromAddr
s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr)
}
return nil
}
func (s *stateClient) sendSyn() {
localAddr := getLocalAddr()
if localAddr != s.syn.FromAddr {
s.syn.TraceID = newTraceID()
s.syn.FromAddr = localAddr
}
s.sendControlPacket(s.syn)
}