Single thread for supervisor.
This commit is contained in:
		| @@ -66,12 +66,7 @@ var ( | |||||||
| 		return | 		return | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	messages [256]chan any = func() (out [256]chan any) { | 	messages = make(chan any, 512) | ||||||
| 		for i := range out { |  | ||||||
| 			out[i] = make(chan any, 256) |  | ||||||
| 		} |  | ||||||
| 		return |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	// Global routing table. | 	// Global routing table. | ||||||
| 	routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { | 	routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { | ||||||
|   | |||||||
| @@ -81,10 +81,12 @@ func (hp *hubPoller) pollHub() { | |||||||
| func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | ||||||
| 	for i, peer := range state.Peers { | 	for i, peer := range state.Peers { | ||||||
| 		if i != int(localIP) { | 		if i != int(localIP) { | ||||||
| 			if peer != nil && peer.Version != hp.versions[i] { | 			if peer == nil || peer.Version != hp.versions[i] { | ||||||
| 				messages[i] <- peerUpdateMsg{Peer: state.Peers[i]} | 				messages <- peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]} | ||||||
|  | 				if peer != nil { | ||||||
| 					hp.versions[i] = peer.Version | 					hp.versions[i] = peer.Version | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -59,8 +59,9 @@ func recvLocalDiscovery(conn *net.UDPConn) { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		select { | 		select { | ||||||
| 		case messages[h.SourceIP] <- msg: | 		case messages <- msg: | ||||||
| 		default: | 		default: | ||||||
|  | 			log.Printf("Dropping local discovery message.") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -86,7 +87,7 @@ func openLocalDiscoveryPacket(raw, buf []byte) (h header, ok bool) { | |||||||
| 	h.Parse(raw[signOverhead:]) | 	h.Parse(raw[signOverhead:]) | ||||||
| 	route := routingTable[h.SourceIP].Load() | 	route := routingTable[h.SourceIP].Load() | ||||||
| 	if route == nil || route.PubSignKey == nil { | 	if route == nil || route.PubSignKey == nil { | ||||||
| 		log.Printf("Missing signing key") | 		log.Printf("Missing signing key: %d", h.SourceIP) | ||||||
| 		ok = false | 		ok = false | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							| @@ -159,11 +159,6 @@ func main() { | |||||||
| 	privKey = config.PrivKey | 	privKey = config.PrivKey | ||||||
| 	privSignKey = config.PrivSignKey | 	privSignKey = config.PrivSignKey | ||||||
|  |  | ||||||
| 	// Start supervisors. |  | ||||||
| 	for i := range 256 { |  | ||||||
| 		go newPeerSupervisor(i).Run() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if localPub { | 	if localPub { | ||||||
| 		go addrDiscoveryServer() | 		go addrDiscoveryServer() | ||||||
| 	} else { | 	} else { | ||||||
| @@ -174,15 +169,12 @@ func main() { | |||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for range time.Tick(pingInterval) { | 		for range time.Tick(pingInterval) { | ||||||
| 			for i := range messages { | 			messages <- pingTimerMsg{} | ||||||
| 				select { |  | ||||||
| 				case messages[i] <- pingTimerMsg{}: |  | ||||||
| 				default: |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
|  | 	go startPeerSuper() | ||||||
|  |  | ||||||
| 	go newHubPoller().Run() | 	go newHubPoller().Run() | ||||||
| 	go readFromConn(conn) | 	go readFromConn(conn) | ||||||
| 	readFromIFace(iface) | 	readFromIFace(iface) | ||||||
| @@ -272,7 +264,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	select { | 	select { | ||||||
| 	case messages[h.SourceIP] <- msg: | 	case messages <- msg: | ||||||
| 	default: | 	default: | ||||||
| 		log.Printf("Dropping control packet.") | 		log.Printf("Dropping control packet.") | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -25,8 +25,8 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | |||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
| 	case packetTypeSynAck: | 	case packetTypeSynAck: | ||||||
| 		packet, err := parseSynAckPacket(buf) | 		packet, err := parseAckPacket(buf) | ||||||
| 		return controlMsg[synAckPacket]{ | 		return controlMsg[ackPacket]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| 			SrcAddr: srcAddr, | 			SrcAddr: srcAddr, | ||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| @@ -56,6 +56,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | |||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type peerUpdateMsg struct { | type peerUpdateMsg struct { | ||||||
|  | 	PeerIP byte | ||||||
| 	Peer   *m.Peer | 	Peer   *m.Peer | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -49,12 +49,12 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type synAckPacket struct { | type ackPacket struct { | ||||||
| 	TraceID  uint64 | 	TraceID  uint64 | ||||||
| 	FromAddr netip.AddrPort | 	FromAddr netip.AddrPort | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p synAckPacket) Marshal(buf []byte) []byte { | func (p ackPacket) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(packetTypeSynAck). | 		Byte(packetTypeSynAck). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| @@ -62,7 +62,7 @@ func (p synAckPacket) Marshal(buf []byte) []byte { | |||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { | func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		AddrPort(&p.FromAddr). | 		AddrPort(&p.FromAddr). | ||||||
|   | |||||||
| @@ -25,12 +25,12 @@ func TestPacketSyn(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestPacketSynAck(t *testing.T) { | func TestPacketSynAck(t *testing.T) { | ||||||
| 	in := synAckPacket{ | 	in := ackPacket{ | ||||||
| 		TraceID:  newTraceID(), | 		TraceID:  newTraceID(), | ||||||
| 		FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), | 		FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize))) | 	out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize))) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,354 +0,0 @@ | |||||||
| package node |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"log" |  | ||||||
| 	"net/netip" |  | ||||||
| 	"sync/atomic" |  | ||||||
| 	"time" |  | ||||||
| 	"vppn/m" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	pingInterval    = 8 * time.Second |  | ||||||
| 	timeoutInterval = 25 * time.Second |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type peerSupervisor struct { |  | ||||||
| 	// The purpose of this state machine is to manage this published data. |  | ||||||
| 	published *atomic.Pointer[peerRoute] |  | ||||||
| 	staged    peerRoute // Local copy of shared data. See publish(). |  | ||||||
|  |  | ||||||
| 	// Immutable data. |  | ||||||
| 	remoteIP byte // Remote VPN IP. |  | ||||||
|  |  | ||||||
| 	// Mutable peer data. |  | ||||||
| 	peer      *m.Peer |  | ||||||
| 	remotePub bool |  | ||||||
|  |  | ||||||
| 	// Incoming events. |  | ||||||
| 	messages chan any |  | ||||||
|  |  | ||||||
| 	// Buffers for sending control packets. |  | ||||||
| 	buf1 []byte |  | ||||||
| 	buf2 []byte |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func newPeerSupervisor(i int) *peerSupervisor { |  | ||||||
| 	return &peerSupervisor{ |  | ||||||
| 		published: routingTable[i], |  | ||||||
| 		remoteIP:  byte(i), |  | ||||||
| 		messages:  messages[i], |  | ||||||
| 		buf1:      make([]byte, bufferSize), |  | ||||||
| 		buf2:      make([]byte, bufferSize), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type stateFunc func() stateFunc |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) Run() { |  | ||||||
| 	state := s.noPeer |  | ||||||
| 	for { |  | ||||||
| 		state = state() |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { |  | ||||||
| 	_sendControlPacket(pkt, s.staged, s.buf1, s.buf2) |  | ||||||
| 	time.Sleep(500 * time.Millisecond) // Rate limit packets. |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) sendControlPacketTo( |  | ||||||
| 	pkt interface{ Marshal([]byte) []byte }, |  | ||||||
| 	addr netip.AddrPort, |  | ||||||
| ) { |  | ||||||
| 	if !addr.IsValid() { |  | ||||||
| 		s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	route := s.staged |  | ||||||
| 	route.Direct = true |  | ||||||
| 	route.RemoteAddr = addr |  | ||||||
| 	_sendControlPacket(pkt, route, s.buf1, s.buf2) |  | ||||||
| 	time.Sleep(500 * time.Millisecond) // Rate limit packets. |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) logf(msg string, args ...any) { |  | ||||||
| 	log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) publish() { |  | ||||||
| 	data := s.staged |  | ||||||
| 	s.published.Store(&data) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) noPeer() stateFunc { |  | ||||||
| 	for { |  | ||||||
| 		rawMsg := <-s.messages |  | ||||||
| 		if msg, ok := rawMsg.(peerUpdateMsg); ok { |  | ||||||
| 			return s.peerUpdate(msg.Peer) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc { |  | ||||||
| 	return func() stateFunc { return s._peerUpdate(peer) } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { |  | ||||||
| 	defer s.publish() |  | ||||||
|  |  | ||||||
| 	s.peer = peer |  | ||||||
| 	s.staged = peerRoute{} |  | ||||||
|  |  | ||||||
| 	if s.peer == nil { |  | ||||||
| 		return s.noPeer |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	s.staged.IP = s.remoteIP |  | ||||||
| 	s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) |  | ||||||
| 	s.staged.PubSignKey = peer.PubSignKey |  | ||||||
| 	s.staged.DataCipher = newDataCipher() |  | ||||||
|  |  | ||||||
| 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { |  | ||||||
| 		s.remotePub = true |  | ||||||
| 		s.staged.Relay = peer.Relay |  | ||||||
| 		s.staged.Direct = true |  | ||||||
| 		s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) |  | ||||||
| 	} else if localPub { |  | ||||||
| 		s.staged.Direct = true |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if s.remotePub == localPub { |  | ||||||
| 		if localIP < s.remoteIP { |  | ||||||
| 			return s.server |  | ||||||
| 		} |  | ||||||
| 		return s.client |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if s.remotePub { |  | ||||||
| 		return s.client |  | ||||||
| 	} |  | ||||||
| 	return s.server |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) server() stateFunc { |  | ||||||
| 	logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) } |  | ||||||
|  |  | ||||||
| 	logf("DOWN") |  | ||||||
|  |  | ||||||
| 	var ( |  | ||||||
| 		syn      synPacket |  | ||||||
| 		lastSeen = time.Now() |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		rawMsg := <-s.messages |  | ||||||
| 		switch msg := rawMsg.(type) { |  | ||||||
|  |  | ||||||
| 		case peerUpdateMsg: |  | ||||||
| 			return s.peerUpdate(msg.Peer) |  | ||||||
|  |  | ||||||
| 		case controlMsg[synPacket]: |  | ||||||
| 			p := msg.Packet |  | ||||||
| 			lastSeen = time.Now() |  | ||||||
|  |  | ||||||
| 			// Before we can respond to this packet, we need to make sure the |  | ||||||
| 			// route is setup properly. |  | ||||||
| 			// |  | ||||||
| 			// The client will update the syn's TraceID whenever there's a change. |  | ||||||
| 			// The server will follow the client's request. |  | ||||||
| 			if p.TraceID != syn.TraceID || !s.staged.Up { |  | ||||||
| 				if p.Direct { |  | ||||||
| 					logf("UP - Direct") |  | ||||||
| 				} else { |  | ||||||
| 					logf("UP - Relayed") |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				syn = p |  | ||||||
| 				s.staged.Up = true |  | ||||||
| 				s.staged.Direct = syn.Direct |  | ||||||
| 				s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) |  | ||||||
| 				s.staged.RemoteAddr = msg.SrcAddr |  | ||||||
|  |  | ||||||
| 				s.publish() |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// We should always respond. |  | ||||||
| 			ack := synAckPacket{ |  | ||||||
| 				TraceID:  syn.TraceID, |  | ||||||
| 				FromAddr: getLocalAddr(), |  | ||||||
| 			} |  | ||||||
| 			s.sendControlPacket(ack) |  | ||||||
|  |  | ||||||
| 			if s.staged.Direct { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if !syn.FromAddr.IsValid() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			probe := probePacket{TraceID: newTraceID()} |  | ||||||
| 			s.sendControlPacketTo(probe, syn.FromAddr) |  | ||||||
|  |  | ||||||
| 		case controlMsg[probePacket]: |  | ||||||
| 			if !msg.SrcAddr.IsValid() { |  | ||||||
| 				logf("Invalid probe address") |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) |  | ||||||
|  |  | ||||||
| 		case pingTimerMsg: |  | ||||||
| 			if time.Since(lastSeen) > timeoutInterval && s.staged.Up { |  | ||||||
| 				logf("Connection timeout") |  | ||||||
| 				s.staged.Up = false |  | ||||||
| 				s.publish() |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func (s *peerSupervisor) client() stateFunc { |  | ||||||
| 	logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) } |  | ||||||
|  |  | ||||||
| 	logf("DOWN") |  | ||||||
|  |  | ||||||
| 	var ( |  | ||||||
| 		syn = synPacket{ |  | ||||||
| 			TraceID:   newTraceID(), |  | ||||||
| 			SharedKey: s.staged.DataCipher.Key(), |  | ||||||
| 			Direct:    s.staged.Direct, |  | ||||||
| 			FromAddr:  getLocalAddr(), |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		lastSeen = time.Now() |  | ||||||
| 		ack      synAckPacket |  | ||||||
|  |  | ||||||
| 		probe     probePacket |  | ||||||
| 		probeAddr netip.AddrPort |  | ||||||
|  |  | ||||||
| 		localProbe     probePacket |  | ||||||
| 		localProbeAddr netip.AddrPort |  | ||||||
|  |  | ||||||
| 		lastLocalAddr netip.AddrPort |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	s.sendControlPacket(syn) |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		rawMsg := <-s.messages |  | ||||||
| 		switch msg := rawMsg.(type) { |  | ||||||
|  |  | ||||||
| 		case peerUpdateMsg: |  | ||||||
| 			return s.peerUpdate(msg.Peer) |  | ||||||
|  |  | ||||||
| 		case controlMsg[synAckPacket]: |  | ||||||
| 			p := msg.Packet |  | ||||||
|  |  | ||||||
| 			if p.TraceID != syn.TraceID { |  | ||||||
| 				continue // Hmm... |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			lastSeen = time.Now() |  | ||||||
| 			ack = msg.Packet |  | ||||||
|  |  | ||||||
| 			if !s.staged.Up { |  | ||||||
| 				if s.staged.Direct { |  | ||||||
| 					logf("UP - Direct") |  | ||||||
| 				} else { |  | ||||||
| 					logf("UP - Relayed") |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				s.staged.Up = true |  | ||||||
| 				s.publish() |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 		case controlMsg[probePacket]: |  | ||||||
| 			if s.staged.Direct { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			p := msg.Packet |  | ||||||
|  |  | ||||||
| 			if p.TraceID != localProbe.TraceID && p.TraceID != probe.TraceID { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Upgrade connection. |  | ||||||
|  |  | ||||||
| 			s.staged.Direct = true |  | ||||||
| 			if p.TraceID == localProbe.TraceID { |  | ||||||
| 				logf("UP - Local") |  | ||||||
| 				s.staged.RemoteAddr = localProbeAddr |  | ||||||
| 			} else { |  | ||||||
| 				logf("UP - Direct") |  | ||||||
| 				s.staged.RemoteAddr = probeAddr |  | ||||||
| 			} |  | ||||||
| 			s.publish() |  | ||||||
|  |  | ||||||
| 			syn.TraceID = newTraceID() |  | ||||||
| 			syn.Direct = true |  | ||||||
| 			syn.FromAddr = getLocalAddr() |  | ||||||
| 			s.sendControlPacket(syn) |  | ||||||
|  |  | ||||||
| 		case controlMsg[localDiscoveryPacket]: |  | ||||||
| 			if s.staged.Direct { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Send probe. |  | ||||||
| 			// |  | ||||||
| 			// The source port will be the multicast port, so we'll have to |  | ||||||
| 			// construct the correct address using the peer's listed port. |  | ||||||
| 			localProbe = probePacket{TraceID: newTraceID()} |  | ||||||
| 			localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) |  | ||||||
| 			s.sendControlPacketTo(localProbe, localProbeAddr) |  | ||||||
|  |  | ||||||
| 		case pingTimerMsg: |  | ||||||
| 			if time.Since(lastSeen) > timeoutInterval { |  | ||||||
| 				if s.staged.Up { |  | ||||||
| 					logf("Connection timeout") |  | ||||||
| 				} |  | ||||||
| 				return s.peerUpdate(s.peer) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			syn.FromAddr = getLocalAddr() |  | ||||||
| 			if syn.FromAddr != lastLocalAddr { |  | ||||||
| 				syn.TraceID = newTraceID() |  | ||||||
| 				lastLocalAddr = syn.FromAddr |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			s.sendControlPacket(syn) |  | ||||||
|  |  | ||||||
| 			if s.staged.Direct { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if !ack.FromAddr.IsValid() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			probe = probePacket{TraceID: newTraceID()} |  | ||||||
| 			probeAddr = ack.FromAddr |  | ||||||
|  |  | ||||||
| 			s.sendControlPacketTo(probe, ack.FromAddr) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
							
								
								
									
										392
									
								
								node/supervisor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										392
									
								
								node/supervisor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,392 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"time" | ||||||
|  | 	"vppn/m" | ||||||
|  |  | ||||||
|  | 	"git.crumpington.com/lib/go/ratelimiter" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	pingInterval    = 8 * time.Second | ||||||
|  | 	timeoutInterval = 25 * time.Second | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func startPeerSuper() { | ||||||
|  | 	peers := [256]peerState{} | ||||||
|  | 	for i := range peers { | ||||||
|  | 		data := &peerStateData{ | ||||||
|  | 			published: routingTable[i], | ||||||
|  | 			remoteIP:  byte(i), | ||||||
|  | 			buf1:      make([]byte, bufferSize), | ||||||
|  | 			buf2:      make([]byte, bufferSize), | ||||||
|  | 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||||
|  | 				FillPeriod:   50 * time.Millisecond, | ||||||
|  | 				MaxWaitCount: 1, | ||||||
|  | 			}), | ||||||
|  | 		} | ||||||
|  | 		peers[i] = data.OnPeerUpdate(nil) | ||||||
|  | 	} | ||||||
|  | 	go runPeerSuper(peers) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func runPeerSuper(peers [256]peerState) { | ||||||
|  | 	for raw := range messages { | ||||||
|  | 		switch msg := raw.(type) { | ||||||
|  |  | ||||||
|  | 		case peerUpdateMsg: | ||||||
|  | 			peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer) | ||||||
|  |  | ||||||
|  | 		case controlMsg[synPacket]: | ||||||
|  | 			peers[msg.SrcIP].OnSyn(msg) | ||||||
|  |  | ||||||
|  | 		case controlMsg[ackPacket]: | ||||||
|  | 			peers[msg.SrcIP].OnAck(msg) | ||||||
|  |  | ||||||
|  | 		case controlMsg[probePacket]: | ||||||
|  | 			peers[msg.SrcIP].OnProbe(msg) | ||||||
|  |  | ||||||
|  | 		case controlMsg[localDiscoveryPacket]: | ||||||
|  | 			peers[msg.SrcIP].OnLocalDiscovery(msg) | ||||||
|  |  | ||||||
|  | 		case pingTimerMsg: | ||||||
|  | 			for i := range peers { | ||||||
|  | 				if newState := peers[i].OnPingTimer(); newState != nil { | ||||||
|  | 					peers[i] = newState | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 		default: | ||||||
|  | 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type peerState interface { | ||||||
|  | 	OnPeerUpdate(*m.Peer) peerState | ||||||
|  | 	OnSyn(controlMsg[synPacket]) | ||||||
|  | 	OnAck(controlMsg[ackPacket]) | ||||||
|  | 	OnProbe(controlMsg[probePacket]) | ||||||
|  | 	OnLocalDiscovery(controlMsg[localDiscoveryPacket]) | ||||||
|  | 	OnPingTimer() peerState | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type peerStateData struct { | ||||||
|  | 	// The purpose of this state machine is to manage this published data. | ||||||
|  | 	published *atomic.Pointer[peerRoute] | ||||||
|  | 	staged    peerRoute // Local copy of shared data. See publish(). | ||||||
|  |  | ||||||
|  | 	// Immutable data. | ||||||
|  | 	remoteIP byte // Remote VPN IP. | ||||||
|  |  | ||||||
|  | 	// Mutable peer data. | ||||||
|  | 	peer      *m.Peer | ||||||
|  | 	remotePub bool | ||||||
|  |  | ||||||
|  | 	// Buffers for sending control packets. | ||||||
|  | 	buf1 []byte | ||||||
|  | 	buf2 []byte | ||||||
|  |  | ||||||
|  | 	// For logging. Set per-state. | ||||||
|  | 	client bool | ||||||
|  |  | ||||||
|  | 	limiter *ratelimiter.Limiter | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { | ||||||
|  | 	s.limiter.Limit() | ||||||
|  | 	_sendControlPacket(pkt, s.staged, s.buf1, s.buf2) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *peerStateData) sendControlPacketTo( | ||||||
|  | 	pkt interface{ Marshal([]byte) []byte }, | ||||||
|  | 	addr netip.AddrPort, | ||||||
|  | ) { | ||||||
|  | 	if !addr.IsValid() { | ||||||
|  | 		s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	route := s.staged | ||||||
|  | 	route.Direct = true | ||||||
|  | 	route.RemoteAddr = addr | ||||||
|  | 	s.limiter.Limit() | ||||||
|  | 	_sendControlPacket(pkt, route, s.buf1, s.buf2) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (s *peerStateData) publish() { | ||||||
|  | 	data := s.staged | ||||||
|  | 	s.published.Store(&data) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *peerStateData) logf(format string, args ...any) { | ||||||
|  | 	b := strings.Builder{} | ||||||
|  | 	b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name)) | ||||||
|  |  | ||||||
|  | 	if s.client { | ||||||
|  | 		b.WriteString("CLIENT|") | ||||||
|  | 	} else { | ||||||
|  | 		b.WriteString("SERVER|") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		b.WriteString("DIRECT |") | ||||||
|  | 	} else { | ||||||
|  | 		b.WriteString("RELAYED|") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.staged.Up { | ||||||
|  | 		b.WriteString("UP  |") | ||||||
|  | 	} else { | ||||||
|  | 		b.WriteString("DOWN|") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Printf(b.String()+format, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { | ||||||
|  | 	defer s.publish() | ||||||
|  |  | ||||||
|  | 	if peer == nil { | ||||||
|  | 		return enterStateDisconnected(s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.peer = peer | ||||||
|  | 	s.staged.IP = s.remoteIP | ||||||
|  | 	s.staged.PubSignKey = peer.PubSignKey | ||||||
|  | 	s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) | ||||||
|  | 	s.staged.DataCipher = newDataCipher() | ||||||
|  |  | ||||||
|  | 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||||
|  | 		s.remotePub = true | ||||||
|  | 		s.staged.Relay = peer.Relay | ||||||
|  | 		s.staged.Direct = true | ||||||
|  | 		s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) | ||||||
|  | 	} else if localPub { | ||||||
|  | 		s.staged.Direct = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.remotePub == localPub { | ||||||
|  | 		if localIP < s.remoteIP { | ||||||
|  | 			return enterStateServer(s) | ||||||
|  | 		} | ||||||
|  | 		return enterStateClient(s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.remotePub { | ||||||
|  | 		return enterStateClient(s) | ||||||
|  | 	} | ||||||
|  | 	return enterStateServer(s) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type stateDisconnected struct { | ||||||
|  | 	*peerStateData | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateDisconnected(s *peerStateData) peerState { | ||||||
|  | 	s.peer = nil | ||||||
|  | 	s.staged = peerRoute{} | ||||||
|  | 	s.publish() | ||||||
|  | 	return &stateDisconnected{s} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateDisconnected) OnSyn(controlMsg[synPacket])                       {} | ||||||
|  | func (s *stateDisconnected) OnAck(controlMsg[ackPacket])                       {} | ||||||
|  | func (s *stateDisconnected) OnProbe(controlMsg[probePacket])                   {} | ||||||
|  | func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {} | ||||||
|  |  | ||||||
|  | func (s *stateDisconnected) OnPingTimer() peerState { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type stateServer struct { | ||||||
|  | 	*stateDisconnected | ||||||
|  | 	lastSeen   time.Time | ||||||
|  | 	synTraceID uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateServer(s *peerStateData) peerState { | ||||||
|  | 	s.client = false | ||||||
|  | 	return &stateServer{stateDisconnected: &stateDisconnected{s}} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  | 	p := msg.Packet | ||||||
|  |  | ||||||
|  | 	// Before we can respond to this packet, we need to make sure the | ||||||
|  | 	// route is setup properly. | ||||||
|  | 	// | ||||||
|  | 	// The client will update the syn's TraceID whenever there's a change. | ||||||
|  | 	// The server will follow the client's request. | ||||||
|  | 	if p.TraceID != s.synTraceID || !s.staged.Up { | ||||||
|  | 		s.synTraceID = p.TraceID | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.staged.Direct = p.Direct | ||||||
|  | 		s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) | ||||||
|  | 		s.staged.RemoteAddr = msg.SrcAddr | ||||||
|  | 		s.publish() | ||||||
|  | 		s.logf("Got syn.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Always respond. | ||||||
|  | 	ack := ackPacket{ | ||||||
|  | 		TraceID:  p.TraceID, | ||||||
|  | 		FromAddr: getLocalAddr(), | ||||||
|  | 	} | ||||||
|  | 	s.sendControlPacket(ack) | ||||||
|  |  | ||||||
|  | 	if !s.staged.Direct && p.FromAddr.IsValid() { | ||||||
|  | 		s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer) OnProbe(msg controlMsg[probePacket]) { | ||||||
|  | 	if !msg.SrcAddr.IsValid() { | ||||||
|  | 		s.logf("Invalid probe address.") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer) OnPingTimer() peerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | ||||||
|  | 		s.staged.Up = false | ||||||
|  | 		s.publish() | ||||||
|  | 		s.logf("Connection timeout.") | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type stateClient struct { | ||||||
|  | 	*stateDisconnected | ||||||
|  |  | ||||||
|  | 	lastSeen time.Time | ||||||
|  | 	syn      synPacket | ||||||
|  | 	ack      ackPacket | ||||||
|  |  | ||||||
|  | 	probeTraceID uint64 | ||||||
|  | 	probeAddr    netip.AddrPort | ||||||
|  |  | ||||||
|  | 	localProbeTraceID uint64 | ||||||
|  | 	localProbeAddr    netip.AddrPort | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateClient(s *peerStateData) peerState { | ||||||
|  | 	s.client = true | ||||||
|  | 	ss := &stateClient{stateDisconnected: &stateDisconnected{s}} | ||||||
|  | 	ss.syn = synPacket{ | ||||||
|  | 		TraceID:   newTraceID(), | ||||||
|  | 		SharedKey: s.staged.DataCipher.Key(), | ||||||
|  | 		Direct:    s.staged.Direct, | ||||||
|  | 		FromAddr:  getLocalAddr(), | ||||||
|  | 	} | ||||||
|  | 	ss.sendSyn() | ||||||
|  | 	return ss | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { | ||||||
|  | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
|  | 		s.logf("Ack has incorrect trace ID") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.ack = msg.Packet | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  |  | ||||||
|  | 	if !s.staged.Up { | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.logf("Got ack.") | ||||||
|  | 		s.publish() | ||||||
|  | 	} else { | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch msg.Packet.TraceID { | ||||||
|  | 	case s.probeTraceID: | ||||||
|  | 		s.staged.RemoteAddr = s.probeAddr | ||||||
|  | 	case s.localProbeTraceID: | ||||||
|  | 		s.staged.RemoteAddr = s.localProbeAddr | ||||||
|  | 	default: | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.staged.Direct = true | ||||||
|  | 	s.publish() | ||||||
|  |  | ||||||
|  | 	s.syn.TraceID = newTraceID() | ||||||
|  | 	s.syn.Direct = true | ||||||
|  | 	s.syn.FromAddr = getLocalAddr() | ||||||
|  | 	s.sendControlPacket(s.syn) | ||||||
|  |  | ||||||
|  | 	s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Send probe. | ||||||
|  | 	// | ||||||
|  | 	// The source port will be the multicast port, so we'll have to | ||||||
|  | 	// construct the correct address using the peer's listed port. | ||||||
|  | 	s.localProbeTraceID = newTraceID() | ||||||
|  | 	s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
|  | 	s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) OnPingTimer() peerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval { | ||||||
|  | 		if s.staged.Up { | ||||||
|  | 			s.logf("Connection timeout.") | ||||||
|  | 		} | ||||||
|  | 		return s.OnPeerUpdate(s.peer) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.sendSyn() | ||||||
|  |  | ||||||
|  | 	if !s.staged.Direct && s.ack.FromAddr.IsValid() { | ||||||
|  | 		s.probeTraceID = newTraceID() | ||||||
|  | 		s.probeAddr = s.ack.FromAddr | ||||||
|  | 		s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) sendSyn() { | ||||||
|  | 	localAddr := getLocalAddr() | ||||||
|  | 	if localAddr != s.syn.FromAddr { | ||||||
|  | 		s.syn.TraceID = newTraceID() | ||||||
|  | 		s.syn.FromAddr = localAddr | ||||||
|  | 	} | ||||||
|  | 	s.sendControlPacket(s.syn) | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user