Compare commits
	
		
			5 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | a0b5058544 | ||
|  | 232681fac6 | ||
|  | 6e7a2456b2 | ||
|  | 970490b17b | ||
|  | 2bdd76e689 | 
| @@ -3,65 +3,69 @@ package node | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"runtime/debug" | ||||
| 	"sort" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func addrDiscoveryServer() { | ||||
| 	var ( | ||||
| 		buf1 = make([]byte, bufferSize) | ||||
| 		buf2 = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
| 		msg := <-discoveryMessages | ||||
| 		p := msg.Packet | ||||
|  | ||||
| 		route := routingTable[msg.SrcIP].Load() | ||||
| 		if route == nil || !route.RemoteAddr.IsValid() { | ||||
| 			continue | ||||
| type pubAddrStore struct { | ||||
| 	lastSeen map[netip.AddrPort]time.Time | ||||
| 	addrList []netip.AddrPort | ||||
| } | ||||
|  | ||||
| 		_sendControlPacket(addrDiscoveryPacket{ | ||||
| 			TraceID: p.TraceID, | ||||
| 			ToAddr:  msg.SrcAddr, | ||||
| 		}, *route, buf1, buf2) | ||||
| func newPubAddrStore() *pubAddrStore { | ||||
| 	return &pubAddrStore{ | ||||
| 		lastSeen: map[netip.AddrPort]time.Time{}, | ||||
| 		addrList: make([]netip.AddrPort, 0, 32), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func addrDiscoveryClient() { | ||||
| 	var ( | ||||
| 		checkInterval = 8 * time.Second | ||||
| 		timer         = time.NewTimer(4 * time.Second) | ||||
|  | ||||
| 		buf1 = make([]byte, bufferSize) | ||||
| 		buf2 = make([]byte, bufferSize) | ||||
|  | ||||
| 		addrPacket addrDiscoveryPacket | ||||
| 		lAddr      netip.AddrPort | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case msg := <-discoveryMessages: | ||||
| 			p := msg.Packet | ||||
| 			if p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr { | ||||
| 				continue | ||||
| func (store *pubAddrStore) Store(add netip.AddrPort) { | ||||
| 	if localPub { | ||||
| 		log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 			log.Printf("Discovered local address: %v", p.ToAddr) | ||||
| 			lAddr = p.ToAddr | ||||
| 			localAddr.Store(&p.ToAddr) | ||||
|  | ||||
| 		case <-timer.C: | ||||
| 			timer.Reset(checkInterval) | ||||
|  | ||||
| 			route := getRelayRoute() | ||||
| 			if route == nil { | ||||
| 				continue | ||||
| 	if !add.IsValid() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 			addrPacket.TraceID = newTraceID() | ||||
| 			_sendControlPacket(addrPacket, *route, buf1, buf2) | ||||
| 	if _, exists := store.lastSeen[add]; !exists { | ||||
| 		store.addrList = append(store.addrList, add) | ||||
| 	} | ||||
| 	store.lastSeen[add] = time.Now() | ||||
| 	store.sort() | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { | ||||
| 	if localPub { | ||||
| 		addrs[0] = localAddr | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	copy(addrs[:], store.addrList) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) Clean() { | ||||
| 	if localPub { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for ip, lastSeen := range store.lastSeen { | ||||
| 		if time.Since(lastSeen) > timeoutInterval { | ||||
| 			delete(store.lastSeen, ip) | ||||
| 		} | ||||
| 	} | ||||
| 	store.addrList = store.addrList[:0] | ||||
| 	for ip := range store.lastSeen { | ||||
| 		store.addrList = append(store.addrList, ip) | ||||
| 	} | ||||
| 	store.sort() | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) sort() { | ||||
| 	sort.Slice(store.addrList, func(i, j int) bool { | ||||
| 		return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) | ||||
| 	}) | ||||
| } | ||||
|   | ||||
							
								
								
									
										29
									
								
								node/addrdiscovery_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								node/addrdiscovery_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestPubAddrStore(t *testing.T) { | ||||
| 	s := newPubAddrStore() | ||||
|  | ||||
| 	l := []netip.AddrPort{ | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), | ||||
| 	} | ||||
|  | ||||
| 	for i := range l { | ||||
| 		s.Store(l[i]) | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 	} | ||||
|  | ||||
| 	s.Clean() | ||||
|  | ||||
| 	l2 := s.Get() | ||||
| 	if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { | ||||
| 		t.Fatal(l, l2) | ||||
| 	} | ||||
| } | ||||
| @@ -1,7 +1,6 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| @@ -12,13 +11,6 @@ func getRelayRoute() *peerRoute { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getLocalAddr() netip.AddrPort { | ||||
| 	if a := localAddr.Load(); a != nil { | ||||
| 		return *a | ||||
| 	} | ||||
| 	return netip.AddrPort{} | ||||
| } | ||||
|  | ||||
| func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { | ||||
| 	buf := pkt.Marshal(buf2) | ||||
| 	h := header{ | ||||
|   | ||||
| @@ -41,6 +41,7 @@ var ( | ||||
| 	netName     string | ||||
| 	localIP     byte | ||||
| 	localPub    bool | ||||
| 	localAddr   netip.AddrPort | ||||
| 	privKey     []byte | ||||
| 	privSignKey []byte | ||||
|  | ||||
| @@ -67,7 +68,7 @@ var ( | ||||
| 	}() | ||||
|  | ||||
| 	// Messages for the supervisor. | ||||
| 	messages = make(chan any, 512) | ||||
| 	messages = make(chan any, 1024) | ||||
|  | ||||
| 	// Global routing table. | ||||
| 	routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { | ||||
| @@ -78,10 +79,8 @@ var ( | ||||
| 		return | ||||
| 	}() | ||||
|  | ||||
| 	// Managed by the addrDiscovery* functions. | ||||
| 	discoveryMessages = make(chan controlMsg[addrDiscoveryPacket], 256) | ||||
|  | ||||
| 	// Managed by the relayManager. | ||||
| 	localAddr = &atomic.Pointer[netip.AddrPort]{} | ||||
| 	relayIP = &atomic.Pointer[byte]{} | ||||
|  | ||||
| 	publicAddrs = newPubAddrStore() | ||||
| ) | ||||
|   | ||||
| @@ -25,7 +25,7 @@ func sendLocalDiscovery(conn *net.UDPConn) { | ||||
| 		buf2 = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	for range time.Tick(32 * time.Second) { | ||||
| 	for range time.Tick(16 * time.Second) { | ||||
| 		signed := buildLocalDiscoveryPacket(buf1, buf2) | ||||
| 		if _, err := conn.WriteToUDP(signed, multicastAddr); err != nil { | ||||
| 			log.Printf("Failed to write multicast UDP packet: %v", err) | ||||
|   | ||||
							
								
								
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							| @@ -152,17 +152,13 @@ func main() { | ||||
| 	ip, ok := netip.AddrFromSlice(config.PublicIP) | ||||
| 	if ok { | ||||
| 		localPub = true | ||||
| 		addr := netip.AddrPortFrom(ip, config.Port) | ||||
| 		localAddr.Store(&addr) | ||||
| 		localAddr = netip.AddrPortFrom(ip, config.Port) | ||||
| 	} | ||||
|  | ||||
| 	privKey = config.PrivKey | ||||
| 	privSignKey = config.PrivSignKey | ||||
|  | ||||
| 	if localPub { | ||||
| 		go addrDiscoveryServer() | ||||
| 	} else { | ||||
| 		go addrDiscoveryClient() | ||||
| 	if !localPub { | ||||
| 		go relayManager() | ||||
| 		go localDiscovery() | ||||
| 	} | ||||
| @@ -177,6 +173,7 @@ func main() { | ||||
|  | ||||
| 	go newHubPoller().Run() | ||||
| 	go readFromConn(conn) | ||||
|  | ||||
| 	readFromIFace(iface) | ||||
| } | ||||
|  | ||||
| @@ -232,7 +229,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | ||||
| 	} | ||||
|  | ||||
| 	if h.DestIP != localIP { | ||||
| 		log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP) | ||||
| 		log.Printf("Incorrect destination IP on control packet: %#v", h) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -258,11 +255,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if dm, ok := msg.(controlMsg[addrDiscoveryPacket]); ok { | ||||
| 		discoveryMessages <- dm | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	select { | ||||
| 	case messages <- msg: | ||||
| 	default: | ||||
|   | ||||
| @@ -24,7 +24,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | ||||
| 			Packet:  packet, | ||||
| 		}, err | ||||
|  | ||||
| 	case packetTypeSynAck: | ||||
| 	case packetTypeAck: | ||||
| 		packet, err := parseAckPacket(buf) | ||||
| 		return controlMsg[ackPacket]{ | ||||
| 			SrcIP:   srcIP, | ||||
| @@ -40,14 +40,6 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | ||||
| 			Packet:  packet, | ||||
| 		}, err | ||||
|  | ||||
| 	case packetTypeAddrDiscovery: | ||||
| 		packet, err := parseAddrDiscoveryPacket(buf) | ||||
| 		return controlMsg[addrDiscoveryPacket]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
| 		}, err | ||||
|  | ||||
| 	default: | ||||
| 		return nil, errUnknownPacketType | ||||
| 	} | ||||
|   | ||||
| @@ -63,12 +63,20 @@ func (w *binWriter) Int64(x int64) *binWriter { | ||||
| } | ||||
|  | ||||
| func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { | ||||
| 	w.Bool(addrPort.IsValid()) | ||||
| 	addr := addrPort.Addr().As16() | ||||
| 	copy(w.b[w.i:w.i+16], addr[:]) | ||||
| 	w.i += 16 | ||||
| 	return w.Uint16(addrPort.Port()) | ||||
| } | ||||
|  | ||||
| func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { | ||||
| 	for _, addrPort := range l { | ||||
| 		w.AddrPort(addrPort) | ||||
| 	} | ||||
| 	return w | ||||
| } | ||||
|  | ||||
| func (w *binWriter) Build() []byte { | ||||
| 	return w.b[:w.i] | ||||
| } | ||||
| @@ -146,15 +154,34 @@ func (r *binReader) Int64(x *int64) *binReader { | ||||
| } | ||||
|  | ||||
| func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { | ||||
| 	if !r.hasBytes(18) { | ||||
| 	if !r.hasBytes(19) { | ||||
| 		return r | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		valid bool | ||||
| 		port  uint16 | ||||
| 	) | ||||
|  | ||||
| 	r.Bool(&valid) | ||||
| 	addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() | ||||
| 	r.i += 16 | ||||
|  | ||||
| 	var port uint16 | ||||
| 	r.Uint16(&port) | ||||
|  | ||||
| 	if valid { | ||||
| 		*x = netip.AddrPortFrom(addr, port) | ||||
| 	} else { | ||||
| 		*x = netip.AddrPort{} | ||||
| 	} | ||||
|  | ||||
| 	return r | ||||
| } | ||||
|  | ||||
| func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { | ||||
| 	for i := range x { | ||||
| 		r.AddrPort(&x[i]) | ||||
| 	} | ||||
| 	return r | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -12,15 +12,30 @@ func TestBinWriteRead(t *testing.T) { | ||||
| 	type Item struct { | ||||
| 		Type     byte | ||||
| 		TraceID  uint64 | ||||
| 		Addrs    [8]netip.AddrPort | ||||
| 		DestAddr netip.AddrPort | ||||
| 	} | ||||
|  | ||||
| 	in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)} | ||||
| 	in := Item{ | ||||
| 		1, | ||||
| 		2, | ||||
| 		[8]netip.AddrPort{}, | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), | ||||
| 	} | ||||
|  | ||||
| 	in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) | ||||
| 	in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) | ||||
| 	in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) | ||||
| 	in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) | ||||
| 	in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) | ||||
| 	in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) | ||||
| 	in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) | ||||
|  | ||||
| 	buf = newBinWriter(buf). | ||||
| 		Byte(in.Type). | ||||
| 		Uint64(in.TraceID). | ||||
| 		AddrPort(in.DestAddr). | ||||
| 		AddrPortArray(in.Addrs). | ||||
| 		Build() | ||||
|  | ||||
| 	out := Item{} | ||||
| @@ -29,6 +44,7 @@ func TestBinWriteRead(t *testing.T) { | ||||
| 		Byte(&out.Type). | ||||
| 		Uint64(&out.TraceID). | ||||
| 		AddrPort(&out.DestAddr). | ||||
| 		AddrPortArray(&out.Addrs). | ||||
| 		Error() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
|   | ||||
| @@ -24,7 +24,7 @@ type synPacket struct { | ||||
| 	TraceID       uint64   // TraceID to match response w/ request. | ||||
| 	SharedKey     [32]byte // Our shared key. | ||||
| 	Direct        bool | ||||
| 	FromAddr  netip.AddrPort // The client's sending address. | ||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||
| } | ||||
|  | ||||
| func (p synPacket) Marshal(buf []byte) []byte { | ||||
| @@ -33,7 +33,14 @@ func (p synPacket) Marshal(buf []byte) []byte { | ||||
| 		Uint64(p.TraceID). | ||||
| 		SharedKey(p.SharedKey). | ||||
| 		Bool(p.Direct). | ||||
| 		AddrPort(p.FromAddr). | ||||
| 		AddrPort(p.PossibleAddrs[0]). | ||||
| 		AddrPort(p.PossibleAddrs[1]). | ||||
| 		AddrPort(p.PossibleAddrs[2]). | ||||
| 		AddrPort(p.PossibleAddrs[3]). | ||||
| 		AddrPort(p.PossibleAddrs[4]). | ||||
| 		AddrPort(p.PossibleAddrs[5]). | ||||
| 		AddrPort(p.PossibleAddrs[6]). | ||||
| 		AddrPort(p.PossibleAddrs[7]). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| @@ -42,7 +49,14 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||
| 		Uint64(&p.TraceID). | ||||
| 		SharedKey(&p.SharedKey). | ||||
| 		Bool(&p.Direct). | ||||
| 		AddrPort(&p.FromAddr). | ||||
| 		AddrPort(&p.PossibleAddrs[0]). | ||||
| 		AddrPort(&p.PossibleAddrs[1]). | ||||
| 		AddrPort(&p.PossibleAddrs[2]). | ||||
| 		AddrPort(&p.PossibleAddrs[3]). | ||||
| 		AddrPort(&p.PossibleAddrs[4]). | ||||
| 		AddrPort(&p.PossibleAddrs[5]). | ||||
| 		AddrPort(&p.PossibleAddrs[6]). | ||||
| 		AddrPort(&p.PossibleAddrs[7]). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
| @@ -51,44 +65,39 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||
|  | ||||
| type ackPacket struct { | ||||
| 	TraceID       uint64 | ||||
| 	FromAddr netip.AddrPort | ||||
| 	ToAddr        netip.AddrPort | ||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||
| } | ||||
|  | ||||
| func (p ackPacket) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeSynAck). | ||||
| 		Byte(packetTypeAck). | ||||
| 		Uint64(p.TraceID). | ||||
| 		AddrPort(p.FromAddr). | ||||
| 		AddrPort(p.ToAddr). | ||||
| 		AddrPort(p.PossibleAddrs[0]). | ||||
| 		AddrPort(p.PossibleAddrs[1]). | ||||
| 		AddrPort(p.PossibleAddrs[2]). | ||||
| 		AddrPort(p.PossibleAddrs[3]). | ||||
| 		AddrPort(p.PossibleAddrs[4]). | ||||
| 		AddrPort(p.PossibleAddrs[5]). | ||||
| 		AddrPort(p.PossibleAddrs[6]). | ||||
| 		AddrPort(p.PossibleAddrs[7]). | ||||
| 		Build() | ||||
|  | ||||
| } | ||||
|  | ||||
| func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		AddrPort(&p.FromAddr). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type addrDiscoveryPacket struct { | ||||
| 	TraceID uint64 | ||||
| 	ToAddr  netip.AddrPort | ||||
| } | ||||
|  | ||||
| func (p addrDiscoveryPacket) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeAddrDiscovery). | ||||
| 		Uint64(p.TraceID). | ||||
| 		AddrPort(p.ToAddr). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		AddrPort(&p.ToAddr). | ||||
| 		AddrPort(&p.PossibleAddrs[0]). | ||||
| 		AddrPort(&p.PossibleAddrs[1]). | ||||
| 		AddrPort(&p.PossibleAddrs[2]). | ||||
| 		AddrPort(&p.PossibleAddrs[3]). | ||||
| 		AddrPort(&p.PossibleAddrs[4]). | ||||
| 		AddrPort(&p.PossibleAddrs[5]). | ||||
| 		AddrPort(&p.PossibleAddrs[6]). | ||||
| 		AddrPort(&p.PossibleAddrs[7]). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -14,7 +14,7 @@ import ( | ||||
|  | ||||
| const ( | ||||
| 	pingInterval    = 8 * time.Second | ||||
| 	timeoutInterval = 25 * time.Second | ||||
| 	timeoutInterval = 30 * time.Second | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -28,7 +28,7 @@ func startPeerSuper() { | ||||
| 			buf1:      make([]byte, bufferSize), | ||||
| 			buf2:      make([]byte, bufferSize), | ||||
| 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||
| 				FillPeriod:   50 * time.Millisecond, | ||||
| 				FillPeriod:   20 * time.Millisecond, | ||||
| 				MaxWaitCount: 1, | ||||
| 			}), | ||||
| 		} | ||||
| @@ -57,6 +57,7 @@ func runPeerSuper(peers [256]peerState) { | ||||
| 			peers[msg.SrcIP].OnLocalDiscovery(msg) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			publicAddrs.Clean() | ||||
| 			for i := range peers { | ||||
| 				if newState := peers[i].OnPingTimer(); newState != nil { | ||||
| 					peers[i] = newState | ||||
| @@ -101,6 +102,8 @@ type peerStateData struct { | ||||
| 	// For logging. Set per-state. | ||||
| 	client bool | ||||
|  | ||||
| 	// We rate limit per remote endpoint because if we don't we tend to lose | ||||
| 	// packets. | ||||
| 	limiter *ratelimiter.Limiter | ||||
| } | ||||
|  | ||||
| @@ -171,10 +174,13 @@ func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { | ||||
| 	} | ||||
|  | ||||
| 	s.peer = peer | ||||
| 	s.staged.IP = s.remoteIP | ||||
| 	s.staged.PubSignKey = peer.PubSignKey | ||||
| 	s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) | ||||
| 	s.staged.DataCipher = newDataCipher() | ||||
| 	s.staged = peerRoute{ | ||||
| 		IP:            s.remoteIP, | ||||
| 		PubSignKey:    peer.PubSignKey, | ||||
| 		ControlCipher: newControlCipher(privKey, peer.PubKey), | ||||
| 		DataCipher:    newDataCipher(), | ||||
| 	} | ||||
| 	s.remotePub = false | ||||
|  | ||||
| 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||
| 		s.remotePub = true | ||||
| @@ -255,12 +261,21 @@ func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { | ||||
| 	// Always respond. | ||||
| 	ack := ackPacket{ | ||||
| 		TraceID:       p.TraceID, | ||||
| 		FromAddr: getLocalAddr(), | ||||
| 		ToAddr:        s.staged.RemoteAddr, | ||||
| 		PossibleAddrs: publicAddrs.Get(), | ||||
| 	} | ||||
| 	s.sendControlPacket(ack) | ||||
|  | ||||
| 	if !s.staged.Direct && p.FromAddr.IsValid() { | ||||
| 		s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr) | ||||
| 	if s.staged.Direct { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Not direct => send probes. | ||||
| 	for _, addr := range p.PossibleAddrs { | ||||
| 		if !addr.IsValid() { | ||||
| 			break | ||||
| 		} | ||||
| 		s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -290,26 +305,34 @@ type stateClient struct { | ||||
| 	syn      synPacket | ||||
| 	ack      ackPacket | ||||
|  | ||||
| 	probeTraceID uint64 | ||||
| 	probeAddr    netip.AddrPort | ||||
|  | ||||
| 	localProbeTraceID uint64 | ||||
| 	localProbeAddr    netip.AddrPort | ||||
| 	probes             map[uint64]netip.AddrPort | ||||
| 	localDiscoveryAddr netip.AddrPort | ||||
| } | ||||
|  | ||||
| func enterStateClient(s *peerStateData) peerState { | ||||
| 	s.client = true | ||||
| 	ss := &stateClient{stateDisconnected: &stateDisconnected{s}} | ||||
| 	ss := &stateClient{ | ||||
| 		stateDisconnected: &stateDisconnected{s}, | ||||
| 		probes:            map[uint64]netip.AddrPort{}, | ||||
| 	} | ||||
|  | ||||
| 	ss.syn = synPacket{ | ||||
| 		TraceID:       newTraceID(), | ||||
| 		SharedKey:     s.staged.DataCipher.Key(), | ||||
| 		Direct:        s.staged.Direct, | ||||
| 		FromAddr:  getLocalAddr(), | ||||
| 		PossibleAddrs: publicAddrs.Get(), | ||||
| 	} | ||||
| 	ss.sendSyn() | ||||
| 	ss.sendControlPacket(ss.syn) | ||||
|  | ||||
| 	return ss | ||||
| } | ||||
|  | ||||
| func (s *stateClient) sendProbeTo(addr netip.AddrPort) { | ||||
| 	probe := probePacket{TraceID: newTraceID()} | ||||
| 	s.probes[probe.TraceID] = addr | ||||
| 	s.sendControlPacketTo(probe, addr) | ||||
| } | ||||
|  | ||||
| func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { | ||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | ||||
| 		s.logf("Ack has incorrect trace ID") | ||||
| @@ -323,7 +346,11 @@ func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { | ||||
| 		s.staged.Up = true | ||||
| 		s.logf("Got ack.") | ||||
| 		s.publish() | ||||
| 	} else { | ||||
| 	} | ||||
|  | ||||
| 	// Store possible public address if we're not a public node. | ||||
| 	if !localPub && s.remotePub { | ||||
| 		publicAddrs.Store(msg.Packet.ToAddr) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -332,21 +359,18 @@ func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	switch msg.Packet.TraceID { | ||||
| 	case s.probeTraceID: | ||||
| 		s.staged.RemoteAddr = s.probeAddr | ||||
| 	case s.localProbeTraceID: | ||||
| 		s.staged.RemoteAddr = s.localProbeAddr | ||||
| 	default: | ||||
| 	addr, ok := s.probes[msg.Packet.TraceID] | ||||
| 	if !ok { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	s.staged.RemoteAddr = addr | ||||
| 	s.staged.Direct = true | ||||
| 	s.publish() | ||||
|  | ||||
| 	s.syn.TraceID = newTraceID() | ||||
| 	s.syn.Direct = true | ||||
| 	s.syn.FromAddr = getLocalAddr() | ||||
| 	s.syn.PossibleAddrs = [8]netip.AddrPort{} | ||||
| 	s.sendControlPacket(s.syn) | ||||
|  | ||||
| 	s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) | ||||
| @@ -357,13 +381,9 @@ func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Send probe. | ||||
| 	// | ||||
| 	// The source port will be the multicast port, so we'll have to | ||||
| 	// construct the correct address using the peer's listed port. | ||||
| 	s.localProbeTraceID = newTraceID() | ||||
| 	s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| 	s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr) | ||||
| 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| } | ||||
|  | ||||
| func (s *stateClient) OnPingTimer() peerState { | ||||
| @@ -374,22 +394,24 @@ func (s *stateClient) OnPingTimer() peerState { | ||||
| 		return s.OnPeerUpdate(s.peer) | ||||
| 	} | ||||
|  | ||||
| 	s.sendSyn() | ||||
| 	s.sendControlPacket(s.syn) | ||||
|  | ||||
| 	if !s.staged.Direct && s.ack.FromAddr.IsValid() { | ||||
| 		s.probeTraceID = newTraceID() | ||||
| 		s.probeAddr = s.ack.FromAddr | ||||
| 		s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr) | ||||
| 	if s.staged.Direct { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	clear(s.probes) | ||||
| 	for _, addr := range s.ack.PossibleAddrs { | ||||
| 		if !addr.IsValid() { | ||||
| 			break | ||||
| 		} | ||||
| 		s.sendProbeTo(addr) | ||||
| 	} | ||||
|  | ||||
| 	if s.localDiscoveryAddr.IsValid() { | ||||
| 		s.sendProbeTo(s.localDiscoveryAddr) | ||||
| 		s.localDiscoveryAddr = netip.AddrPort{} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *stateClient) sendSyn() { | ||||
| 	localAddr := getLocalAddr() | ||||
| 	if localAddr != s.syn.FromAddr { | ||||
| 		s.syn.TraceID = newTraceID() | ||||
| 		s.syn.FromAddr = localAddr | ||||
| 	} | ||||
| 	s.sendControlPacket(s.syn) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user