Better address discovery.

This commit is contained in:
jdl 2025-01-12 20:31:36 +01:00
parent d495ba9be7
commit 2bdd76e689
10 changed files with 253 additions and 164 deletions

View File

@ -3,65 +3,65 @@ package node
import ( import (
"log" "log"
"net/netip" "net/netip"
"runtime/debug"
"sort"
"time" "time"
) )
func addrDiscoveryServer() { type pubAddrStore struct {
var ( lastSeen map[netip.AddrPort]time.Time
buf1 = make([]byte, bufferSize) addrList []netip.AddrPort
buf2 = make([]byte, bufferSize) }
)
for { func newPubAddrStore() *pubAddrStore {
msg := <-discoveryMessages return &pubAddrStore{
p := msg.Packet lastSeen: map[netip.AddrPort]time.Time{},
addrList: make([]netip.AddrPort, 0, 32),
route := routingTable[msg.SrcIP].Load()
if route == nil || !route.RemoteAddr.IsValid() {
continue
}
_sendControlPacket(addrDiscoveryPacket{
TraceID: p.TraceID,
ToAddr: msg.SrcAddr,
}, *route, buf1, buf2)
} }
} }
func addrDiscoveryClient() { func (store *pubAddrStore) Store(add netip.AddrPort) {
var ( if localPub {
checkInterval = 8 * time.Second log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack())
timer = time.NewTimer(4 * time.Second) return
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
addrPacket addrDiscoveryPacket
lAddr netip.AddrPort
)
for {
select {
case msg := <-discoveryMessages:
p := msg.Packet
if p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr {
continue
} }
log.Printf("Discovered local address: %v", p.ToAddr) if _, exists := store.lastSeen[add]; !exists {
lAddr = p.ToAddr store.addrList = append(store.addrList, add)
localAddr.Store(&p.ToAddr)
case <-timer.C:
timer.Reset(checkInterval)
route := getRelayRoute()
if route == nil {
continue
}
addrPacket.TraceID = newTraceID()
_sendControlPacket(addrPacket, *route, buf1, buf2)
}
} }
store.lastSeen[add] = time.Now()
store.sort()
}
func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) {
if localPub {
addrs[0] = localAddr
return
}
copy(addrs[:], store.addrList)
return
}
func (store *pubAddrStore) Clean() {
if localPub {
return
}
for ip, lastSeen := range store.lastSeen {
if time.Since(lastSeen) > timeoutInterval {
delete(store.lastSeen, ip)
}
}
store.addrList = store.addrList[:0]
for ip := range store.lastSeen {
store.addrList = append(store.addrList, ip)
}
store.sort()
}
func (store *pubAddrStore) sort() {
sort.Slice(store.addrList, func(i, j int) bool {
return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]])
})
} }

View File

@ -0,0 +1,29 @@
package node
import (
"net/netip"
"testing"
"time"
)
func TestPubAddrStore(t *testing.T) {
s := newPubAddrStore()
l := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22),
}
for i := range l {
s.Store(l[i])
time.Sleep(time.Millisecond)
}
s.Clean()
l2 := s.Get()
if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] {
t.Fatal(l, l2)
}
}

View File

