fewer-routines #2
| @@ -25,9 +25,6 @@ scp hub user@<remote>:~/ | ||||
| Create systemd file in `/etc/systemd/system/hub.service | ||||
|  | ||||
| ``` | ||||
| Description=hub | ||||
| Requires=network.target | ||||
|  | ||||
| [Service] | ||||
| AmbientCapabilities=CAP_NET_BIND_SERVICE | ||||
| Type=simple | ||||
| @@ -65,9 +62,6 @@ Create systemd file in `/etc/systemd/system/vppn.service`. | ||||
|  | ||||
|  | ||||
| ``` | ||||
| Description=vppn | ||||
| Requires=network.target | ||||
|  | ||||
| [Service] | ||||
| AmbientCapabilities=CAP_NET_BIND_SERVICE CAP_NET_ADMIN | ||||
| Type=simple | ||||
|   | ||||
| @@ -66,12 +66,8 @@ var ( | ||||
| 		return | ||||
| 	}() | ||||
|  | ||||
| 	messages [256]chan any = func() (out [256]chan any) { | ||||
| 		for i := range out { | ||||
| 			out[i] = make(chan any, 256) | ||||
| 		} | ||||
| 		return | ||||
| 	}() | ||||
| 	// Messages for the supervisor. | ||||
| 	messages = make(chan any, 512) | ||||
|  | ||||
| 	// Global routing table. | ||||
| 	routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { | ||||
|   | ||||
| @@ -81,9 +81,11 @@ func (hp *hubPoller) pollHub() { | ||||
| func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | ||||
| 	for i, peer := range state.Peers { | ||||
| 		if i != int(localIP) { | ||||
| 			if peer != nil && peer.Version != hp.versions[i] { | ||||
| 				messages[i] <- peerUpdateMsg{Peer: state.Peers[i]} | ||||
| 				hp.versions[i] = peer.Version | ||||
| 			if peer == nil || peer.Version != hp.versions[i] { | ||||
| 				messages <- peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]} | ||||
| 				if peer != nil { | ||||
| 					hp.versions[i] = peer.Version | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -59,8 +59,9 @@ func recvLocalDiscovery(conn *net.UDPConn) { | ||||
| 		} | ||||
|  | ||||
| 		select { | ||||
| 		case messages[h.SourceIP] <- msg: | ||||
| 		case messages <- msg: | ||||
| 		default: | ||||
| 			log.Printf("Dropping local discovery message.") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -86,7 +87,7 @@ func openLocalDiscoveryPacket(raw, buf []byte) (h header, ok bool) { | ||||
| 	h.Parse(raw[signOverhead:]) | ||||
| 	route := routingTable[h.SourceIP].Load() | ||||
| 	if route == nil || route.PubSignKey == nil { | ||||
| 		log.Printf("Missing signing key") | ||||
| 		log.Printf("Missing signing key: %d", h.SourceIP) | ||||
| 		ok = false | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							| @@ -159,11 +159,6 @@ func main() { | ||||
| 	privKey = config.PrivKey | ||||
| 	privSignKey = config.PrivSignKey | ||||
|  | ||||
| 	// Start supervisors. | ||||
| 	for i := range 256 { | ||||
| 		go newPeerSupervisor(i).Run() | ||||
| 	} | ||||
|  | ||||
| 	if localPub { | ||||
| 		go addrDiscoveryServer() | ||||
| 	} else { | ||||
| @@ -174,15 +169,12 @@ func main() { | ||||
|  | ||||
| 	go func() { | ||||
| 		for range time.Tick(pingInterval) { | ||||
| 			for i := range messages { | ||||
| 				select { | ||||
| 				case messages[i] <- pingTimerMsg{}: | ||||
| 				default: | ||||
| 				} | ||||
| 			} | ||||
| 			messages <- pingTimerMsg{} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	go startPeerSuper() | ||||
|  | ||||
| 	go newHubPoller().Run() | ||||
| 	go readFromConn(conn) | ||||
| 	readFromIFace(iface) | ||||
| @@ -272,7 +264,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | ||||
| 	} | ||||
|  | ||||
| 	select { | ||||
| 	case messages[h.SourceIP] <- msg: | ||||
| 	case messages <- msg: | ||||
| 	default: | ||||
| 		log.Printf("Dropping control packet.") | ||||
| 	} | ||||
|   | ||||
| @@ -25,8 +25,8 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | ||||
| 		}, err | ||||
|  | ||||
| 	case packetTypeSynAck: | ||||
| 		packet, err := parseSynAckPacket(buf) | ||||
| 		return controlMsg[synAckPacket]{ | ||||
| 		packet, err := parseAckPacket(buf) | ||||
| 		return controlMsg[ackPacket]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
| @@ -56,7 +56,8 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type peerUpdateMsg struct { | ||||
| 	Peer *m.Peer | ||||
| 	PeerIP byte | ||||
| 	Peer   *m.Peer | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|   | ||||
| @@ -49,12 +49,12 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type synAckPacket struct { | ||||
| type ackPacket struct { | ||||
| 	TraceID  uint64 | ||||
| 	FromAddr netip.AddrPort | ||||
| } | ||||
|  | ||||
| func (p synAckPacket) Marshal(buf []byte) []byte { | ||||
| func (p ackPacket) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeSynAck). | ||||
| 		Uint64(p.TraceID). | ||||
| @@ -62,7 +62,7 @@ func (p synAckPacket) Marshal(buf []byte) []byte { | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { | ||||
| func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		AddrPort(&p.FromAddr). | ||||
|   | ||||
| @@ -25,12 +25,12 @@ func TestPacketSyn(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestPacketSynAck(t *testing.T) { | ||||
| 	in := synAckPacket{ | ||||
| 	in := ackPacket{ | ||||
| 		TraceID:  newTraceID(), | ||||
| 		FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), | ||||
| 	} | ||||
|  | ||||
| 	out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize))) | ||||
| 	out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize))) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,354 +0,0 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	pingInterval    = 8 * time.Second | ||||
| 	timeoutInterval = 25 * time.Second | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type peerSupervisor struct { | ||||
| 	// The purpose of this state machine is to manage this published data. | ||||
| 	published *atomic.Pointer[peerRoute] | ||||
| 	staged    peerRoute // Local copy of shared data. See publish(). | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	remoteIP byte // Remote VPN IP. | ||||
|  | ||||
| 	// Mutable peer data. | ||||
| 	peer      *m.Peer | ||||
| 	remotePub bool | ||||
|  | ||||
| 	// Incoming events. | ||||
| 	messages chan any | ||||
|  | ||||
| 	// Buffers for sending control packets. | ||||
| 	buf1 []byte | ||||
| 	buf2 []byte | ||||
| } | ||||
|  | ||||
| func newPeerSupervisor(i int) *peerSupervisor { | ||||
| 	return &peerSupervisor{ | ||||
| 		published: routingTable[i], | ||||
| 		remoteIP:  byte(i), | ||||
| 		messages:  messages[i], | ||||
| 		buf1:      make([]byte, bufferSize), | ||||
| 		buf2:      make([]byte, bufferSize), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type stateFunc func() stateFunc | ||||
|  | ||||
| func (s *peerSupervisor) Run() { | ||||
| 	state := s.noPeer | ||||
| 	for { | ||||
| 		state = state() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { | ||||
| 	_sendControlPacket(pkt, s.staged, s.buf1, s.buf2) | ||||
| 	time.Sleep(500 * time.Millisecond) // Rate limit packets. | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) sendControlPacketTo( | ||||
| 	pkt interface{ Marshal([]byte) []byte }, | ||||
| 	addr netip.AddrPort, | ||||
| ) { | ||||
| 	if !addr.IsValid() { | ||||
| 		s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) | ||||
| 		return | ||||
| 	} | ||||
| 	route := s.staged | ||||
| 	route.Direct = true | ||||
| 	route.RemoteAddr = addr | ||||
| 	_sendControlPacket(pkt, route, s.buf1, s.buf2) | ||||
| 	time.Sleep(500 * time.Millisecond) // Rate limit packets. | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) logf(msg string, args ...any) { | ||||
| 	log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) publish() { | ||||
| 	data := s.staged | ||||
| 	s.published.Store(&data) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) noPeer() stateFunc { | ||||
| 	for { | ||||
| 		rawMsg := <-s.messages | ||||
| 		if msg, ok := rawMsg.(peerUpdateMsg); ok { | ||||
| 			return s.peerUpdate(msg.Peer) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc { | ||||
| 	return func() stateFunc { return s._peerUpdate(peer) } | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { | ||||
| 	defer s.publish() | ||||
|  | ||||
| 	s.peer = peer | ||||
| 	s.staged = peerRoute{} | ||||
|  | ||||
| 	if s.peer == nil { | ||||
| 		return s.noPeer | ||||
| 	} | ||||
|  | ||||
| 	s.staged.IP = s.remoteIP | ||||
| 	s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) | ||||
| 	s.staged.PubSignKey = peer.PubSignKey | ||||
| 	s.staged.DataCipher = newDataCipher() | ||||
|  | ||||
| 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||
| 		s.remotePub = true | ||||
| 		s.staged.Relay = peer.Relay | ||||
| 		s.staged.Direct = true | ||||
| 		s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) | ||||
| 	} else if localPub { | ||||
| 		s.staged.Direct = true | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub == localPub { | ||||
| 		if localIP < s.remoteIP { | ||||
| 			return s.server | ||||
| 		} | ||||
| 		return s.client | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub { | ||||
| 		return s.client | ||||
| 	} | ||||
| 	return s.server | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) server() stateFunc { | ||||
| 	logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) } | ||||
|  | ||||
| 	logf("DOWN") | ||||
|  | ||||
| 	var ( | ||||
| 		syn      synPacket | ||||
| 		lastSeen = time.Now() | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
| 		rawMsg := <-s.messages | ||||
| 		switch msg := rawMsg.(type) { | ||||
|  | ||||
| 		case peerUpdateMsg: | ||||
| 			return s.peerUpdate(msg.Peer) | ||||
|  | ||||
| 		case controlMsg[synPacket]: | ||||
| 			p := msg.Packet | ||||
| 			lastSeen = time.Now() | ||||
|  | ||||
| 			// Before we can respond to this packet, we need to make sure the | ||||
| 			// route is setup properly. | ||||
| 			// | ||||
| 			// The client will update the syn's TraceID whenever there's a change. | ||||
| 			// The server will follow the client's request. | ||||
| 			if p.TraceID != syn.TraceID || !s.staged.Up { | ||||
| 				if p.Direct { | ||||
| 					logf("UP - Direct") | ||||
| 				} else { | ||||
| 					logf("UP - Relayed") | ||||
| 				} | ||||
|  | ||||
| 				syn = p | ||||
| 				s.staged.Up = true | ||||
| 				s.staged.Direct = syn.Direct | ||||
| 				s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) | ||||
| 				s.staged.RemoteAddr = msg.SrcAddr | ||||
|  | ||||
| 				s.publish() | ||||
| 			} | ||||
|  | ||||
| 			// We should always respond. | ||||
| 			ack := synAckPacket{ | ||||
| 				TraceID:  syn.TraceID, | ||||
| 				FromAddr: getLocalAddr(), | ||||
| 			} | ||||
| 			s.sendControlPacket(ack) | ||||
|  | ||||
| 			if s.staged.Direct { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if !syn.FromAddr.IsValid() { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			probe := probePacket{TraceID: newTraceID()} | ||||
| 			s.sendControlPacketTo(probe, syn.FromAddr) | ||||
|  | ||||
| 		case controlMsg[probePacket]: | ||||
| 			if !msg.SrcAddr.IsValid() { | ||||
| 				logf("Invalid probe address") | ||||
| 				continue | ||||
| 			} | ||||
| 			s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			if time.Since(lastSeen) > timeoutInterval && s.staged.Up { | ||||
| 				logf("Connection timeout") | ||||
| 				s.staged.Up = false | ||||
| 				s.publish() | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) client() stateFunc { | ||||
| 	logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) } | ||||
|  | ||||
| 	logf("DOWN") | ||||
|  | ||||
| 	var ( | ||||
| 		syn = synPacket{ | ||||
| 			TraceID:   newTraceID(), | ||||
| 			SharedKey: s.staged.DataCipher.Key(), | ||||
| 			Direct:    s.staged.Direct, | ||||
| 			FromAddr:  getLocalAddr(), | ||||
| 		} | ||||
|  | ||||
| 		lastSeen = time.Now() | ||||
| 		ack      synAckPacket | ||||
|  | ||||
| 		probe     probePacket | ||||
| 		probeAddr netip.AddrPort | ||||
|  | ||||
| 		localProbe     probePacket | ||||
| 		localProbeAddr netip.AddrPort | ||||
|  | ||||
| 		lastLocalAddr netip.AddrPort | ||||
| 	) | ||||
|  | ||||
| 	s.sendControlPacket(syn) | ||||
|  | ||||
| 	for { | ||||
| 		rawMsg := <-s.messages | ||||
| 		switch msg := rawMsg.(type) { | ||||
|  | ||||
| 		case peerUpdateMsg: | ||||
| 			return s.peerUpdate(msg.Peer) | ||||
|  | ||||
| 		case controlMsg[synAckPacket]: | ||||
| 			p := msg.Packet | ||||
|  | ||||
| 			if p.TraceID != syn.TraceID { | ||||
| 				continue // Hmm... | ||||
| 			} | ||||
|  | ||||
| 			lastSeen = time.Now() | ||||
| 			ack = msg.Packet | ||||
|  | ||||
| 			if !s.staged.Up { | ||||
| 				if s.staged.Direct { | ||||
| 					logf("UP - Direct") | ||||
| 				} else { | ||||
| 					logf("UP - Relayed") | ||||
| 				} | ||||
|  | ||||
| 				s.staged.Up = true | ||||
| 				s.publish() | ||||
| 			} | ||||
|  | ||||
| 		case controlMsg[probePacket]: | ||||
| 			if s.staged.Direct { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			p := msg.Packet | ||||
|  | ||||
| 			if p.TraceID != localProbe.TraceID && p.TraceID != probe.TraceID { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// Upgrade connection. | ||||
|  | ||||
| 			s.staged.Direct = true | ||||
| 			if p.TraceID == localProbe.TraceID { | ||||
| 				logf("UP - Local") | ||||
| 				s.staged.RemoteAddr = localProbeAddr | ||||
| 			} else { | ||||
| 				logf("UP - Direct") | ||||
| 				s.staged.RemoteAddr = probeAddr | ||||
| 			} | ||||
| 			s.publish() | ||||
|  | ||||
| 			syn.TraceID = newTraceID() | ||||
| 			syn.Direct = true | ||||
| 			syn.FromAddr = getLocalAddr() | ||||
| 			s.sendControlPacket(syn) | ||||
|  | ||||
| 		case controlMsg[localDiscoveryPacket]: | ||||
| 			if s.staged.Direct { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// Send probe. | ||||
| 			// | ||||
| 			// The source port will be the multicast port, so we'll have to | ||||
| 			// construct the correct address using the peer's listed port. | ||||
| 			localProbe = probePacket{TraceID: newTraceID()} | ||||
| 			localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| 			s.sendControlPacketTo(localProbe, localProbeAddr) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			if time.Since(lastSeen) > timeoutInterval { | ||||
| 				if s.staged.Up { | ||||
| 					logf("Connection timeout") | ||||
| 				} | ||||
| 				return s.peerUpdate(s.peer) | ||||
| 			} | ||||
|  | ||||
| 			syn.FromAddr = getLocalAddr() | ||||
| 			if syn.FromAddr != lastLocalAddr { | ||||
| 				syn.TraceID = newTraceID() | ||||
| 				lastLocalAddr = syn.FromAddr | ||||
| 			} | ||||
|  | ||||
| 			s.sendControlPacket(syn) | ||||
|  | ||||
| 			if s.staged.Direct { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if !ack.FromAddr.IsValid() { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			probe = probePacket{TraceID: newTraceID()} | ||||
| 			probeAddr = ack.FromAddr | ||||
|  | ||||
| 			s.sendControlPacketTo(probe, ack.FromAddr) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										392
									
								
								node/supervisor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										392
									
								
								node/supervisor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,392 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
|  | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	pingInterval    = 8 * time.Second | ||||
| 	timeoutInterval = 25 * time.Second | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func startPeerSuper() { | ||||
| 	peers := [256]peerState{} | ||||
| 	for i := range peers { | ||||
| 		data := &peerStateData{ | ||||
| 			published: routingTable[i], | ||||
| 			remoteIP:  byte(i), | ||||
| 			buf1:      make([]byte, bufferSize), | ||||
| 			buf2:      make([]byte, bufferSize), | ||||
| 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||
| 				FillPeriod:   50 * time.Millisecond, | ||||
| 				MaxWaitCount: 1, | ||||
| 			}), | ||||
| 		} | ||||
| 		peers[i] = data.OnPeerUpdate(nil) | ||||
| 	} | ||||
| 	go runPeerSuper(peers) | ||||
| } | ||||
|  | ||||
| func runPeerSuper(peers [256]peerState) { | ||||
| 	for raw := range messages { | ||||
| 		switch msg := raw.(type) { | ||||
|  | ||||
| 		case peerUpdateMsg: | ||||
| 			peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer) | ||||
|  | ||||
| 		case controlMsg[synPacket]: | ||||
| 			peers[msg.SrcIP].OnSyn(msg) | ||||
|  | ||||
| 		case controlMsg[ackPacket]: | ||||
| 			peers[msg.SrcIP].OnAck(msg) | ||||
|  | ||||
| 		case controlMsg[probePacket]: | ||||
| 			peers[msg.SrcIP].OnProbe(msg) | ||||
|  | ||||
| 		case controlMsg[localDiscoveryPacket]: | ||||
| 			peers[msg.SrcIP].OnLocalDiscovery(msg) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			for i := range peers { | ||||
| 				if newState := peers[i].OnPingTimer(); newState != nil { | ||||
| 					peers[i] = newState | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type peerState interface { | ||||
| 	OnPeerUpdate(*m.Peer) peerState | ||||
| 	OnSyn(controlMsg[synPacket]) | ||||
| 	OnAck(controlMsg[ackPacket]) | ||||
| 	OnProbe(controlMsg[probePacket]) | ||||
| 	OnLocalDiscovery(controlMsg[localDiscoveryPacket]) | ||||
| 	OnPingTimer() peerState | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type peerStateData struct { | ||||
| 	// The purpose of this state machine is to manage this published data. | ||||
| 	published *atomic.Pointer[peerRoute] | ||||
| 	staged    peerRoute // Local copy of shared data. See publish(). | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	remoteIP byte // Remote VPN IP. | ||||
|  | ||||
| 	// Mutable peer data. | ||||
| 	peer      *m.Peer | ||||
| 	remotePub bool | ||||
|  | ||||
| 	// Buffers for sending control packets. | ||||
| 	buf1 []byte | ||||
| 	buf2 []byte | ||||
|  | ||||
| 	// For logging. Set per-state. | ||||
| 	client bool | ||||
|  | ||||
| 	limiter *ratelimiter.Limiter | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { | ||||
| 	s.limiter.Limit() | ||||
| 	_sendControlPacket(pkt, s.staged, s.buf1, s.buf2) | ||||
| } | ||||
|  | ||||
| func (s *peerStateData) sendControlPacketTo( | ||||
| 	pkt interface{ Marshal([]byte) []byte }, | ||||
| 	addr netip.AddrPort, | ||||
| ) { | ||||
| 	if !addr.IsValid() { | ||||
| 		s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) | ||||
| 		return | ||||
| 	} | ||||
| 	route := s.staged | ||||
| 	route.Direct = true | ||||
| 	route.RemoteAddr = addr | ||||
| 	s.limiter.Limit() | ||||
| 	_sendControlPacket(pkt, route, s.buf1, s.buf2) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerStateData) publish() { | ||||
| 	data := s.staged | ||||
| 	s.published.Store(&data) | ||||
| } | ||||
|  | ||||
| func (s *peerStateData) logf(format string, args ...any) { | ||||
| 	b := strings.Builder{} | ||||
| 	b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name)) | ||||
|  | ||||
| 	if s.client { | ||||
| 		b.WriteString("CLIENT|") | ||||
| 	} else { | ||||
| 		b.WriteString("SERVER|") | ||||
| 	} | ||||
|  | ||||
| 	if s.staged.Direct { | ||||
| 		b.WriteString("DIRECT |") | ||||
| 	} else { | ||||
| 		b.WriteString("RELAYED|") | ||||
| 	} | ||||
|  | ||||
| 	if s.staged.Up { | ||||
| 		b.WriteString("UP  |") | ||||
| 	} else { | ||||
| 		b.WriteString("DOWN|") | ||||
| 	} | ||||
|  | ||||
| 	log.Printf(b.String()+format, args...) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { | ||||
| 	defer s.publish() | ||||
|  | ||||
| 	if peer == nil { | ||||
| 		return enterStateDisconnected(s) | ||||
| 	} | ||||
|  | ||||
| 	s.peer = peer | ||||
| 	s.staged.IP = s.remoteIP | ||||
| 	s.staged.PubSignKey = peer.PubSignKey | ||||
| 	s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) | ||||
| 	s.staged.DataCipher = newDataCipher() | ||||
|  | ||||
| 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||
| 		s.remotePub = true | ||||
| 		s.staged.Relay = peer.Relay | ||||
| 		s.staged.Direct = true | ||||
| 		s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) | ||||
| 	} else if localPub { | ||||
| 		s.staged.Direct = true | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub == localPub { | ||||
| 		if localIP < s.remoteIP { | ||||
| 			return enterStateServer(s) | ||||
| 		} | ||||
| 		return enterStateClient(s) | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub { | ||||
| 		return enterStateClient(s) | ||||
| 	} | ||||
| 	return enterStateServer(s) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateDisconnected struct { | ||||
| 	*peerStateData | ||||
| } | ||||
|  | ||||
| func enterStateDisconnected(s *peerStateData) peerState { | ||||
| 	s.peer = nil | ||||
| 	s.staged = peerRoute{} | ||||
| 	s.publish() | ||||
| 	return &stateDisconnected{s} | ||||
| } | ||||
|  | ||||
| func (s *stateDisconnected) OnSyn(controlMsg[synPacket])                       {} | ||||
| func (s *stateDisconnected) OnAck(controlMsg[ackPacket])                       {} | ||||
| func (s *stateDisconnected) OnProbe(controlMsg[probePacket])                   {} | ||||
| func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {} | ||||
|  | ||||
| func (s *stateDisconnected) OnPingTimer() peerState { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateServer struct { | ||||
| 	*stateDisconnected | ||||
| 	lastSeen   time.Time | ||||
| 	synTraceID uint64 | ||||
| } | ||||
|  | ||||
| func enterStateServer(s *peerStateData) peerState { | ||||
| 	s.client = false | ||||
| 	return &stateServer{stateDisconnected: &stateDisconnected{s}} | ||||
| } | ||||
|  | ||||
| func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { | ||||
| 	s.lastSeen = time.Now() | ||||
| 	p := msg.Packet | ||||
|  | ||||
| 	// Before we can respond to this packet, we need to make sure the | ||||
| 	// route is setup properly. | ||||
| 	// | ||||
| 	// The client will update the syn's TraceID whenever there's a change. | ||||
| 	// The server will follow the client's request. | ||||
| 	if p.TraceID != s.synTraceID || !s.staged.Up { | ||||
| 		s.synTraceID = p.TraceID | ||||
| 		s.staged.Up = true | ||||
| 		s.staged.Direct = p.Direct | ||||
| 		s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) | ||||
| 		s.staged.RemoteAddr = msg.SrcAddr | ||||
| 		s.publish() | ||||
| 		s.logf("Got syn.") | ||||
| 	} | ||||
|  | ||||
| 	// Always respond. | ||||
| 	ack := ackPacket{ | ||||
| 		TraceID:  p.TraceID, | ||||
| 		FromAddr: getLocalAddr(), | ||||
| 	} | ||||
| 	s.sendControlPacket(ack) | ||||
|  | ||||
| 	if !s.staged.Direct && p.FromAddr.IsValid() { | ||||
| 		s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *stateServer) OnProbe(msg controlMsg[probePacket]) { | ||||
| 	if !msg.SrcAddr.IsValid() { | ||||
| 		s.logf("Invalid probe address.") | ||||
| 		return | ||||
| 	} | ||||
| 	s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||
| } | ||||
|  | ||||
| func (s *stateServer) OnPingTimer() peerState { | ||||
| 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | ||||
| 		s.staged.Up = false | ||||
| 		s.publish() | ||||
| 		s.logf("Connection timeout.") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateClient struct { | ||||
| 	*stateDisconnected | ||||
|  | ||||
| 	lastSeen time.Time | ||||
| 	syn      synPacket | ||||
| 	ack      ackPacket | ||||
|  | ||||
| 	probeTraceID uint64 | ||||
| 	probeAddr    netip.AddrPort | ||||
|  | ||||
| 	localProbeTraceID uint64 | ||||
| 	localProbeAddr    netip.AddrPort | ||||
| } | ||||
|  | ||||
| func enterStateClient(s *peerStateData) peerState { | ||||
| 	s.client = true | ||||
| 	ss := &stateClient{stateDisconnected: &stateDisconnected{s}} | ||||
| 	ss.syn = synPacket{ | ||||
| 		TraceID:   newTraceID(), | ||||
| 		SharedKey: s.staged.DataCipher.Key(), | ||||
| 		Direct:    s.staged.Direct, | ||||
| 		FromAddr:  getLocalAddr(), | ||||
| 	} | ||||
| 	ss.sendSyn() | ||||
| 	return ss | ||||
| } | ||||
|  | ||||
| func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { | ||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | ||||
| 		s.logf("Ack has incorrect trace ID") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	s.ack = msg.Packet | ||||
| 	s.lastSeen = time.Now() | ||||
|  | ||||
| 	if !s.staged.Up { | ||||
| 		s.staged.Up = true | ||||
| 		s.logf("Got ack.") | ||||
| 		s.publish() | ||||
| 	} else { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { | ||||
| 	if s.staged.Direct { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	switch msg.Packet.TraceID { | ||||
| 	case s.probeTraceID: | ||||
| 		s.staged.RemoteAddr = s.probeAddr | ||||
| 	case s.localProbeTraceID: | ||||
| 		s.staged.RemoteAddr = s.localProbeAddr | ||||
| 	default: | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	s.staged.Direct = true | ||||
| 	s.publish() | ||||
|  | ||||
| 	s.syn.TraceID = newTraceID() | ||||
| 	s.syn.Direct = true | ||||
| 	s.syn.FromAddr = getLocalAddr() | ||||
| 	s.sendControlPacket(s.syn) | ||||
|  | ||||
| 	s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) | ||||
| } | ||||
|  | ||||
| func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { | ||||
| 	if s.staged.Direct { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Send probe. | ||||
| 	// | ||||
| 	// The source port will be the multicast port, so we'll have to | ||||
| 	// construct the correct address using the peer's listed port. | ||||
| 	s.localProbeTraceID = newTraceID() | ||||
| 	s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| 	s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr) | ||||
| } | ||||
|  | ||||
| func (s *stateClient) OnPingTimer() peerState { | ||||
| 	if time.Since(s.lastSeen) > timeoutInterval { | ||||
| 		if s.staged.Up { | ||||
| 			s.logf("Connection timeout.") | ||||
| 		} | ||||
| 		return s.OnPeerUpdate(s.peer) | ||||
| 	} | ||||
|  | ||||
| 	s.sendSyn() | ||||
|  | ||||
| 	if !s.staged.Direct && s.ack.FromAddr.IsValid() { | ||||
| 		s.probeTraceID = newTraceID() | ||||
| 		s.probeAddr = s.ack.FromAddr | ||||
| 		s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *stateClient) sendSyn() { | ||||
| 	localAddr := getLocalAddr() | ||||
| 	if localAddr != s.syn.FromAddr { | ||||
| 		s.syn.TraceID = newTraceID() | ||||
| 		s.syn.FromAddr = localAddr | ||||
| 	} | ||||
| 	s.sendControlPacket(s.syn) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user