refactor-for-testability #3
| @@ -84,6 +84,7 @@ func (r *connReader) handleControlPacket( | |||||||
| 	enc []byte, | 	enc []byte, | ||||||
| ) { | ) { | ||||||
| 	if peer.ControlCipher == nil { | 	if peer.ControlCipher == nil { | ||||||
|  | 		log.Printf("No control cipher for peer: %v", h) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -125,13 +126,13 @@ func (r *connReader) handleDataPacket( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	relay, ok := rt.GetRelay() | 	remote := rt.Peers[h.DestIP] | ||||||
| 	if !ok { | 	if !remote.Direct { | ||||||
| 		r.logf("Relay not available.") | 		r.logf("Unable to relay data to %d.", h.DestIP) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r.writeToUDPAddrPort(data, relay.DirectAddr) | 	r.writeToUDPAddrPort(data, remote.DirectAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *connReader) logf(format string, args ...any) { | func (r *connReader) logf(format string, args ...any) { | ||||||
|   | |||||||
| @@ -41,6 +41,14 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | |||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
|  | 	case packetTypeInit: | ||||||
|  | 		packet, err := parsePacketInit(buf) | ||||||
|  | 		return controlMsg[packetInit]{ | ||||||
|  | 			SrcIP:   srcIP, | ||||||
|  | 			SrcAddr: srcAddr, | ||||||
|  | 			Packet:  packet, | ||||||
|  | 		}, err | ||||||
|  |  | ||||||
| 	default: | 	default: | ||||||
| 		return nil, errUnknownPacketType | 		return nil, errUnknownPacketType | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -7,6 +7,8 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|  | 	version = 1 | ||||||
|  |  | ||||||
| 	bufferSize = 1536 | 	bufferSize = 1536 | ||||||
|  |  | ||||||
| 	if_mtu       = 1200 | 	if_mtu       = 1200 | ||||||
|   | |||||||
							
								
								
									
										13
									
								
								peer/logging.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								peer/logging.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "log" | ||||||
|  |  | ||||||
|  | func logPacket(p []byte, notes string) { | ||||||
|  | 	h := parseHeader(p) | ||||||
|  | 	log.Printf(`Sending: Data: %v | From: %d | To:   %d | %s | ||||||
|  | `, | ||||||
|  | 		h.StreamID == dataStreamID, | ||||||
|  | 		h.SourceIP, | ||||||
|  | 		h.DestIP, | ||||||
|  | 		notes) | ||||||
|  | } | ||||||
| @@ -6,19 +6,38 @@ import ( | |||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	packetTypeSyn           = 1 | 	packetTypeSyn           = 1 | ||||||
|  | 	packetTypeInit          = 2 | ||||||
| 	packetTypeAck           = 3 | 	packetTypeAck           = 3 | ||||||
| 	packetTypeProbe         = 4 | 	packetTypeProbe         = 4 | ||||||
| 	packetTypeAddrDiscovery = 5 | 	packetTypeAddrDiscovery = 5 | ||||||
| 	packetTypeInit          = 6 |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type packetInit struct { | type packetInit struct { | ||||||
| 	TraceID uint64 | 	TraceID uint64 | ||||||
|  | 	Direct  bool | ||||||
| 	Version uint64 | 	Version uint64 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (p packetInit) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf). | ||||||
|  | 		Byte(packetTypeInit). | ||||||
|  | 		Uint64(p.TraceID). | ||||||
|  | 		Bool(p.Direct). | ||||||
|  | 		Uint64(p.Version). | ||||||
|  | 		Build() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parsePacketInit(buf []byte) (p packetInit, err error) { | ||||||
|  | 	err = newBinReader(buf[1:]). | ||||||
|  | 		Uint64(&p.TraceID). | ||||||
|  | 		Bool(&p.Direct). | ||||||
|  | 		Uint64(&p.Version). | ||||||
|  | 		Error() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type packetSyn struct { | type packetSyn struct { | ||||||
|   | |||||||
| @@ -43,6 +43,7 @@ type pState struct { | |||||||
| 	limiter *ratelimiter.Limiter | 	limiter *ratelimiter.Limiter | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /* | ||||||
| func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { | func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		// Don't defer directly otherwise s.staged will be evaluated immediately | 		// Don't defer directly otherwise s.staged will be evaluated immediately | ||||||
| @@ -78,7 +79,7 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { | |||||||
| 			return enterStateServer(s) | 			return enterStateServer(s) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		return enterStateClientDirect(s) | 		return enterStateClientinit(s) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if s.localAddr.IsValid() { | 	if s.localAddr.IsValid() { | ||||||
| @@ -90,8 +91,9 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { | |||||||
| 		return enterStateServer(s) | 		return enterStateServer(s) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return enterStateClientRelayed(s) | 	return enterStateClientinit(s) | ||||||
| } | } | ||||||
|  | */ | ||||||
|  |  | ||||||
| func (s *pState) logf(format string, args ...any) { | func (s *pState) logf(format string, args ...any) { | ||||||
| 	b := strings.Builder{} | 	b := strings.Builder{} | ||||||
| @@ -140,6 +142,7 @@ func (s *pState) Send(peer remotePeer, pkt marshaller) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | /* | ||||||
| type stateDisconnected struct{ *pState } | type stateDisconnected struct{ *pState } | ||||||
|  |  | ||||||
| func enterStateDisconnected(s *pState) peerState { | func enterStateDisconnected(s *pState) peerState { | ||||||
| @@ -181,6 +184,8 @@ func (s *stateServer) OnMsg(rawMsg any) peerState { | |||||||
| 	switch msg := rawMsg.(type) { | 	switch msg := rawMsg.(type) { | ||||||
| 	case peerUpdateMsg: | 	case peerUpdateMsg: | ||||||
| 		return s.OnPeerUpdate(msg.Peer) | 		return s.OnPeerUpdate(msg.Peer) | ||||||
|  | 	case controlMsg[packetInit]: | ||||||
|  | 		return s.OnInit(msg) | ||||||
| 	case controlMsg[packetSyn]: | 	case controlMsg[packetSyn]: | ||||||
| 		return s.OnSyn(msg) | 		return s.OnSyn(msg) | ||||||
| 	case controlMsg[packetProbe]: | 	case controlMsg[packetProbe]: | ||||||
| @@ -193,6 +198,21 @@ func (s *stateServer) OnMsg(rawMsg any) peerState { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *stateServer) OnInit(msg controlMsg[packetInit]) peerState { | ||||||
|  | 	s.logf("Responding to INIT.") | ||||||
|  | 	route := s.staged | ||||||
|  | 	route.Direct = msg.Packet.Direct | ||||||
|  | 	route.DirectAddr = msg.SrcAddr | ||||||
|  |  | ||||||
|  | 	s.Send(route, packetInit{ | ||||||
|  | 		TraceID: msg.Packet.TraceID, | ||||||
|  | 		Direct:  route.Direct, | ||||||
|  | 		Version: version, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { | func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { | ||||||
| 	s.lastSeen = time.Now() | 	s.lastSeen = time.Now() | ||||||
| 	p := msg.Packet | 	p := msg.Packet | ||||||
| @@ -252,6 +272,75 @@ func (s *stateServer) OnPingTimer() peerState { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type stateClientInit struct { | ||||||
|  | 	*stateDisconnected | ||||||
|  | 	startedAt time.Time | ||||||
|  | 	traceID   uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateClientinit(s *pState) peerState { | ||||||
|  | 	s.logf("==> ClientInit") | ||||||
|  | 	s.pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
|  | 	state := &stateClientInit{ | ||||||
|  | 		stateDisconnected: &stateDisconnected{s}, | ||||||
|  | 		startedAt:         time.Now(), | ||||||
|  | 		traceID:           newTraceID(), | ||||||
|  | 	} | ||||||
|  | 	state.Send(s.staged, packetInit{ | ||||||
|  | 		TraceID: state.traceID, | ||||||
|  | 		Direct:  s.staged.Direct, | ||||||
|  | 		Version: version, | ||||||
|  | 	}) | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return s.OnPeerUpdate(msg.Peer) | ||||||
|  | 	case controlMsg[packetInit]: | ||||||
|  | 		return s.onInit(msg) | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.onPing() | ||||||
|  | 	default: | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState { | ||||||
|  | 	if msg.Packet.TraceID != s.traceID { | ||||||
|  | 		s.logf("Invalid trace ID on INIT.") | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | 	s.logf("Got INIT version %d.", msg.Packet.Version) | ||||||
|  | 	return s.nextState() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit) onPing() peerState { | ||||||
|  | 	if time.Since(s.startedAt) > timeoutInterval { | ||||||
|  | 		s.logf("Init timeout. Assuming version 1.") | ||||||
|  | 		return s.nextState() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.traceID = newTraceID() | ||||||
|  | 	s.Send(s.staged, packetInit{ | ||||||
|  | 		TraceID: s.traceID, | ||||||
|  | 		Direct:  s.staged.Direct, | ||||||
|  | 		Version: version, | ||||||
|  | 	}) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit) nextState() peerState { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return enterStateClientDirect(s.pState) | ||||||
|  | 	} | ||||||
|  | 	return enterStateClientRelayed(s.pState) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type stateClientDirect struct { | type stateClientDirect struct { | ||||||
| 	*stateDisconnected | 	*stateDisconnected | ||||||
| 	lastSeen time.Time | 	lastSeen time.Time | ||||||
| @@ -418,3 +507,4 @@ func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { | |||||||
| 	s.probes[probe.TraceID] = addr | 	s.probes[probe.TraceID] = addr | ||||||
| 	s.SendTo(probe, addr) | 	s.SendTo(probe, addr) | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -53,6 +53,10 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | |||||||
| 	h.State = h.State.OnMsg(peerUpdateMsg{p}) | 	h.State = h.State.OnMsg(peerUpdateMsg{p}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) OnInit(msg controlMsg[packetInit]) { | ||||||
|  | 	h.State = h.State.OnMsg(msg) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { | func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { | ||||||
| 	h.State = h.State.OnMsg(msg) | 	h.State = h.State.OnMsg(msg) | ||||||
| } | } | ||||||
| @@ -110,6 +114,7 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDire | |||||||
|  |  | ||||||
| 	h.PeerUpdate(peer) | 	h.PeerUpdate(peer) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
| 	return assertType[*stateClientDirect](t, h.State) | 	return assertType[*stateClientDirect](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -124,7 +124,7 @@ type peerSuper struct { | |||||||
| func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { | func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { | ||||||
| 	return &peerSuper{ | 	return &peerSuper{ | ||||||
| 		messages:  make(chan any, 8), | 		messages:  make(chan any, 8), | ||||||
| 		state:     state.OnPeerUpdate(nil), | 		state:     initPeerState(state, nil), | ||||||
| 		pingTimer: pingTimer, | 		pingTimer: pingTimer, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,9 +1,7 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"log" |  | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"runtime/debug" |  | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -27,14 +25,13 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (store *pubAddrStore) Store(add netip.AddrPort) { | func (store *pubAddrStore) Store(add netip.AddrPort) { | ||||||
| 	store.lock.Lock() |  | ||||||
| 	defer store.lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if store.localPub { | 	if store.localPub { | ||||||
| 		log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) |  | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	store.lock.Lock() | ||||||
|  | 	defer store.lock.Unlock() | ||||||
|  |  | ||||||
| 	if !add.IsValid() { | 	if !add.IsValid() { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										85
									
								
								peer/state-clientdirect.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								peer/state-clientdirect.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,85 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type stateClientDirect2 struct { | ||||||
|  | 	*peerData | ||||||
|  | 	lastSeen time.Time | ||||||
|  | 	syn      packetSyn | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateClientDirect2(data *peerData, directAddr netip.AddrPort) peerState { | ||||||
|  | 	data.staged.Relay = data.peer.Relay | ||||||
|  | 	data.staged.Direct = true | ||||||
|  | 	data.staged.DirectAddr = directAddr | ||||||
|  | 	data.publish(data.staged) | ||||||
|  |  | ||||||
|  | 	state := &stateClientDirect2{ | ||||||
|  | 		peerData: data, | ||||||
|  | 		lastSeen: time.Now(), | ||||||
|  | 		syn: packetSyn{ | ||||||
|  | 			TraceID:   newTraceID(), | ||||||
|  | 			SharedKey: data.staged.DataCipher.Key(), | ||||||
|  | 			Direct:    true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	state.Send(state.staged, state.syn) | ||||||
|  |  | ||||||
|  | 	data.pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
|  | 	state.logf("==> ClientDirect") | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientDirect2) logf(str string, args ...any) { | ||||||
|  | 	s.peerData.logf("CLNT | "+str, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientDirect2) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return initPeerState(s.peerData, msg.Peer) | ||||||
|  | 	case controlMsg[packetAck]: | ||||||
|  | 		return s.onAck(msg) | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.onPingTimer() | ||||||
|  | 	case controlMsg[packetLocalDiscovery]: | ||||||
|  | 		return s | ||||||
|  | 	default: | ||||||
|  | 		s.logf("Ignoring message: %v", raw) | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientDirect2) onAck(msg controlMsg[packetAck]) peerState { | ||||||
|  | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  |  | ||||||
|  | 	if !s.staged.Up { | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Got ACK.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientDirect2) onPingTimer() peerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval { | ||||||
|  | 		if s.staged.Up { | ||||||
|  | 			s.logf("Timeout.") | ||||||
|  | 		} | ||||||
|  | 		return initPeerState(s.peerData, s.peer) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Send(s.staged, s.syn) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
							
								
								
									
										93
									
								
								peer/state-clientinit.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								peer/state-clientinit.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type stateClientInit2 struct { | ||||||
|  | 	*peerData | ||||||
|  | 	startedAt time.Time | ||||||
|  | 	traceID   uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateClientInit2(data *peerData) peerState { | ||||||
|  | 	ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) | ||||||
|  |  | ||||||
|  | 	data.staged.Up = false | ||||||
|  | 	data.staged.Relay = false | ||||||
|  | 	data.staged.Direct = ipValid | ||||||
|  | 	data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) | ||||||
|  | 	data.staged.PubSignKey = data.peer.PubSignKey | ||||||
|  | 	data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) | ||||||
|  | 	data.staged.DataCipher = newDataCipher() | ||||||
|  |  | ||||||
|  | 	data.publish(data.staged) | ||||||
|  |  | ||||||
|  | 	state := &stateClientInit2{ | ||||||
|  | 		peerData:  data, | ||||||
|  | 		startedAt: time.Now(), | ||||||
|  | 		traceID:   newTraceID(), | ||||||
|  | 	} | ||||||
|  | 	state.sendInit() | ||||||
|  |  | ||||||
|  | 	data.pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
|  | 	state.logf("==> ClientInit") | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit2) logf(str string, args ...any) { | ||||||
|  | 	s.peerData.logf("INIT | "+str, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit2) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return initPeerState(s.peerData, msg.Peer) | ||||||
|  | 	case controlMsg[packetInit]: | ||||||
|  | 		return s.onInit(msg) | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.onPing() | ||||||
|  | 	default: | ||||||
|  | 		s.logf("Ignoring message: %v", raw) | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { | ||||||
|  | 	if msg.Packet.TraceID != s.traceID { | ||||||
|  | 		s.logf("Invalid trace ID on INIT.") | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | 	s.logf("Got INIT version %d.", msg.Packet.Version) | ||||||
|  | 	return s.nextState() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit2) onPing() peerState { | ||||||
|  | 	if time.Since(s.startedAt) > timeoutInterval { | ||||||
|  | 		s.logf("Init timeout. Assuming version 1.") | ||||||
|  | 		return s.nextState() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.sendInit() | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit2) sendInit() { | ||||||
|  | 	s.traceID = newTraceID() | ||||||
|  | 	init := packetInit{ | ||||||
|  | 		TraceID: s.traceID, | ||||||
|  | 		Direct:  s.staged.Direct, | ||||||
|  | 		Version: version, | ||||||
|  | 	} | ||||||
|  | 	s.Send(s.staged, init) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientInit2) nextState() peerState { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return enterStateClientDirect2(s.peerData, s.staged.DirectAddr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return enterStateClientRelayed2(s.peerData) | ||||||
|  | } | ||||||
							
								
								
									
										142
									
								
								peer/state-clientrelayed.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								peer/state-clientrelayed.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,142 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type sentProbe struct { | ||||||
|  | 	SentAt time.Time | ||||||
|  | 	Addr   netip.AddrPort | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type stateClientRelayed2 struct { | ||||||
|  | 	*peerData | ||||||
|  | 	lastSeen time.Time | ||||||
|  | 	syn      packetSyn | ||||||
|  | 	probes   map[uint64]sentProbe | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateClientRelayed2(data *peerData) peerState { | ||||||
|  | 	data.staged.Relay = false | ||||||
|  | 	data.staged.Direct = false | ||||||
|  | 	data.staged.DirectAddr = netip.AddrPort{} | ||||||
|  | 	data.publish(data.staged) | ||||||
|  |  | ||||||
|  | 	state := &stateClientRelayed2{ | ||||||
|  | 		peerData: data, | ||||||
|  | 		lastSeen: time.Now(), | ||||||
|  | 		syn: packetSyn{ | ||||||
|  | 			TraceID:       newTraceID(), | ||||||
|  | 			SharedKey:     data.staged.DataCipher.Key(), | ||||||
|  | 			Direct:        false, | ||||||
|  | 			PossibleAddrs: data.pubAddrs.Get(), | ||||||
|  | 		}, | ||||||
|  | 		probes: map[uint64]sentProbe{}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	state.Send(state.staged, state.syn) | ||||||
|  |  | ||||||
|  | 	data.pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
|  | 	state.logf("==> ClientRelayed") | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) logf(str string, args ...any) { | ||||||
|  | 	s.peerData.logf("CLNT | "+str, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return initPeerState(s.peerData, msg.Peer) | ||||||
|  | 	case controlMsg[packetAck]: | ||||||
|  | 		return s.onAck(msg) | ||||||
|  | 	case controlMsg[packetProbe]: | ||||||
|  | 		return s.onProbe(msg) | ||||||
|  | 	case controlMsg[packetLocalDiscovery]: | ||||||
|  | 		return s.onLocalDiscovery(msg) | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.onPingTimer() | ||||||
|  | 	default: | ||||||
|  | 		s.logf("Ignoring message: %v", raw) | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { | ||||||
|  | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  |  | ||||||
|  | 	if !s.staged.Up { | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Got ACK.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||||
|  |  | ||||||
|  | 	for _, addr := range msg.Packet.PossibleAddrs { | ||||||
|  | 		if !addr.IsValid() { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		s.sendProbeTo(addr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.cleanProbes() | ||||||
|  |  | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) onPingTimer() peerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval { | ||||||
|  | 		if s.staged.Up { | ||||||
|  | 			s.logf("Timeout.") | ||||||
|  | 		} | ||||||
|  | 		return initPeerState(s.peerData, s.peer) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Send(s.staged, s.syn) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { | ||||||
|  | 	s.cleanProbes() | ||||||
|  |  | ||||||
|  | 	sent, ok := s.probes[msg.Packet.TraceID] | ||||||
|  | 	if !ok { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.logf("Successful probe.") | ||||||
|  | 	return enterStateClientDirect2(s.peerData, sent.Addr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) peerState { | ||||||
|  | 	// The source port will be the multicast port, so we'll have to | ||||||
|  | 	// construct the correct address using the peer's listed port. | ||||||
|  | 	addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
|  | 	s.sendProbeTo(addr) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) cleanProbes() { | ||||||
|  | 	for key, sent := range s.probes { | ||||||
|  | 		if time.Since(sent.SentAt) > pingInterval { | ||||||
|  | 			delete(s.probes, key) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) sendProbeTo(addr netip.AddrPort) { | ||||||
|  | 	probe := packetProbe{TraceID: newTraceID()} | ||||||
|  | 	s.probes[probe.TraceID] = sentProbe{ | ||||||
|  | 		SentAt: time.Now(), | ||||||
|  | 		Addr:   addr, | ||||||
|  | 	} | ||||||
|  | 	s.SendTo(probe, addr) | ||||||
|  | } | ||||||
							
								
								
									
										33
									
								
								peer/state-disconnected.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								peer/state-disconnected.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "net/netip" | ||||||
|  |  | ||||||
|  | type stateDisconnected2 struct { | ||||||
|  | 	*peerData | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateDisconnected2(data *peerData) peerState { | ||||||
|  | 	data.staged.Up = false | ||||||
|  | 	data.staged.Relay = false | ||||||
|  | 	data.staged.Direct = false | ||||||
|  | 	data.staged.DirectAddr = netip.AddrPort{} | ||||||
|  | 	data.staged.PubSignKey = nil | ||||||
|  | 	data.staged.ControlCipher = nil | ||||||
|  | 	data.staged.DataCipher = nil | ||||||
|  |  | ||||||
|  | 	data.publish(data.staged) | ||||||
|  |  | ||||||
|  | 	data.pingTimer.Stop() | ||||||
|  |  | ||||||
|  | 	return &stateDisconnected2{data} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateDisconnected2) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return initPeerState(s.peerData, msg.Peer) | ||||||
|  | 	default: | ||||||
|  | 		s.logf("Ignoring message: %v", raw) | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										127
									
								
								peer/state-server.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								peer/state-server.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,127 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type stateServer2 struct { | ||||||
|  | 	*peerData | ||||||
|  | 	lastSeen   time.Time | ||||||
|  | 	synTraceID uint64 // Last syn trace ID. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateServer2(data *peerData) peerState { | ||||||
|  | 	data.staged.Up = false | ||||||
|  | 	data.staged.Relay = false | ||||||
|  | 	data.staged.Direct = false | ||||||
|  | 	data.staged.DirectAddr = netip.AddrPort{} | ||||||
|  | 	data.staged.PubSignKey = data.peer.PubSignKey | ||||||
|  | 	data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) | ||||||
|  | 	data.staged.DataCipher = nil | ||||||
|  |  | ||||||
|  | 	data.publish(data.staged) | ||||||
|  |  | ||||||
|  | 	data.pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
|  | 	state := &stateServer2{peerData: data} | ||||||
|  | 	state.logf("==> Server") | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer2) logf(str string, args ...any) { | ||||||
|  | 	s.peerData.logf("SRVR | "+str, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer2) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return initPeerState(s.peerData, msg.Peer) | ||||||
|  | 	case controlMsg[packetInit]: | ||||||
|  | 		return s.onInit(msg) | ||||||
|  | 	case controlMsg[packetSyn]: | ||||||
|  | 		return s.onSyn(msg) | ||||||
|  | 	case controlMsg[packetProbe]: | ||||||
|  | 		return s.onProbe(msg) | ||||||
|  | 	case controlMsg[packetLocalDiscovery]: | ||||||
|  | 		return s | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.onPingTimer() | ||||||
|  | 	default: | ||||||
|  | 		s.logf("Ignoring message: %v", raw) | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer2) onInit(msg controlMsg[packetInit]) peerState { | ||||||
|  | 	s.staged.Up = false | ||||||
|  | 	s.staged.Direct = msg.Packet.Direct | ||||||
|  | 	s.staged.DirectAddr = msg.SrcAddr | ||||||
|  | 	s.publish(s.staged) | ||||||
|  |  | ||||||
|  | 	init := packetInit{ | ||||||
|  | 		TraceID: msg.Packet.TraceID, | ||||||
|  | 		Direct:  s.staged.Direct, | ||||||
|  | 		Version: version, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Send(s.staged, init) | ||||||
|  |  | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  | 	p := msg.Packet | ||||||
|  |  | ||||||
|  | 	// Before we can respond to this packet, we need to make sure the | ||||||
|  | 	// route is setup properly. | ||||||
|  | 	// | ||||||
|  | 	// The client will update the syn's TraceID whenever there's a change. | ||||||
|  | 	// The server will follow the client's request. | ||||||
|  | 	if p.TraceID != s.synTraceID || !s.staged.Up { | ||||||
|  | 		s.synTraceID = p.TraceID | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.staged.Direct = p.Direct | ||||||
|  | 		s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) | ||||||
|  | 		s.staged.DirectAddr = msg.SrcAddr | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Got SYN.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Always respond. | ||||||
|  | 	s.Send(s.staged, packetAck{ | ||||||
|  | 		TraceID:       p.TraceID, | ||||||
|  | 		ToAddr:        s.staged.DirectAddr, | ||||||
|  | 		PossibleAddrs: s.pubAddrs.Get(), | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if p.Direct { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, addr := range msg.Packet.PossibleAddrs { | ||||||
|  | 		if !addr.IsValid() { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		s.SendTo(packetProbe{TraceID: newTraceID()}, addr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { | ||||||
|  | 	if msg.SrcAddr.IsValid() { | ||||||
|  | 		s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||||
|  | 	} | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateServer2) onPingTimer() peerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | ||||||
|  | 		s.staged.Up = false | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Timeout.") | ||||||
|  | 	} | ||||||
|  | 	return s | ||||||
|  | } | ||||||
							
								
								
									
										28
									
								
								peer/statedata.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								peer/statedata.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"vppn/m" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type peerData = pState | ||||||
|  |  | ||||||
|  | func initPeerState(data *peerData, peer *m.Peer) peerState { | ||||||
|  | 	data.peer = peer | ||||||
|  |  | ||||||
|  | 	if peer == nil { | ||||||
|  | 		return enterStateDisconnected2(data) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||||
|  | 		if data.localAddr.IsValid() && data.localIP < data.remoteIP { | ||||||
|  | 			return enterStateServer2(data) | ||||||
|  | 		} | ||||||
|  | 		return enterStateClientInit2(data) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if data.localAddr.IsValid() || data.localIP < data.remoteIP { | ||||||
|  | 		return enterStateServer2(data) | ||||||
|  | 	} | ||||||
|  | 	return enterStateClientInit2(data) | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user