refactor-for-testability #3
@ -84,7 +84,7 @@ func (r *connReader) handleControlPacket(
|
|||||||
enc []byte,
|
enc []byte,
|
||||||
) {
|
) {
|
||||||
if peer.ControlCipher == nil {
|
if peer.ControlCipher == nil {
|
||||||
log.Printf("No control cipher for peer: %v", h)
|
r.logf("No control cipher for peer: %d", h.SourceIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,8 +18,10 @@ const (
|
|||||||
dataCipherOverhead = 16
|
dataCipherOverhead = 16
|
||||||
signOverhead = 64
|
signOverhead = 64
|
||||||
|
|
||||||
pingInterval = 8 * time.Second
|
pingInterval = 8 * time.Second
|
||||||
timeoutInterval = 30 * time.Second
|
timeoutInterval = 30 * time.Second
|
||||||
|
broadcastInterval = 16 * time.Second
|
||||||
|
broadcastErrorTimeoutInterval = 8 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
|
var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
|
||||||
|
@ -50,11 +50,15 @@ func newHubPoller(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hp *hubPoller) logf(s string, args ...any) {
|
||||||
|
log.Printf("[HubPoller] "+s, args...)
|
||||||
|
}
|
||||||
|
|
||||||
func (hp *hubPoller) Run() {
|
func (hp *hubPoller) Run() {
|
||||||
state, err := loadNetworkState(hp.netName)
|
state, err := loadNetworkState(hp.netName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to load network state: %v", err)
|
hp.logf("Failed to load network state: %v", err)
|
||||||
log.Printf("Polling hub...")
|
hp.logf("Polling hub...")
|
||||||
hp.pollHub()
|
hp.pollHub()
|
||||||
} else {
|
} else {
|
||||||
hp.applyNetworkState(state)
|
hp.applyNetworkState(state)
|
||||||
@ -70,25 +74,25 @@ func (hp *hubPoller) pollHub() {
|
|||||||
|
|
||||||
resp, err := hp.client.Do(hp.req)
|
resp, err := hp.client.Do(hp.req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to fetch peer state: %v", err)
|
hp.logf("Failed to fetch peer state: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to read body from hub: %v", err)
|
hp.logf("Failed to read body from hub: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &state); err != nil {
|
if err := json.Unmarshal(body, &state); err != nil {
|
||||||
log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body)
|
hp.logf("Failed to unmarshal response from hub: %v\n%s", err, body)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hp.applyNetworkState(state)
|
hp.applyNetworkState(state)
|
||||||
|
|
||||||
if err := storeNetworkState(hp.netName, state); err != nil {
|
if err := storeNetworkState(hp.netName, state); err != nil {
|
||||||
log.Printf("Failed to store network state: %v", err)
|
hp.logf("Failed to store network state: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import "log"
|
|
||||||
|
|
||||||
func logPacket(p []byte, notes string) {
|
|
||||||
h := parseHeader(p)
|
|
||||||
log.Printf(`Sending: Data: %v | From: %d | To: %d | %s
|
|
||||||
`,
|
|
||||||
h.StreamID == dataStreamID,
|
|
||||||
h.SourceIP,
|
|
||||||
h.DestIP,
|
|
||||||
notes)
|
|
||||||
}
|
|
@ -12,12 +12,12 @@ func runMCReader(
|
|||||||
handleControlMsg func(destIP byte, msg any),
|
handleControlMsg func(destIP byte, msg any),
|
||||||
) {
|
) {
|
||||||
for {
|
for {
|
||||||
runMCReader2(rt, handleControlMsg)
|
runMCReaderInner(rt, handleControlMsg)
|
||||||
time.Sleep(8 * time.Second)
|
time.Sleep(broadcastErrorTimeoutInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runMCReader2(
|
func runMCReaderInner(
|
||||||
rt *atomic.Pointer[routingTable],
|
rt *atomic.Pointer[routingTable],
|
||||||
handleControlMsg func(destIP byte, msg any),
|
handleControlMsg func(destIP byte, msg any),
|
||||||
) {
|
) {
|
||||||
|
@ -41,10 +41,10 @@ func runMCWriter(localIP byte, signingKey []byte) {
|
|||||||
|
|
||||||
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)
|
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to bind to multicast address: %v", err)
|
log.Fatalf("[MCWriter] Failed to bind to multicast address: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for range time.Tick(8 * time.Second) {
|
for range time.Tick(broadcastInterval) {
|
||||||
_, err := conn.WriteToUDP(discoveryPacket, multicastAddr)
|
_, err := conn.WriteToUDP(discoveryPacket, multicastAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[MCWriter] Failed to write multicast: %v", err)
|
log.Printf("[MCWriter] Failed to write multicast: %v", err)
|
||||||
|
28
peer/peer.go
28
peer/peer.go
@ -32,10 +32,14 @@ type peerConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newPeerMain(conf peerConfig) *peerMain {
|
func newPeerMain(conf peerConfig) *peerMain {
|
||||||
|
logf := func(s string, args ...any) {
|
||||||
|
log.Printf("[Main] "+s, args...)
|
||||||
|
}
|
||||||
|
|
||||||
config, err := loadPeerConfig(conf.NetName)
|
config, err := loadPeerConfig(conf.NetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to load configuration: %v", err)
|
logf("Failed to load configuration: %v", err)
|
||||||
log.Printf("Initializing...")
|
logf("Initializing...")
|
||||||
initPeerWithHub(conf)
|
initPeerWithHub(conf)
|
||||||
|
|
||||||
config, err = loadPeerConfig(conf.NetName)
|
config, err = loadPeerConfig(conf.NetName)
|
||||||
@ -54,7 +58,7 @@ func newPeerMain(conf peerConfig) *peerMain {
|
|||||||
log.Fatalf("Failed to resolve UDP address: %v", err)
|
log.Fatalf("Failed to resolve UDP address: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Listening on %v...", myAddr)
|
logf("Listening on %v...", myAddr)
|
||||||
conn, err := net.ListenUDP("udp", myAddr)
|
conn, err := net.ListenUDP("udp", myAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to open UDP port: %v", err)
|
log.Fatalf("Failed to open UDP port: %v", err)
|
||||||
@ -69,15 +73,15 @@ func newPeerMain(conf peerConfig) *peerMain {
|
|||||||
writeLock.Lock()
|
writeLock.Lock()
|
||||||
n, err = conn.WriteToUDPAddrPort(b, addr)
|
n, err = conn.WriteToUDPAddrPort(b, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to write packet: %v", err)
|
logf("Failed to write packet: %v", err)
|
||||||
}
|
}
|
||||||
writeLock.Unlock()
|
writeLock.Unlock()
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var localAddr netip.AddrPort
|
var localAddr netip.AddrPort
|
||||||
ip, ok := netip.AddrFromSlice(config.PublicIP)
|
ip, localAddrValid := netip.AddrFromSlice(config.PublicIP)
|
||||||
if ok {
|
if localAddrValid {
|
||||||
localAddr = netip.AddrPortFrom(ip, config.Port)
|
localAddr = netip.AddrPortFrom(ip, config.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,12 +109,18 @@ func newPeerMain(conf peerConfig) *peerMain {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *peerMain) Run() {
|
func (p *peerMain) Run() {
|
||||||
|
|
||||||
go p.ifReader.Run()
|
go p.ifReader.Run()
|
||||||
go p.connReader.Run()
|
go p.connReader.Run()
|
||||||
p.super.Start()
|
p.super.Start()
|
||||||
go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey)
|
|
||||||
go runMCReader(p.rt, p.super.HandleControlMsg)
|
if !p.rt.Load().LocalAddr.IsValid() {
|
||||||
p.hubPoller.Run()
|
go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey)
|
||||||
|
go runMCReader(p.rt, p.super.HandleControlMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.hubPoller.Run()
|
||||||
|
select {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func initPeerWithHub(conf peerConfig) {
|
func initPeerWithHub(conf peerConfig) {
|
||||||
|
162
peer/state-client.go
Normal file
162
peer/state-client.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sentProbe struct {
|
||||||
|
SentAt time.Time
|
||||||
|
Addr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateClient struct {
|
||||||
|
*peerData
|
||||||
|
lastSeen time.Time
|
||||||
|
syn packetSyn
|
||||||
|
probes map[uint64]sentProbe
|
||||||
|
}
|
||||||
|
|
||||||
|
func enterStateClient(data *peerData) peerState {
|
||||||
|
ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP)
|
||||||
|
|
||||||
|
data.staged.Relay = data.peer.Relay && ipValid
|
||||||
|
data.staged.Direct = ipValid
|
||||||
|
data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port)
|
||||||
|
data.publish(data.staged)
|
||||||
|
|
||||||
|
state := &stateClient{
|
||||||
|
peerData: data,
|
||||||
|
lastSeen: time.Now(),
|
||||||
|
syn: packetSyn{
|
||||||
|
TraceID: newTraceID(),
|
||||||
|
SharedKey: data.staged.DataCipher.Key(),
|
||||||
|
Direct: data.staged.Direct,
|
||||||
|
PossibleAddrs: data.pubAddrs.Get(),
|
||||||
|
},
|
||||||
|
probes: map[uint64]sentProbe{},
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Send(state.staged, state.syn)
|
||||||
|
|
||||||
|
data.pingTimer.Reset(pingInterval)
|
||||||
|
|
||||||
|
state.logf("==> Client")
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) logf(str string, args ...any) {
|
||||||
|
s.peerData.logf("CLNT | "+str, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) OnMsg(raw any) peerState {
|
||||||
|
switch msg := raw.(type) {
|
||||||
|
case peerUpdateMsg:
|
||||||
|
return initPeerState(s.peerData, msg.Peer)
|
||||||
|
case controlMsg[packetAck]:
|
||||||
|
s.onAck(msg)
|
||||||
|
case controlMsg[packetProbe]:
|
||||||
|
return s.onProbe(msg)
|
||||||
|
case controlMsg[packetLocalDiscovery]:
|
||||||
|
s.onLocalDiscovery(msg)
|
||||||
|
case pingTimerMsg:
|
||||||
|
return s.onPingTimer()
|
||||||
|
default:
|
||||||
|
s.logf("Ignoring message: %v", raw)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) onAck(msg controlMsg[packetAck]) {
|
||||||
|
if msg.Packet.TraceID != s.syn.TraceID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastSeen = time.Now()
|
||||||
|
|
||||||
|
if !s.staged.Up {
|
||||||
|
s.staged.Up = true
|
||||||
|
s.publish(s.staged)
|
||||||
|
s.logf("Got ACK.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.staged.Direct {
|
||||||
|
s.pubAddrs.Store(msg.Packet.ToAddr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relayed below.
|
||||||
|
|
||||||
|
s.cleanProbes()
|
||||||
|
|
||||||
|
for _, addr := range msg.Packet.PossibleAddrs {
|
||||||
|
if !addr.IsValid() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
s.sendProbeTo(addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) onPingTimer() peerState {
|
||||||
|
if time.Since(s.lastSeen) > timeoutInterval {
|
||||||
|
if s.staged.Up {
|
||||||
|
s.logf("Timeout.")
|
||||||
|
}
|
||||||
|
return initPeerState(s.peerData, s.peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Send(s.staged, s.syn)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) onProbe(msg controlMsg[packetProbe]) peerState {
|
||||||
|
if s.staged.Direct {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cleanProbes()
|
||||||
|
|
||||||
|
sent, ok := s.probes[msg.Packet.TraceID]
|
||||||
|
if !ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
s.staged.Direct = true
|
||||||
|
s.staged.DirectAddr = sent.Addr
|
||||||
|
s.publish(s.staged)
|
||||||
|
|
||||||
|
s.syn.TraceID = newTraceID()
|
||||||
|
s.syn.Direct = true
|
||||||
|
s.Send(s.staged, s.syn)
|
||||||
|
s.logf("Successful probe.")
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
|
||||||
|
if s.staged.Direct {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// The source port will be the multicast port, so we'll have to
|
||||||
|
// construct the correct address using the peer's listed port.
|
||||||
|
addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
|
||||||
|
s.sendProbeTo(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) cleanProbes() {
|
||||||
|
for key, sent := range s.probes {
|
||||||
|
if time.Since(sent.SentAt) > pingInterval {
|
||||||
|
delete(s.probes, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateClient) sendProbeTo(addr netip.AddrPort) {
|
||||||
|
probe := packetProbe{TraceID: newTraceID()}
|
||||||
|
s.probes[probe.TraceID] = sentProbe{
|
||||||
|
SentAt: time.Now(),
|
||||||
|
Addr: addr,
|
||||||
|
}
|
||||||
|
s.logf("Probing %v...", addr)
|
||||||
|
s.SendTo(probe, addr)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user