Better address discovery.
This commit is contained in:
		| @@ -3,65 +3,65 @@ package node | |||||||
| import ( | import ( | ||||||
| 	"log" | 	"log" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
|  | 	"runtime/debug" | ||||||
|  | 	"sort" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func addrDiscoveryServer() { | type pubAddrStore struct { | ||||||
| 	var ( | 	lastSeen map[netip.AddrPort]time.Time | ||||||
| 		buf1 = make([]byte, bufferSize) | 	addrList []netip.AddrPort | ||||||
| 		buf2 = make([]byte, bufferSize) |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		msg := <-discoveryMessages |  | ||||||
| 		p := msg.Packet |  | ||||||
|  |  | ||||||
| 		route := routingTable[msg.SrcIP].Load() |  | ||||||
| 		if route == nil || !route.RemoteAddr.IsValid() { |  | ||||||
| 			continue |  | ||||||
| } | } | ||||||
|  |  | ||||||
| 		_sendControlPacket(addrDiscoveryPacket{ | func newPubAddrStore() *pubAddrStore { | ||||||
| 			TraceID: p.TraceID, | 	return &pubAddrStore{ | ||||||
| 			ToAddr:  msg.SrcAddr, | 		lastSeen: map[netip.AddrPort]time.Time{}, | ||||||
| 		}, *route, buf1, buf2) | 		addrList: make([]netip.AddrPort, 0, 32), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func addrDiscoveryClient() { | func (store *pubAddrStore) Store(add netip.AddrPort) { | ||||||
| 	var ( | 	if localPub { | ||||||
| 		checkInterval = 8 * time.Second | 		log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) | ||||||
| 		timer         = time.NewTimer(4 * time.Second) | 		return | ||||||
|  |  | ||||||
| 		buf1 = make([]byte, bufferSize) |  | ||||||
| 		buf2 = make([]byte, bufferSize) |  | ||||||
|  |  | ||||||
| 		addrPacket addrDiscoveryPacket |  | ||||||
| 		lAddr      netip.AddrPort |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case msg := <-discoveryMessages: |  | ||||||
| 			p := msg.Packet |  | ||||||
| 			if p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr { |  | ||||||
| 				continue |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 			log.Printf("Discovered local address: %v", p.ToAddr) | 	if _, exists := store.lastSeen[add]; !exists { | ||||||
| 			lAddr = p.ToAddr | 		store.addrList = append(store.addrList, add) | ||||||
| 			localAddr.Store(&p.ToAddr) | 	} | ||||||
|  | 	store.lastSeen[add] = time.Now() | ||||||
| 		case <-timer.C: | 	store.sort() | ||||||
| 			timer.Reset(checkInterval) |  | ||||||
|  |  | ||||||
| 			route := getRelayRoute() |  | ||||||
| 			if route == nil { |  | ||||||
| 				continue |  | ||||||
| } | } | ||||||
|  |  | ||||||
| 			addrPacket.TraceID = newTraceID() | func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { | ||||||
| 			_sendControlPacket(addrPacket, *route, buf1, buf2) | 	if localPub { | ||||||
|  | 		addrs[0] = localAddr | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	copy(addrs[:], store.addrList) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (store *pubAddrStore) Clean() { | ||||||
|  | 	if localPub { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for ip, lastSeen := range store.lastSeen { | ||||||
|  | 		if time.Since(lastSeen) > timeoutInterval { | ||||||
|  | 			delete(store.lastSeen, ip) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	store.addrList = store.addrList[:0] | ||||||
|  | 	for ip := range store.lastSeen { | ||||||
|  | 		store.addrList = append(store.addrList, ip) | ||||||
|  | 	} | ||||||
|  | 	store.sort() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (store *pubAddrStore) sort() { | ||||||
|  | 	sort.Slice(store.addrList, func(i, j int) bool { | ||||||
|  | 		return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								node/addrdiscovery_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								node/addrdiscovery_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestPubAddrStore(t *testing.T) { | ||||||
|  | 	s := newPubAddrStore() | ||||||
|  |  | ||||||
|  | 	l := []netip.AddrPort{ | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range l { | ||||||
|  | 		s.Store(l[i]) | ||||||
|  | 		time.Sleep(time.Millisecond) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Clean() | ||||||
|  |  | ||||||
|  | 	l2 := s.Get() | ||||||
|  | 	if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { | ||||||
|  | 		t.Fatal(l, l2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,7 +1,6 @@ | |||||||
| package node | package node | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"net/netip" |  | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -12,13 +11,6 @@ func getRelayRoute() *peerRoute { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func getLocalAddr() netip.AddrPort { |  | ||||||
| 	if a := localAddr.Load(); a != nil { |  | ||||||
| 		return *a |  | ||||||
| 	} |  | ||||||
| 	return netip.AddrPort{} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { | func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { | ||||||
| 	buf := pkt.Marshal(buf2) | 	buf := pkt.Marshal(buf2) | ||||||
| 	h := header{ | 	h := header{ | ||||||
|   | |||||||
| @@ -41,6 +41,7 @@ var ( | |||||||
| 	netName     string | 	netName     string | ||||||
| 	localIP     byte | 	localIP     byte | ||||||
| 	localPub    bool | 	localPub    bool | ||||||
|  | 	localAddr   netip.AddrPort | ||||||
| 	privKey     []byte | 	privKey     []byte | ||||||
| 	privSignKey []byte | 	privSignKey []byte | ||||||
|  |  | ||||||
| @@ -78,10 +79,8 @@ var ( | |||||||
| 		return | 		return | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// Managed by the addrDiscovery* functions. |  | ||||||
| 	discoveryMessages = make(chan controlMsg[addrDiscoveryPacket], 256) |  | ||||||
|  |  | ||||||
| 	// Managed by the relayManager. | 	// Managed by the relayManager. | ||||||
| 	localAddr = &atomic.Pointer[netip.AddrPort]{} |  | ||||||
| 	relayIP = &atomic.Pointer[byte]{} | 	relayIP = &atomic.Pointer[byte]{} | ||||||
|  |  | ||||||
|  | 	publicAddrs = newPubAddrStore() | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								node/main.go
									
									
									
									
									
								
							| @@ -152,17 +152,13 @@ func main() { | |||||||
| 	ip, ok := netip.AddrFromSlice(config.PublicIP) | 	ip, ok := netip.AddrFromSlice(config.PublicIP) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		localPub = true | 		localPub = true | ||||||
| 		addr := netip.AddrPortFrom(ip, config.Port) | 		localAddr = netip.AddrPortFrom(ip, config.Port) | ||||||
| 		localAddr.Store(&addr) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	privKey = config.PrivKey | 	privKey = config.PrivKey | ||||||
| 	privSignKey = config.PrivSignKey | 	privSignKey = config.PrivSignKey | ||||||
|  |  | ||||||
| 	if localPub { | 	if !localPub { | ||||||
| 		go addrDiscoveryServer() |  | ||||||
| 	} else { |  | ||||||
| 		go addrDiscoveryClient() |  | ||||||
| 		go relayManager() | 		go relayManager() | ||||||
| 		go localDiscovery() | 		go localDiscovery() | ||||||
| 	} | 	} | ||||||
| @@ -177,6 +173,7 @@ func main() { | |||||||
|  |  | ||||||
| 	go newHubPoller().Run() | 	go newHubPoller().Run() | ||||||
| 	go readFromConn(conn) | 	go readFromConn(conn) | ||||||
|  |  | ||||||
| 	readFromIFace(iface) | 	readFromIFace(iface) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -232,7 +229,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if h.DestIP != localIP { | 	if h.DestIP != localIP { | ||||||
| 		log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP) | 		log.Printf("Incorrect destination IP on control packet: %#v", h) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -258,11 +255,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if dm, ok := msg.(controlMsg[addrDiscoveryPacket]); ok { |  | ||||||
| 		discoveryMessages <- dm |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	select { | 	select { | ||||||
| 	case messages <- msg: | 	case messages <- msg: | ||||||
| 	default: | 	default: | ||||||
|   | |||||||
| @@ -24,7 +24,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | |||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
| 	case packetTypeSynAck: | 	case packetTypeAck: | ||||||
| 		packet, err := parseAckPacket(buf) | 		packet, err := parseAckPacket(buf) | ||||||
| 		return controlMsg[ackPacket]{ | 		return controlMsg[ackPacket]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| @@ -40,14 +40,6 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | |||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
| 	case packetTypeAddrDiscovery: |  | ||||||
| 		packet, err := parseAddrDiscoveryPacket(buf) |  | ||||||
| 		return controlMsg[addrDiscoveryPacket]{ |  | ||||||
| 			SrcIP:   srcIP, |  | ||||||
| 			SrcAddr: srcAddr, |  | ||||||
| 			Packet:  packet, |  | ||||||
| 		}, err |  | ||||||
|  |  | ||||||
| 	default: | 	default: | ||||||
| 		return nil, errUnknownPacketType | 		return nil, errUnknownPacketType | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -63,12 +63,20 @@ func (w *binWriter) Int64(x int64) *binWriter { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { | func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { | ||||||
|  | 	w.Bool(addrPort.IsValid()) | ||||||
| 	addr := addrPort.Addr().As16() | 	addr := addrPort.Addr().As16() | ||||||
| 	copy(w.b[w.i:w.i+16], addr[:]) | 	copy(w.b[w.i:w.i+16], addr[:]) | ||||||
| 	w.i += 16 | 	w.i += 16 | ||||||
| 	return w.Uint16(addrPort.Port()) | 	return w.Uint16(addrPort.Port()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { | ||||||
|  | 	for _, addrPort := range l { | ||||||
|  | 		w.AddrPort(addrPort) | ||||||
|  | 	} | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
| func (w *binWriter) Build() []byte { | func (w *binWriter) Build() []byte { | ||||||
| 	return w.b[:w.i] | 	return w.b[:w.i] | ||||||
| } | } | ||||||
| @@ -146,15 +154,34 @@ func (r *binReader) Int64(x *int64) *binReader { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { | func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { | ||||||
| 	if !r.hasBytes(18) { | 	if !r.hasBytes(19) { | ||||||
| 		return r | 		return r | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	var ( | ||||||
|  | 		valid bool | ||||||
|  | 		port  uint16 | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	r.Bool(&valid) | ||||||
| 	addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() | 	addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() | ||||||
| 	r.i += 16 | 	r.i += 16 | ||||||
|  |  | ||||||
| 	var port uint16 |  | ||||||
| 	r.Uint16(&port) | 	r.Uint16(&port) | ||||||
|  |  | ||||||
|  | 	if valid { | ||||||
| 		*x = netip.AddrPortFrom(addr, port) | 		*x = netip.AddrPortFrom(addr, port) | ||||||
|  | 	} else { | ||||||
|  | 		*x = netip.AddrPort{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { | ||||||
|  | 	for i := range x { | ||||||
|  | 		r.AddrPort(&x[i]) | ||||||
|  | 	} | ||||||
| 	return r | 	return r | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -12,15 +12,30 @@ func TestBinWriteRead(t *testing.T) { | |||||||
| 	type Item struct { | 	type Item struct { | ||||||
| 		Type     byte | 		Type     byte | ||||||
| 		TraceID  uint64 | 		TraceID  uint64 | ||||||
|  | 		Addrs    [8]netip.AddrPort | ||||||
| 		DestAddr netip.AddrPort | 		DestAddr netip.AddrPort | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)} | 	in := Item{ | ||||||
|  | 		1, | ||||||
|  | 		2, | ||||||
|  | 		[8]netip.AddrPort{}, | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) | ||||||
|  | 	in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) | ||||||
|  | 	in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) | ||||||
|  | 	in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) | ||||||
|  | 	in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) | ||||||
|  | 	in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) | ||||||
|  | 	in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) | ||||||
|  |  | ||||||
| 	buf = newBinWriter(buf). | 	buf = newBinWriter(buf). | ||||||
| 		Byte(in.Type). | 		Byte(in.Type). | ||||||
| 		Uint64(in.TraceID). | 		Uint64(in.TraceID). | ||||||
| 		AddrPort(in.DestAddr). | 		AddrPort(in.DestAddr). | ||||||
|  | 		AddrPortArray(in.Addrs). | ||||||
| 		Build() | 		Build() | ||||||
|  |  | ||||||
| 	out := Item{} | 	out := Item{} | ||||||
| @@ -29,6 +44,7 @@ func TestBinWriteRead(t *testing.T) { | |||||||
| 		Byte(&out.Type). | 		Byte(&out.Type). | ||||||
| 		Uint64(&out.TraceID). | 		Uint64(&out.TraceID). | ||||||
| 		AddrPort(&out.DestAddr). | 		AddrPort(&out.DestAddr). | ||||||
|  | 		AddrPortArray(&out.Addrs). | ||||||
| 		Error() | 		Error() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|   | |||||||
| @@ -24,7 +24,7 @@ type synPacket struct { | |||||||
| 	TraceID       uint64   // TraceID to match response w/ request. | 	TraceID       uint64   // TraceID to match response w/ request. | ||||||
| 	SharedKey     [32]byte // Our shared key. | 	SharedKey     [32]byte // Our shared key. | ||||||
| 	Direct        bool | 	Direct        bool | ||||||
| 	FromAddr  netip.AddrPort // The client's sending address. | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p synPacket) Marshal(buf []byte) []byte { | func (p synPacket) Marshal(buf []byte) []byte { | ||||||
| @@ -33,7 +33,14 @@ func (p synPacket) Marshal(buf []byte) []byte { | |||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| 		SharedKey(p.SharedKey). | 		SharedKey(p.SharedKey). | ||||||
| 		Bool(p.Direct). | 		Bool(p.Direct). | ||||||
| 		AddrPort(p.FromAddr). | 		AddrPort(p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[7]). | ||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -42,7 +49,14 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | |||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		SharedKey(&p.SharedKey). | 		SharedKey(&p.SharedKey). | ||||||
| 		Bool(&p.Direct). | 		Bool(&p.Direct). | ||||||
| 		AddrPort(&p.FromAddr). | 		AddrPort(&p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[7]). | ||||||
| 		Error() | 		Error() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| @@ -51,44 +65,39 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | |||||||
|  |  | ||||||
| type ackPacket struct { | type ackPacket struct { | ||||||
| 	TraceID       uint64 | 	TraceID       uint64 | ||||||
| 	FromAddr netip.AddrPort | 	ToAddr        netip.AddrPort | ||||||
|  | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p ackPacket) Marshal(buf []byte) []byte { | func (p ackPacket) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(packetTypeSynAck). | 		Byte(packetTypeAck). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| 		AddrPort(p.FromAddr). | 		AddrPort(p.ToAddr). | ||||||
|  | 		AddrPort(p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[7]). | ||||||
| 		Build() | 		Build() | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseAckPacket(buf []byte) (p ackPacket, err error) { | func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||||
| 	err = newBinReader(buf[1:]). |  | ||||||
| 		Uint64(&p.TraceID). |  | ||||||
| 		AddrPort(&p.FromAddr). |  | ||||||
| 		Error() |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type addrDiscoveryPacket struct { |  | ||||||
| 	TraceID uint64 |  | ||||||
| 	ToAddr  netip.AddrPort |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (p addrDiscoveryPacket) Marshal(buf []byte) []byte { |  | ||||||
| 	return newBinWriter(buf). |  | ||||||
| 		Byte(packetTypeAddrDiscovery). |  | ||||||
| 		Uint64(p.TraceID). |  | ||||||
| 		AddrPort(p.ToAddr). |  | ||||||
| 		Build() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) { |  | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		AddrPort(&p.ToAddr). | 		AddrPort(&p.ToAddr). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[7]). | ||||||
| 		Error() | 		Error() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|   | |||||||
| @@ -14,7 +14,7 @@ import ( | |||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	pingInterval    = 8 * time.Second | 	pingInterval    = 8 * time.Second | ||||||
| 	timeoutInterval = 25 * time.Second | 	timeoutInterval = 30 * time.Second | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
| @@ -28,7 +28,7 @@ func startPeerSuper() { | |||||||
| 			buf1:      make([]byte, bufferSize), | 			buf1:      make([]byte, bufferSize), | ||||||
| 			buf2:      make([]byte, bufferSize), | 			buf2:      make([]byte, bufferSize), | ||||||
| 			limiter: ratelimiter.New(ratelimiter.Config{ | 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||||
| 				FillPeriod:   50 * time.Millisecond, | 				FillPeriod:   20 * time.Millisecond, | ||||||
| 				MaxWaitCount: 1, | 				MaxWaitCount: 1, | ||||||
| 			}), | 			}), | ||||||
| 		} | 		} | ||||||
| @@ -57,6 +57,7 @@ func runPeerSuper(peers [256]peerState) { | |||||||
| 			peers[msg.SrcIP].OnLocalDiscovery(msg) | 			peers[msg.SrcIP].OnLocalDiscovery(msg) | ||||||
|  |  | ||||||
| 		case pingTimerMsg: | 		case pingTimerMsg: | ||||||
|  | 			publicAddrs.Clean() | ||||||
| 			for i := range peers { | 			for i := range peers { | ||||||
| 				if newState := peers[i].OnPingTimer(); newState != nil { | 				if newState := peers[i].OnPingTimer(); newState != nil { | ||||||
| 					peers[i] = newState | 					peers[i] = newState | ||||||
| @@ -171,10 +172,13 @@ func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.peer = peer | 	s.peer = peer | ||||||
| 	s.staged.IP = s.remoteIP | 	s.staged = peerRoute{ | ||||||
| 	s.staged.PubSignKey = peer.PubSignKey | 		IP:            s.remoteIP, | ||||||
| 	s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) | 		PubSignKey:    peer.PubSignKey, | ||||||
| 	s.staged.DataCipher = newDataCipher() | 		ControlCipher: newControlCipher(privKey, peer.PubKey), | ||||||
|  | 		DataCipher:    newDataCipher(), | ||||||
|  | 	} | ||||||
|  | 	s.remotePub = false | ||||||
|  |  | ||||||
| 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||||
| 		s.remotePub = true | 		s.remotePub = true | ||||||
| @@ -255,12 +259,20 @@ func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { | |||||||
| 	// Always respond. | 	// Always respond. | ||||||
| 	ack := ackPacket{ | 	ack := ackPacket{ | ||||||
| 		TraceID:       p.TraceID, | 		TraceID:       p.TraceID, | ||||||
| 		FromAddr: getLocalAddr(), | 		ToAddr:        s.staged.RemoteAddr, | ||||||
|  | 		PossibleAddrs: publicAddrs.Get(), | ||||||
| 	} | 	} | ||||||
| 	s.sendControlPacket(ack) | 	s.sendControlPacket(ack) | ||||||
|  |  | ||||||
| 	if !s.staged.Direct && p.FromAddr.IsValid() { | 	if s.staged.Direct { | ||||||
| 		s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr) | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Not direct => send probes. | ||||||
|  | 	for _, addr := range p.PossibleAddrs { | ||||||
|  | 		if addr.IsValid() { | ||||||
|  | 			s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -290,26 +302,35 @@ type stateClient struct { | |||||||
| 	syn      synPacket | 	syn      synPacket | ||||||
| 	ack      ackPacket | 	ack      ackPacket | ||||||
|  |  | ||||||
| 	probeTraceID uint64 | 	probes             map[uint64]netip.AddrPort | ||||||
| 	probeAddr    netip.AddrPort | 	localDiscoveryAddr chan netip.AddrPort | ||||||
|  |  | ||||||
| 	localProbeTraceID uint64 |  | ||||||
| 	localProbeAddr    netip.AddrPort |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func enterStateClient(s *peerStateData) peerState { | func enterStateClient(s *peerStateData) peerState { | ||||||
| 	s.client = true | 	s.client = true | ||||||
| 	ss := &stateClient{stateDisconnected: &stateDisconnected{s}} | 	ss := &stateClient{ | ||||||
|  | 		stateDisconnected:  &stateDisconnected{s}, | ||||||
|  | 		probes:             map[uint64]netip.AddrPort{}, | ||||||
|  | 		localDiscoveryAddr: make(chan netip.AddrPort, 1), | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	ss.syn = synPacket{ | 	ss.syn = synPacket{ | ||||||
| 		TraceID:       newTraceID(), | 		TraceID:       newTraceID(), | ||||||
| 		SharedKey:     s.staged.DataCipher.Key(), | 		SharedKey:     s.staged.DataCipher.Key(), | ||||||
| 		Direct:        s.staged.Direct, | 		Direct:        s.staged.Direct, | ||||||
| 		FromAddr:  getLocalAddr(), | 		PossibleAddrs: publicAddrs.Get(), | ||||||
| 	} | 	} | ||||||
| 	ss.sendSyn() | 	ss.sendControlPacket(ss.syn) | ||||||
|  |  | ||||||
| 	return ss | 	return ss | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) sendProbeTo(addr netip.AddrPort) { | ||||||
|  | 	probe := probePacket{TraceID: newTraceID()} | ||||||
|  | 	s.probes[probe.TraceID] = addr | ||||||
|  | 	s.sendControlPacketTo(probe, addr) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { | func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { | ||||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
| 		s.logf("Ack has incorrect trace ID") | 		s.logf("Ack has incorrect trace ID") | ||||||
| @@ -324,6 +345,12 @@ func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { | |||||||
| 		s.logf("Got ack.") | 		s.logf("Got ack.") | ||||||
| 		s.publish() | 		s.publish() | ||||||
| 	} else { | 	} else { | ||||||
|  | 		// TODO: What???? | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Store possible public address if we're not a public node. | ||||||
|  | 	if !localPub && s.remotePub { | ||||||
|  | 		publicAddrs.Store(msg.Packet.ToAddr) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -332,21 +359,18 @@ func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	switch msg.Packet.TraceID { | 	addr, ok := s.probes[msg.Packet.TraceID] | ||||||
| 	case s.probeTraceID: | 	if !ok { | ||||||
| 		s.staged.RemoteAddr = s.probeAddr |  | ||||||
| 	case s.localProbeTraceID: |  | ||||||
| 		s.staged.RemoteAddr = s.localProbeAddr |  | ||||||
| 	default: |  | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	s.staged.RemoteAddr = addr | ||||||
| 	s.staged.Direct = true | 	s.staged.Direct = true | ||||||
| 	s.publish() | 	s.publish() | ||||||
|  |  | ||||||
| 	s.syn.TraceID = newTraceID() | 	s.syn.TraceID = newTraceID() | ||||||
| 	s.syn.Direct = true | 	s.syn.Direct = true | ||||||
| 	s.syn.FromAddr = getLocalAddr() | 	s.syn.PossibleAddrs = [8]netip.AddrPort{} | ||||||
| 	s.sendControlPacket(s.syn) | 	s.sendControlPacket(s.syn) | ||||||
|  |  | ||||||
| 	s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) | 	s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) | ||||||
| @@ -361,9 +385,14 @@ func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { | |||||||
| 	// | 	// | ||||||
| 	// The source port will be the multicast port, so we'll have to | 	// The source port will be the multicast port, so we'll have to | ||||||
| 	// construct the correct address using the peer's listed port. | 	// construct the correct address using the peer's listed port. | ||||||
| 	s.localProbeTraceID = newTraceID() | 	addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
| 	s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) |  | ||||||
| 	s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr) | 	select { | ||||||
|  | 	case s.localDiscoveryAddr <- addr: | ||||||
|  | 		// OK. | ||||||
|  | 	default: | ||||||
|  | 		log.Printf("Local discovery packet dropped.") | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClient) OnPingTimer() peerState { | func (s *stateClient) OnPingTimer() peerState { | ||||||
| @@ -374,22 +403,26 @@ func (s *stateClient) OnPingTimer() peerState { | |||||||
| 		return s.OnPeerUpdate(s.peer) | 		return s.OnPeerUpdate(s.peer) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.sendSyn() | 	s.sendControlPacket(s.syn) | ||||||
|  |  | ||||||
| 	if !s.staged.Direct && s.ack.FromAddr.IsValid() { | 	if s.staged.Direct { | ||||||
| 		s.probeTraceID = newTraceID() | 		return nil | ||||||
| 		s.probeAddr = s.ack.FromAddr | 	} | ||||||
| 		s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr) |  | ||||||
|  | 	clear(s.probes) | ||||||
|  | 	for _, ip := range publicAddrs.Get() { | ||||||
|  | 		if !ip.IsValid() { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		s.sendProbeTo(ip) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case addr := <-s.localDiscoveryAddr: | ||||||
|  | 		s.sendProbeTo(addr) | ||||||
|  | 	default: | ||||||
|  | 		// Nothing to do. | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClient) sendSyn() { |  | ||||||
| 	localAddr := getLocalAddr() |  | ||||||
| 	if localAddr != s.syn.FromAddr { |  | ||||||
| 		s.syn.TraceID = newTraceID() |  | ||||||
| 		s.syn.FromAddr = localAddr |  | ||||||
| 	} |  | ||||||
| 	s.sendControlPacket(s.syn) |  | ||||||
| } |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user