@ -1,7 +1,6 @@
package node package node
import ( import (
"net/netip"
"sync/atomic" "sync/atomic"
) )
@ -12,13 +11,6 @@ func getRelayRoute() *peerRoute {
return nil return nil
} }
func getLocalAddr() netip.AddrPort {
if a := localAddr.Load(); a != nil {
return *a
}
return netip.AddrPort{}
}
func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) {
buf := pkt.Marshal(buf2) buf := pkt.Marshal(buf2)
h := header{ h := header{

View File

@ -41,6 +41,7 @@ var (
netName string netName string
localIP byte localIP byte
localPub bool localPub bool
localAddr netip.AddrPort
privKey []byte privKey []byte
privSignKey []byte privSignKey []byte
@ -78,10 +79,8 @@ var (
return return
}() }()
// Managed by the addrDiscovery* functions.
discoveryMessages = make(chan controlMsg[addrDiscoveryPacket], 256)
// Managed by the relayManager. // Managed by the relayManager.
localAddr = &atomic.Pointer[netip.AddrPort]{}
relayIP = &atomic.Pointer[byte]{} relayIP = &atomic.Pointer[byte]{}
publicAddrs = newPubAddrStore()
) )

View File

@ -152,17 +152,13 @@ func main() {
ip, ok := netip.AddrFromSlice(config.PublicIP) ip, ok := netip.AddrFromSlice(config.PublicIP)
if ok { if ok {
localPub = true localPub = true
addr := netip.AddrPortFrom(ip, config.Port) localAddr = netip.AddrPortFrom(ip, config.Port)
localAddr.Store(&addr)
} }
privKey = config.PrivKey privKey = config.PrivKey
privSignKey = config.PrivSignKey privSignKey = config.PrivSignKey
if localPub { if !localPub {
go addrDiscoveryServer()
} else {
go addrDiscoveryClient()
go relayManager() go relayManager()
go localDiscovery() go localDiscovery()
} }
@ -177,6 +173,7 @@ func main() {
go newHubPoller().Run() go newHubPoller().Run()
go readFromConn(conn) go readFromConn(conn)
readFromIFace(iface) readFromIFace(iface)
} }
@ -232,7 +229,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
} }
if h.DestIP != localIP { if h.DestIP != localIP {
log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP) log.Printf("Incorrect destination IP on control packet: %#v", h)
return return
} }
@ -258,11 +255,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
return return
} }
if dm, ok := msg.(controlMsg[addrDiscoveryPacket]); ok {
discoveryMessages <- dm
return
}
select { select {
case messages <- msg: case messages <- msg:
default: default:

View File

@ -24,7 +24,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error
Packet: packet, Packet: packet,
}, err }, err
case packetTypeSynAck: case packetTypeAck:
packet, err := parseAckPacket(buf) packet, err := parseAckPacket(buf)
return controlMsg[ackPacket]{ return controlMsg[ackPacket]{
SrcIP: srcIP, SrcIP: srcIP,
@ -40,14 +40,6 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error
Packet: packet, Packet: packet,
}, err }, err
case packetTypeAddrDiscovery:
packet, err := parseAddrDiscoveryPacket(buf)
return controlMsg[addrDiscoveryPacket]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
default: default:
return nil, errUnknownPacketType return nil, errUnknownPacketType
} }

View File

@ -63,12 +63,20 @@ func (w *binWriter) Int64(x int64) *binWriter {
} }
func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter {
w.Bool(addrPort.IsValid())
addr := addrPort.Addr().As16() addr := addrPort.Addr().As16()
copy(w.b[w.i:w.i+16], addr[:]) copy(w.b[w.i:w.i+16], addr[:])
w.i += 16 w.i += 16
return w.Uint16(addrPort.Port()) return w.Uint16(addrPort.Port())
} }
func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter {
for _, addrPort := range l {
w.AddrPort(addrPort)
}
return w
}
func (w *binWriter) Build() []byte { func (w *binWriter) Build() []byte {
return w.b[:w.i] return w.b[:w.i]
} }
@ -146,15 +154,34 @@ func (r *binReader) Int64(x *int64) *binReader {
} }
func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
if !r.hasBytes(18) { if !r.hasBytes(19) {
return r return r
} }
var (
valid bool
port uint16
)
r.Bool(&valid)
addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap()
r.i += 16 r.i += 16
var port uint16
r.Uint16(&port) r.Uint16(&port)
if valid {
*x = netip.AddrPortFrom(addr, port) *x = netip.AddrPortFrom(addr, port)
} else {
*x = netip.AddrPort{}
}
return r
}
func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader {
for i := range x {
r.AddrPort(&x[i])
}
return r return r
} }

View File

