396 lines
9.0 KiB
Go
396 lines
9.0 KiB
Go
package node
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"net/netip"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
"vppn/m"
|
|
|
|
"git.crumpington.com/lib/go/ratelimiter"
|
|
)
|
|
|
|
const (
|
|
pingInterval = 8 * time.Second
|
|
timeoutInterval = 25 * time.Second
|
|
)
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func startPeerSuper() {
|
|
peers := [256]peerState{}
|
|
for i := range peers {
|
|
data := &peerStateData{
|
|
published: routingTable[i],
|
|
remoteIP: byte(i),
|
|
buf1: make([]byte, bufferSize),
|
|
buf2: make([]byte, bufferSize),
|
|
limiter: ratelimiter.New(ratelimiter.Config{
|
|
FillPeriod: 50 * time.Millisecond,
|
|
MaxWaitCount: 1,
|
|
}),
|
|
}
|
|
peers[i] = data.OnPeerUpdate(nil)
|
|
}
|
|
go runPeerSuper(peers)
|
|
}
|
|
|
|
func runPeerSuper(peers [256]peerState) {
|
|
for raw := range messages {
|
|
switch msg := raw.(type) {
|
|
|
|
case peerUpdateMsg:
|
|
peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer)
|
|
|
|
case controlMsg[synPacket]:
|
|
peers[msg.SrcIP].OnSyn(msg)
|
|
|
|
case controlMsg[ackPacket]:
|
|
peers[msg.SrcIP].OnAck(msg)
|
|
|
|
case controlMsg[probePacket]:
|
|
peers[msg.SrcIP].OnProbe(msg)
|
|
|
|
case controlMsg[localDiscoveryPacket]:
|
|
peers[msg.SrcIP].OnLocalDiscovery(msg)
|
|
|
|
case pingTimerMsg:
|
|
for i := range peers {
|
|
if newState := peers[i].OnPingTimer(); newState != nil {
|
|
peers[i] = newState
|
|
}
|
|
}
|
|
|
|
default:
|
|
log.Printf("WARNING: unknown message type: %+v", msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type peerState interface {
|
|
OnPeerUpdate(*m.Peer) peerState
|
|
OnSyn(controlMsg[synPacket])
|
|
OnAck(controlMsg[ackPacket])
|
|
OnProbe(controlMsg[probePacket])
|
|
OnLocalDiscovery(controlMsg[localDiscoveryPacket])
|
|
OnPingTimer() peerState
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type peerStateData struct {
|
|
// The purpose of this state machine is to manage this published data.
|
|
published *atomic.Pointer[peerRoute]
|
|
staged peerRoute // Local copy of shared data. See publish().
|
|
|
|
// Immutable data.
|
|
remoteIP byte // Remote VPN IP.
|
|
|
|
// Mutable peer data.
|
|
peer *m.Peer
|
|
remotePub bool
|
|
|
|
// Buffers for sending control packets.
|
|
buf1 []byte
|
|
buf2 []byte
|
|
|
|
// For logging. Set per-state.
|
|
client bool
|
|
|
|
limiter *ratelimiter.Limiter
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
|
|
s._sendControlPacket(pkt, s.staged)
|
|
}
|
|
|
|
func (s *peerStateData) sendControlPacketTo(pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort) {
|
|
if !addr.IsValid() {
|
|
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
|
|
return
|
|
}
|
|
route := s.staged
|
|
route.Direct = true
|
|
route.RemoteAddr = addr
|
|
s._sendControlPacket(pkt, route)
|
|
}
|
|
|
|
func (s *peerStateData) _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute) {
|
|
if err := s.limiter.Limit(); err != nil {
|
|
s.logf("Not sending control packet: rate limited.") // Shouldn't happen.
|
|
return
|
|
}
|
|
_sendControlPacket(pkt, route, s.buf1, s.buf2)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerStateData) publish() {
|
|
data := s.staged
|
|
s.published.Store(&data)
|
|
}
|
|
|
|
func (s *peerStateData) logf(format string, args ...any) {
|
|
b := strings.Builder{}
|
|
b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name))
|
|
|
|
if s.client {
|
|
b.WriteString("CLIENT|")
|
|
} else {
|
|
b.WriteString("SERVER|")
|
|
}
|
|
|
|
if s.staged.Direct {
|
|
b.WriteString("DIRECT |")
|
|
} else {
|
|
b.WriteString("RELAYED|")
|
|
}
|
|
|
|
if s.staged.Up {
|
|
b.WriteString("UP |")
|
|
} else {
|
|
b.WriteString("DOWN|")
|
|
}
|
|
|
|
log.Printf(b.String()+format, args...)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState {
|
|
defer s.publish()
|
|
|
|
if peer == nil {
|
|
return enterStateDisconnected(s)
|
|
}
|
|
|
|
s.peer = peer
|
|
s.staged.IP = s.remoteIP
|
|
s.staged.PubSignKey = peer.PubSignKey
|
|
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey)
|
|
s.staged.DataCipher = newDataCipher()
|
|
|
|
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
|
|
s.remotePub = true
|
|
s.staged.Relay = peer.Relay
|
|
s.staged.Direct = true
|
|
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
|
|
} else if localPub {
|
|
s.staged.Direct = true
|
|
}
|
|
|
|
if s.remotePub == localPub {
|
|
if localIP < s.remoteIP {
|
|
return enterStateServer(s)
|
|
}
|
|
return enterStateClient(s)
|
|
}
|
|
|
|
if s.remotePub {
|
|
return enterStateClient(s)
|
|
}
|
|
return enterStateServer(s)
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type stateDisconnected struct {
|
|
*peerStateData
|
|
}
|
|
|
|
func enterStateDisconnected(s *peerStateData) peerState {
|
|
s.peer = nil
|
|
s.staged = peerRoute{}
|
|
s.publish()
|
|
return &stateDisconnected{s}
|
|
}
|
|
|
|
func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {}
|
|
func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {}
|
|
func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {}
|
|
func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {}
|
|
|
|
func (s *stateDisconnected) OnPingTimer() peerState {
|
|
return nil
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type stateServer struct {
|
|
*stateDisconnected
|
|
lastSeen time.Time
|
|
synTraceID uint64
|
|
}
|
|
|
|
func enterStateServer(s *peerStateData) peerState {
|
|
s.client = false
|
|
return &stateServer{stateDisconnected: &stateDisconnected{s}}
|
|
}
|
|
|
|
func (s *stateServer) OnSyn(msg controlMsg[synPacket]) {
|
|
s.lastSeen = time.Now()
|
|
p := msg.Packet
|
|
|
|
// Before we can respond to this packet, we need to make sure the
|
|
// route is setup properly.
|
|
//
|
|
// The client will update the syn's TraceID whenever there's a change.
|
|
// The server will follow the client's request.
|
|
if p.TraceID != s.synTraceID || !s.staged.Up {
|
|
s.synTraceID = p.TraceID
|
|
s.staged.Up = true
|
|
s.staged.Direct = p.Direct
|
|
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
|
|
s.staged.RemoteAddr = msg.SrcAddr
|
|
s.publish()
|
|
s.logf("Got syn.")
|
|
}
|
|
|
|
// Always respond.
|
|
ack := ackPacket{
|
|
TraceID: p.TraceID,
|
|
FromAddr: getLocalAddr(),
|
|
}
|
|
s.sendControlPacket(ack)
|
|
|
|
if !s.staged.Direct && p.FromAddr.IsValid() {
|
|
s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr)
|
|
}
|
|
}
|
|
|
|
func (s *stateServer) OnProbe(msg controlMsg[probePacket]) {
|
|
if !msg.SrcAddr.IsValid() {
|
|
s.logf("Invalid probe address.")
|
|
return
|
|
}
|
|
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
|
|
}
|
|
|
|
func (s *stateServer) OnPingTimer() peerState {
|
|
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
|
|
s.staged.Up = false
|
|
s.publish()
|
|
s.logf("Connection timeout.")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
type stateClient struct {
|
|
*stateDisconnected
|
|
|
|
lastSeen time.Time
|
|
syn synPacket
|
|
ack ackPacket
|
|
|
|
probeTraceID uint64
|
|
probeAddr netip.AddrPort
|
|
|
|
localProbeTraceID uint64
|
|
localProbeAddr netip.AddrPort
|
|
}
|
|
|
|
func enterStateClient(s *peerStateData) peerState {
|
|
s.client = true
|
|
ss := &stateClient{stateDisconnected: &stateDisconnected{s}}
|
|
ss.syn = synPacket{
|
|
TraceID: newTraceID(),
|
|
SharedKey: s.staged.DataCipher.Key(),
|
|
Direct: s.staged.Direct,
|
|
FromAddr: getLocalAddr(),
|
|
}
|
|
ss.sendSyn()
|
|
return ss
|
|
}
|
|
|
|
func (s *stateClient) OnAck(msg controlMsg[ackPacket]) {
|
|
if msg.Packet.TraceID != s.syn.TraceID {
|
|
s.logf("Ack has incorrect trace ID")
|
|
return
|
|
}
|
|
|
|
s.ack = msg.Packet
|
|
s.lastSeen = time.Now()
|
|
|
|
if !s.staged.Up {
|
|
s.staged.Up = true
|
|
s.logf("Got ack.")
|
|
s.publish()
|
|
} else {
|
|
}
|
|
}
|
|
|
|
func (s *stateClient) OnProbe(msg controlMsg[probePacket]) {
|
|
if s.staged.Direct {
|
|
return
|
|
}
|
|
|
|
switch msg.Packet.TraceID {
|
|
case s.probeTraceID:
|
|
s.staged.RemoteAddr = s.probeAddr
|
|
case s.localProbeTraceID:
|
|
s.staged.RemoteAddr = s.localProbeAddr
|
|
default:
|
|
return
|
|
}
|
|
|
|
s.staged.Direct = true
|
|
s.publish()
|
|
|
|
s.syn.TraceID = newTraceID()
|
|
s.syn.Direct = true
|
|
s.syn.FromAddr = getLocalAddr()
|
|
s.sendControlPacket(s.syn)
|
|
|
|
s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String())
|
|
}
|
|
|
|
func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) {
|
|
if s.staged.Direct {
|
|
return
|
|
}
|
|
|
|
// Send probe.
|
|
//
|
|
// The source port will be the multicast port, so we'll have to
|
|
// construct the correct address using the peer's listed port.
|
|
s.localProbeTraceID = newTraceID()
|
|
s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
|
|
s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr)
|
|
}
|
|
|
|
func (s *stateClient) OnPingTimer() peerState {
|
|
if time.Since(s.lastSeen) > timeoutInterval {
|
|
if s.staged.Up {
|
|
s.logf("Connection timeout.")
|
|
}
|
|
return s.OnPeerUpdate(s.peer)
|
|
}
|
|
|
|
s.sendSyn()
|
|
|
|
if !s.staged.Direct && s.ack.FromAddr.IsValid() {
|
|
s.probeTraceID = newTraceID()
|
|
s.probeAddr = s.ack.FromAddr
|
|
s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *stateClient) sendSyn() {
|
|
localAddr := getLocalAddr()
|
|
if localAddr != s.syn.FromAddr {
|
|
s.syn.TraceID = newTraceID()
|
|
s.syn.FromAddr = localAddr
|
|
}
|
|
s.sendControlPacket(s.syn)
|
|
}
|