sym-encryption #1
| @@ -23,6 +23,8 @@ type peerRoute struct { | ||||
| 	ControlCipher *controlCipher | ||||
| 	DataCipher    *dataCipher | ||||
| 	RemoteAddr    netip.AddrPort // Remote address if directly connected. | ||||
| 	// TODO: Remove this and use global localAddr and relayIP. | ||||
| 	// Replace w/ a Direct boolean. | ||||
| 	LocalAddr netip.AddrPort // Local address as seen by the remote. | ||||
| 	RelayIP   byte           // Non-zero if we should relay. | ||||
| } | ||||
| @@ -32,6 +34,7 @@ var ( | ||||
| 	netName    string | ||||
| 	localIP    byte | ||||
| 	localPub   bool | ||||
| 	localAddr  netip.AddrPort | ||||
| 	privateKey []byte | ||||
|  | ||||
| 	// Shared interface for writing. | ||||
| @@ -54,4 +57,9 @@ var ( | ||||
|  | ||||
| 	// Global routing table. | ||||
| 	routingTable [256]*atomic.Pointer[peerRoute] | ||||
|  | ||||
| 	// TODO: use relay for local address discovery. This should be new stream ID, | ||||
| 	// managed by a single thread. | ||||
| 	// localAddr *atomic.Pointer[netip.AddrPort] | ||||
| 	// relayIP *atomic.Pointer[byte] | ||||
| ) | ||||
|   | ||||
| @@ -58,7 +58,6 @@ func (hp *hubPoller) Run() { | ||||
| func (hp *hubPoller) pollHub() { | ||||
| 	var state m.NetworkState | ||||
|  | ||||
| 	log.Printf("Fetching peer state...") | ||||
| 	resp, err := hp.client.Do(hp.req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to fetch peer state: %v", err) | ||||
|   | ||||
							
								
								
									
										12
									
								
								node/main.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								node/main.go
									
									
									
									
									
								
							| @@ -105,7 +105,13 @@ func main(listenIP string, port uint16) { | ||||
|  | ||||
| 	// Intialize globals. | ||||
| 	localIP = config.PeerIP | ||||
| 	localPub = addrIsValid(config.PublicIP) | ||||
|  | ||||
| 	ip, ok := netip.AddrFromSlice(config.PublicIP) | ||||
| 	if ok { | ||||
| 		localPub = true | ||||
| 		localAddr = netip.AddrPortFrom(ip, config.Port) | ||||
| 	} | ||||
|  | ||||
| 	privateKey = config.PrivKey | ||||
|  | ||||
| 	_iface = newIFWriter(iface) | ||||
| @@ -178,6 +184,8 @@ func readFromConn(conn *net.UDPConn) { | ||||
| 		case controlStreamID: | ||||
| 			handleControlPacket(remoteAddr, h, data, decBuf) | ||||
|  | ||||
| 			// TODO: discoveryStreamID | ||||
|  | ||||
| 		case dataStreamID: | ||||
| 			handleDataPacket(h, data, decBuf) | ||||
|  | ||||
| @@ -217,7 +225,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | ||||
|  | ||||
| 	pkt := controlPacket{ | ||||
| 		SrcIP:   h.SourceIP, | ||||
| 		RemoteAddr: addr, | ||||
| 		SrcAddr: addr, | ||||
| 	} | ||||
|  | ||||
| 	if err := pkt.ParsePayload(out); err != nil { | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package node | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| ) | ||||
|  | ||||
| @@ -14,13 +15,14 @@ const ( | ||||
| 	packetTypeSyn = iota + 1 | ||||
| 	packetTypeSynAck | ||||
| 	packetTypeAck | ||||
| 	packetTypeProbe | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type controlPacket struct { | ||||
| 	SrcIP   byte | ||||
| 	RemoteAddr netip.AddrPort | ||||
| 	SrcAddr netip.AddrPort | ||||
| 	Payload any | ||||
| } | ||||
|  | ||||
| @@ -30,8 +32,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { | ||||
| 		p.Payload, err = parseSynPacket(buf) | ||||
| 	case packetTypeSynAck: | ||||
| 		p.Payload, err = parseSynAckPacket(buf) | ||||
| 	case packetTypeAck: | ||||
| 		p.Payload, err = parseAckPacket(buf) | ||||
| 	case packetTypeProbe: | ||||
| 		log.Printf("Got probe...") | ||||
| 		p.Payload, err = parseProbePacket(buf) | ||||
| 	default: | ||||
| 		return errUnknownPacketType | ||||
| 	} | ||||
| @@ -44,6 +47,7 @@ type synPacket struct { | ||||
| 	TraceID   uint64   // TraceID to match response w/ request. | ||||
| 	SharedKey [32]byte // Our shared key. | ||||
| 	RelayIP   byte | ||||
| 	FromAddr  netip.AddrPort // The client's sending address. | ||||
| } | ||||
|  | ||||
| func (p synPacket) Marshal(buf []byte) []byte { | ||||
| @@ -52,6 +56,7 @@ func (p synPacket) Marshal(buf []byte) []byte { | ||||
| 		Uint64(p.TraceID). | ||||
| 		SharedKey(p.SharedKey). | ||||
| 		Byte(p.RelayIP). | ||||
| 		AddrPort(p.FromAddr). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| @@ -60,6 +65,7 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||
| 		Uint64(&p.TraceID). | ||||
| 		SharedKey(&p.SharedKey). | ||||
| 		Byte(&p.RelayIP). | ||||
| 		AddrPort(&p.FromAddr). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
| @@ -68,47 +74,54 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||
|  | ||||
| type synAckPacket struct { | ||||
| 	TraceID  uint64 | ||||
| 	RecvAddr netip.AddrPort | ||||
| 	FromAddr netip.AddrPort | ||||
| 	ToAddr   netip.AddrPort | ||||
| } | ||||
|  | ||||
| func (p synAckPacket) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeSynAck). | ||||
| 		Uint64(p.TraceID). | ||||
| 		AddrPort(p.RecvAddr). | ||||
| 		AddrPort(p.FromAddr). | ||||
| 		AddrPort(p.ToAddr). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		AddrPort(&p.RecvAddr). | ||||
| 		AddrPort(&p.FromAddr). | ||||
| 		AddrPort(&p.ToAddr). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type ackPacket struct { | ||||
| type addrDiscoveryPacket struct { | ||||
| 	TraceID  uint64 | ||||
| 	SendAddr netip.AddrPort // Address of the sender. | ||||
| 	RecvAddr netip.AddrPort // Address of the recipient as seen by sender. | ||||
| 	FromAddr netip.AddrPort | ||||
| 	ToAddr   netip.AddrPort | ||||
| } | ||||
|  | ||||
| func (p ackPacket) Marshal(buf []byte) []byte { | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // A probeReqPacket is sent from a client to a server to determine if direct | ||||
| // UDP communication can be used. | ||||
| type probePacket struct { | ||||
| 	TraceID uint64 | ||||
| } | ||||
|  | ||||
| func (p probePacket) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeAck). | ||||
| 		Byte(packetTypeProbe). | ||||
| 		Uint64(p.TraceID). | ||||
| 		AddrPort(p.SendAddr). | ||||
| 		AddrPort(p.RecvAddr). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||
| func parseProbePacket(buf []byte) (p probePacket, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		AddrPort(&p.SendAddr). | ||||
| 		AddrPort(&p.RecvAddr). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -66,6 +66,36 @@ func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte | ||||
| 	_sendControlPacket(pkt, s.staged, s.buf1, s.buf2) | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) sendControlPacketTo( | ||||
| 	pkt interface{ Marshal([]byte) []byte }, | ||||
| 	addr netip.AddrPort, | ||||
| ) { | ||||
| 	if !addr.IsValid() { | ||||
| 		s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) | ||||
| 		return | ||||
| 	} | ||||
| 	route := s.staged | ||||
| 	route.RelayIP = 0 | ||||
| 	route.RemoteAddr = addr | ||||
| 	_sendControlPacket(pkt, route, s.buf1, s.buf2) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) getLocalAddr() netip.AddrPort { | ||||
| 	if localPub { | ||||
| 		return localAddr | ||||
| 	} | ||||
|  | ||||
| 	if s.staged.RelayIP != 0 { | ||||
| 		if addr := routingTable[s.staged.RelayIP].Load().LocalAddr; addr.IsValid() { | ||||
| 			return addr | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s.staged.LocalAddr | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) logf(msg string, args ...any) { | ||||
| @@ -113,7 +143,7 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { | ||||
|  | ||||
| 	if s.remotePub == localPub { | ||||
| 		if localIP < s.remoteIP { | ||||
| 			return s.serverAccept | ||||
| 			return s.server | ||||
| 		} | ||||
| 		return s.clientInit | ||||
| 	} | ||||
| @@ -121,18 +151,13 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { | ||||
| 	if s.remotePub { | ||||
| 		return s.clientInit | ||||
| 	} | ||||
| 	return s.serverAccept | ||||
| 	return s.server | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) serverAccept() stateFunc { | ||||
| 	s.logf("STATE: server-accept") | ||||
| 	s.staged.Up = false | ||||
| 	s.staged.DataCipher = nil | ||||
| 	s.staged.RemoteAddr = zeroAddrPort | ||||
| 	s.staged.RelayIP = 0 | ||||
| 	s.publish() | ||||
| func (s *peerSupervisor) server() stateFunc { | ||||
| 	s.logf("STATE: server") | ||||
|  | ||||
| 	var syn synPacket | ||||
|  | ||||
| @@ -145,61 +170,38 @@ func (s *peerSupervisor) serverAccept() stateFunc { | ||||
| 			switch p := pkt.Payload.(type) { | ||||
|  | ||||
| 			case synPacket: | ||||
| 				// Before we can respond to this packet, we need to make sure the | ||||
| 				// route is setup properly. | ||||
| 				if p.TraceID != syn.TraceID { | ||||
| 					syn = p | ||||
| 				s.staged.RemoteAddr = pkt.RemoteAddr | ||||
| 					s.staged.Up = true | ||||
| 					s.staged.RemoteAddr = pkt.SrcAddr | ||||
| 					s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) | ||||
| 					s.staged.RelayIP = syn.RelayIP | ||||
| 					s.staged.LocalAddr = s.getLocalAddr() | ||||
| 					s.publish() | ||||
| 				s.sendControlPacket(synAckPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr}) | ||||
|  | ||||
| 			case ackPacket: | ||||
| 				if p.TraceID != syn.TraceID { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				// Publish. | ||||
| 				return s.serverConnected(syn.TraceID) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 				// We should always respond. | ||||
| 				s.sendControlPacket(synAckPacket{ | ||||
| 					TraceID:  syn.TraceID, | ||||
| 					FromAddr: s.staged.LocalAddr, | ||||
| 					ToAddr:   pkt.SrcAddr, | ||||
| 				}) | ||||
|  | ||||
| 				// If we're relayed, attempt to probe the client. | ||||
| 				if s.staged.RelayIP != 0 && syn.FromAddr.IsValid() { | ||||
| 					probe := probePacket{TraceID: newTraceID()} | ||||
| 					s.logf("SERVER sending probe %v: %v", probe, syn.FromAddr) | ||||
| 					s.sendControlPacketTo(probe, syn.FromAddr) | ||||
| 				} | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) serverConnected(traceID uint64) stateFunc { | ||||
| 	s.logf("STATE: server-connected") | ||||
| 	s.staged.Up = true | ||||
| 	s.publish() | ||||
| 	return func() stateFunc { | ||||
| 		return s._serverConnected(traceID) | ||||
| 			case probePacket: | ||||
| 				s.logf("SERVER got probe: %v", p) | ||||
| 				s.logf("SERVER sending probe: %v", pkt.SrcAddr) | ||||
| 				s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc { | ||||
|  | ||||
| 	timeoutTimer := time.NewTimer(timeoutInterval) | ||||
| 	defer timeoutTimer.Stop() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case peer := <-s.peerUpdates: | ||||
| 			return s.peerUpdate(peer) | ||||
|  | ||||
| 		case pkt := <-s.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
|  | ||||
| 			case ackPacket: | ||||
| 				if p.TraceID != traceID { | ||||
| 					return s.serverAccept | ||||
| 				} | ||||
| 				s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr}) | ||||
| 				timeoutTimer.Reset(timeoutInterval) | ||||
| 			} | ||||
|  | ||||
| 		case <-timeoutTimer.C: | ||||
| 			s.logf("server timeout") | ||||
| 			return s.serverAccept | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -208,13 +210,10 @@ func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc { | ||||
| func (s *peerSupervisor) clientInit() stateFunc { | ||||
| 	s.logf("STATE: client-init") | ||||
| 	if !s.remotePub { | ||||
| 		// TODO: Check local discovery for IP. | ||||
| 		// TODO: Attempt UDP hole punch. | ||||
| 		// TODO: client-relayed | ||||
| 		return s.clientSelectRelay | ||||
| 	} | ||||
|  | ||||
| 	return s.clientDial | ||||
| 	return s.client | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -237,7 +236,7 @@ func (s *peerSupervisor) clientSelectRelay() stateFunc { | ||||
| 				s.staged.RelayIP = relay.IP | ||||
| 				s.staged.LocalAddr = relay.LocalAddr | ||||
| 				s.publish() | ||||
| 				return s.clientDial | ||||
| 				return s.client | ||||
| 			} | ||||
|  | ||||
| 			s.logf("No relay available.") | ||||
| @@ -264,20 +263,26 @@ func (s *peerSupervisor) selectRelay() *peerRoute { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) clientDial() stateFunc { | ||||
| 	s.logf("STATE: client-dial") | ||||
| func (s *peerSupervisor) client() stateFunc { | ||||
| 	s.logf("STATE: client") | ||||
|  | ||||
| 	var ( | ||||
| 		syn = synPacket{ | ||||
| 			TraceID:   newTraceID(), | ||||
| 			SharedKey: s.staged.DataCipher.Key(), | ||||
| 			RelayIP:   s.staged.RelayIP, | ||||
| 			FromAddr:  s.getLocalAddr(), | ||||
| 		} | ||||
| 		ack synAckPacket | ||||
|  | ||||
| 		timeout = time.NewTimer(dialTimeout) | ||||
| 		probe = probePacket{TraceID: newTraceID()} | ||||
|  | ||||
| 		timeoutTimer = time.NewTimer(timeoutInterval) | ||||
| 		pingTimer    = time.NewTimer(pingInterval) | ||||
| 	) | ||||
|  | ||||
| 	defer timeout.Stop() | ||||
| 	defer timeoutTimer.Stop() | ||||
| 	defer pingTimer.Stop() | ||||
|  | ||||
| 	s.sendControlPacket(syn) | ||||
|  | ||||
| @@ -289,64 +294,36 @@ func (s *peerSupervisor) clientDial() stateFunc { | ||||
|  | ||||
| 		case pkt := <-s.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
|  | ||||
| 			case synAckPacket: | ||||
| 				if p.TraceID != syn.TraceID { | ||||
| 					s.logf("Bad traceID?") | ||||
| 					continue // Hmm... | ||||
| 				} | ||||
| 				s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr}) | ||||
| 				return s.clientConnected(p) | ||||
| 			} | ||||
|  | ||||
| 		case <-timeout.C: | ||||
| 			return s.clientInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) clientConnected(p synAckPacket) stateFunc { | ||||
| 	s.logf("STATE: client-connected") | ||||
| 	s.staged.Up = true | ||||
| 	s.staged.LocalAddr = p.RecvAddr | ||||
| 	s.publish() | ||||
|  | ||||
| 	return func() stateFunc { | ||||
| 		return s._clientConnected(p.TraceID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) _clientConnected(traceID uint64) stateFunc { | ||||
|  | ||||
| 	pingTimer := time.NewTimer(pingInterval) | ||||
| 	timeoutTimer := time.NewTimer(timeoutInterval) | ||||
|  | ||||
| 	defer pingTimer.Stop() | ||||
| 	defer timeoutTimer.Stop() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case peer := <-s.peerUpdates: | ||||
| 			return s.peerUpdate(peer) | ||||
|  | ||||
| 		case pkt := <-s.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
|  | ||||
| 			case ackPacket: | ||||
| 				if p.TraceID != traceID { | ||||
| 					return s.clientInit | ||||
| 				} | ||||
| 				ack = p | ||||
| 				timeoutTimer.Reset(timeoutInterval) | ||||
|  | ||||
| 				if !s.staged.Up { | ||||
| 					s.staged.Up = true | ||||
| 					s.staged.LocalAddr = p.ToAddr | ||||
| 					s.publish() | ||||
| 				} | ||||
|  | ||||
| 			case probePacket: | ||||
| 				s.logf("CLIENT got probe: %v", p) | ||||
| 			} | ||||
|  | ||||
| 		case <-pingTimer.C: | ||||
| 			s.sendControlPacket(ackPacket{TraceID: traceID}) | ||||
| 			s.sendControlPacket(syn) | ||||
| 			pingTimer.Reset(pingInterval) | ||||
|  | ||||
| 		case <-timeoutTimer.C: | ||||
| 			s.logf("client timeout") | ||||
| 			return s.clientInit | ||||
| 			if s.staged.RelayIP != 0 && ack.FromAddr.IsValid() { | ||||
| 				s.logf("CLIENT sending probe %v: %v", probe, ack.FromAddr) | ||||
| 				s.sendControlPacketTo(probe, ack.FromAddr) | ||||
| 			} | ||||
|  | ||||
| 		case <-timeoutTimer.C: | ||||
| 			return s.clientInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user