@ -12,15 +12,30 @@ func TestBinWriteRead(t *testing.T) {
type Item struct { type Item struct {
Type byte Type byte
TraceID uint64 TraceID uint64
Addrs [8]netip.AddrPort
DestAddr netip.AddrPort DestAddr netip.AddrPort
} }
in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)} in := Item{
1,
2,
[8]netip.AddrPort{},
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22),
}
in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20)
in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22)
in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23)
in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24)
in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25)
in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26)
in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27)
buf = newBinWriter(buf). buf = newBinWriter(buf).
Byte(in.Type). Byte(in.Type).
Uint64(in.TraceID). Uint64(in.TraceID).
AddrPort(in.DestAddr). AddrPort(in.DestAddr).
AddrPortArray(in.Addrs).
Build() Build()
out := Item{} out := Item{}
@ -29,6 +44,7 @@ func TestBinWriteRead(t *testing.T) {
Byte(&out.Type). Byte(&out.Type).
Uint64(&out.TraceID). Uint64(&out.TraceID).
AddrPort(&out.DestAddr). AddrPort(&out.DestAddr).
AddrPortArray(&out.Addrs).
Error() Error()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -24,7 +24,7 @@ type synPacket struct {
TraceID uint64 // TraceID to match response w/ request. TraceID uint64 // TraceID to match response w/ request.
SharedKey [32]byte // Our shared key. SharedKey [32]byte // Our shared key.
Direct bool Direct bool
FromAddr netip.AddrPort // The client's sending address. PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
} }
func (p synPacket) Marshal(buf []byte) []byte { func (p synPacket) Marshal(buf []byte) []byte {
@ -33,7 +33,14 @@ func (p synPacket) Marshal(buf []byte) []byte {
Uint64(p.TraceID). Uint64(p.TraceID).
SharedKey(p.SharedKey). SharedKey(p.SharedKey).
Bool(p.Direct). Bool(p.Direct).
AddrPort(p.FromAddr). AddrPort(p.PossibleAddrs[0]).
AddrPort(p.PossibleAddrs[1]).
AddrPort(p.PossibleAddrs[2]).
AddrPort(p.PossibleAddrs[3]).
AddrPort(p.PossibleAddrs[4]).
AddrPort(p.PossibleAddrs[5]).
AddrPort(p.PossibleAddrs[6]).
AddrPort(p.PossibleAddrs[7]).
Build() Build()
} }
@ -42,7 +49,14 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
Uint64(&p.TraceID). Uint64(&p.TraceID).
SharedKey(&p.SharedKey). SharedKey(&p.SharedKey).
Bool(&p.Direct). Bool(&p.Direct).
AddrPort(&p.FromAddr). AddrPort(&p.PossibleAddrs[0]).
AddrPort(&p.PossibleAddrs[1]).
AddrPort(&p.PossibleAddrs[2]).
AddrPort(&p.PossibleAddrs[3]).
AddrPort(&p.PossibleAddrs[4]).
AddrPort(&p.PossibleAddrs[5]).
AddrPort(&p.PossibleAddrs[6]).
AddrPort(&p.PossibleAddrs[7]).
Error() Error()
return return
} }
@ -51,44 +65,39 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
type ackPacket struct { type ackPacket struct {
TraceID uint64 TraceID uint64
FromAddr netip.AddrPort ToAddr netip.AddrPort
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
} }
func (p ackPacket) Marshal(buf []byte) []byte { func (p ackPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf). return newBinWriter(buf).
Byte(packetTypeSynAck). Byte(packetTypeAck).
Uint64(p.TraceID). Uint64(p.TraceID).
AddrPort(p.FromAddr). AddrPort(p.ToAddr).
AddrPort(p.PossibleAddrs[0]).
AddrPort(p.PossibleAddrs[1]).
AddrPort(p.PossibleAddrs[2]).
AddrPort(p.PossibleAddrs[3]).
AddrPort(p.PossibleAddrs[4]).
AddrPort(p.PossibleAddrs[5]).
AddrPort(p.PossibleAddrs[6]).
AddrPort(p.PossibleAddrs[7]).
Build() Build()
} }
func parseAckPacket(buf []byte) (p ackPacket, err error) { func parseAckPacket(buf []byte) (p ackPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
AddrPort(&p.FromAddr).
Error()
return
}
// ----------------------------------------------------------------------------
type addrDiscoveryPacket struct {
TraceID uint64
ToAddr netip.AddrPort
}
func (p addrDiscoveryPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeAddrDiscovery).
Uint64(p.TraceID).
AddrPort(p.ToAddr).
Build()
}
func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) {
err = newBinReader(buf[1:]). err = newBinReader(buf[1:]).
Uint64(&p.TraceID). Uint64(&p.TraceID).
AddrPort(&p.ToAddr). AddrPort(&p.ToAddr).
AddrPort(&p.PossibleAddrs[0]).
AddrPort(&p.PossibleAddrs[1]).
AddrPort(&p.PossibleAddrs[2]).
AddrPort(&p.PossibleAddrs[3]).
AddrPort(&p.PossibleAddrs[4]).
AddrPort(&p.PossibleAddrs[5]).
AddrPort(&p.PossibleAddrs[6]).
AddrPort(&p.PossibleAddrs[7]).
Error() Error()
return return
} }

View File

@ -14,7 +14,7 @@ import (
const ( const (
pingInterval = 8 * time.Second pingInterval = 8 * time.Second
timeoutInterval = 25 * time.Second timeoutInterval = 30 * time.Second
) )
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -28,7 +28,7 @@ func startPeerSuper() {
buf1: make([]byte, bufferSize), buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize), buf2: make([]byte, bufferSize),
limiter: ratelimiter.New(ratelimiter.Config{ limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 50 * time.Millisecond, FillPeriod: 20 * time.Millisecond,
MaxWaitCount: 1, MaxWaitCount: 1,
}), }),
} }
@ -57,6 +57,7 @@ func runPeerSuper(peers [256]peerState) {
peers[msg.SrcIP].OnLocalDiscovery(msg) peers[msg.SrcIP].OnLocalDiscovery(msg)
case pingTimerMsg: case pingTimerMsg:
publicAddrs.Clean()
for i := range peers { for i := range peers {
if newState := peers[i].OnPingTimer(); newState != nil { if newState := peers[i].OnPingTimer(); newState != nil {
peers[i] = newState peers[i] = newState
@ -171,10 +172,13 @@ func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState {
} }
s.peer = peer s.peer = peer
s.staged.IP = s.remoteIP s.staged = peerRoute{
s.staged.PubSignKey = peer.PubSignKey IP: s.remoteIP,
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) PubSignKey: peer.PubSignKey,
s.staged.DataCipher = newDataCipher() ControlCipher: newControlCipher(privKey, peer.PubKey),
DataCipher: newDataCipher(),
}
s.remotePub = false
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true s.remotePub = true
@ -255,12 +259,20 @@ func (s *stateServer) OnSyn(msg controlMsg[synPacket]) {
// Always respond. // Always respond.
ack := ackPacket{ ack := ackPacket{
TraceID: p.TraceID, TraceID: p.TraceID,
FromAddr: getLocalAddr(), ToAddr: s.staged.RemoteAddr,
PossibleAddrs: publicAddrs.Get(),
} }
s.sendControlPacket(ack) s.sendControlPacket(ack)
if !s.staged.Direct && p.FromAddr.IsValid() { if s.staged.Direct {
s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr) return
}
// Not direct => send probes.
for _, addr := range p.PossibleAddrs {
if addr.IsValid() {
s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr)
}
} }
} }
@ -290,26 +302,35 @@ type stateClient struct {
syn synPacket syn synPacket
ack ackPacket ack ackPacket
probeTraceID uint64 probes map[uint64]netip.AddrPort
probeAddr netip.AddrPort localDiscoveryAddr chan netip.AddrPort
localProbeTraceID uint64
localProbeAddr netip.AddrPort
} }
func enterStateClient(s *peerStateData) peerState { func enterStateClient(s *peerStateData) peerState {
s.client = true s.client = true
ss := &stateClient{stateDisconnected: &stateDisconnected{s}} ss := &stateClient{
stateDisconnected: &stateDisconnected{s},
probes: map[uint64]netip.AddrPort{},
localDiscoveryAddr: make(chan netip.AddrPort, 1),
}
ss.syn = synPacket{ ss.syn = synPacket{
TraceID: newTraceID(), TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(), SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct, Direct: s.staged.Direct,
FromAddr: getLocalAddr(), PossibleAddrs: publicAddrs.Get(),
} }
ss.sendSyn() ss.sendControlPacket(ss.syn)
return ss return ss
} }
func (s *stateClient) sendProbeTo(addr netip.AddrPort) {
probe := probePacket{TraceID: newTraceID()}
s.probes[probe.TraceID] = addr
s.sendControlPacketTo(probe, addr)
}
func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { func (s *stateClient) OnAck(msg controlMsg[ackPacket]) {
if msg.Packet.TraceID != s.syn.TraceID { if msg.Packet.TraceID != s.syn.TraceID {
s.logf("Ack has incorrect trace ID") s.logf("Ack has incorrect trace ID")
@ -324,6 +345,12 @@ func (s *stateClient) OnAck(msg controlMsg[ackPacket]) {
s.logf("Got ack.") s.logf("Got ack.")
s.publish() s.publish()
} else { } else {
// TODO: What????
}
// Store possible public address if we're not a public node.
if !localPub && s.remotePub {
publicAddrs.Store(msg.Packet.ToAddr)
} }
} }
@ -332,21 +359,18 @@ func (s *stateClient) OnProbe(msg controlMsg[probePacket]) {
return return
} }
switch msg.Packet.TraceID { addr, ok := s.probes[msg.Packet.TraceID]
case s.probeTraceID: if !ok {
s.staged.RemoteAddr = s.probeAddr
case s.localProbeTraceID:
s.staged.RemoteAddr = s.localProbeAddr
default:
return return
} }
s.staged.RemoteAddr = addr
s.staged.Direct = true s.staged.Direct = true
s.publish() s.publish()
s.syn.TraceID = newTraceID() s.syn.TraceID = newTraceID()
s.syn.Direct = true s.syn.Direct = true
s.syn.FromAddr = getLocalAddr() s.syn.PossibleAddrs = [8]netip.AddrPort{}
s.sendControlPacket(s.syn) s.sendControlPacket(s.syn)
s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String())
@ -361,9 +385,14 @@ func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) {
// //
// The source port will be the multicast port, so we'll have to // The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port. // construct the correct address using the peer's listed port.
s.localProbeTraceID = newTraceID() addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr) select {
case s.localDiscoveryAddr <- addr:
// OK.
default:
log.Printf("Local discovery packet dropped.")
}
} }
func (s *stateClient) OnPingTimer() peerState { func (s *stateClient) OnPingTimer() peerState {
@ -374,22 +403,26 @@ func (s *stateClient) OnPingTimer() peerState {
return s.OnPeerUpdate(s.peer) return s.OnPeerUpdate(s.peer)
} }
s.sendSyn() s.sendControlPacket(s.syn)
if !s.staged.Direct && s.ack.FromAddr.IsValid() { if s.staged.Direct {
s.probeTraceID = newTraceID() return nil
s.probeAddr = s.ack.FromAddr }
s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr)
clear(s.probes)
for _, ip := range publicAddrs.Get() {
if !ip.IsValid() {
break
}
s.sendProbeTo(ip)
}
select {
case addr := <-s.localDiscoveryAddr:
s.sendProbeTo(addr)
default:
// Nothing to do.
} }
return nil return nil
} }
func (s *stateClient) sendSyn() {
localAddr := getLocalAddr()
if localAddr != s.syn.FromAddr {
s.syn.TraceID = newTraceID()
s.syn.FromAddr = localAddr
}
s.sendControlPacket(s.syn)
}