package peer import ( "io" "log" "net/netip" "sync/atomic" ) type ConnReader struct { // Input readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) // Output iface io.Writer forwardData func(ip byte, pkt []byte) handleControlMsg func(pkt any) localIP byte rt *atomic.Pointer[RoutingTable] buf []byte decBuf []byte } func NewConnReader( readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), iface io.Writer, forwardData func(ip byte, pkt []byte), handleControlMsg func(pkt any), rt *atomic.Pointer[RoutingTable], ) *ConnReader { return &ConnReader{ readFromUDPAddrPort: readFromUDPAddrPort, iface: iface, forwardData: forwardData, handleControlMsg: handleControlMsg, localIP: rt.Load().LocalIP, rt: rt, buf: newBuf(), decBuf: newBuf(), } } func (r *ConnReader) Run() { for { r.handleNextPacket() } } func (r *ConnReader) handleNextPacket() { buf := r.buf[:bufferSize] n, remoteAddr, err := r.readFromUDPAddrPort(buf) if err != nil { log.Fatalf("Failed to read from UDP port: %v", err) } if n < headerSize { return } remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) buf = buf[:n] h := parseHeader(buf) peer := r.rt.Load().Peers[h.SourceIP] //peer := rt.Peers[h.SourceIP] switch h.StreamID { case controlStreamID: r.handleControlPacket(remoteAddr, peer, h, buf) case dataStreamID: r.handleDataPacket(peer, h, buf) default: r.logf("Unknown stream ID: %d", h.StreamID) } } func (r *ConnReader) handleControlPacket( remoteAddr netip.AddrPort, peer RemotePeer, h header, enc []byte, ) { if peer.ControlCipher == nil { return } if h.DestIP != r.localIP { r.logf("Incorrect destination IP on control packet: %d", h.DestIP) return } msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) if err != nil { r.logf("Failed to decrypt control packet: %v", err) return } r.handleControlMsg(msg) } func (r *ConnReader) handleDataPacket( peer RemotePeer, h header, enc []byte, ) { if !peer.Up { r.logf("Not connected (recv).") return } data, err := peer.DecryptDataPacket(h, enc, r.decBuf) if err != nil { r.logf("Failed to decrypt data packet: %v", err) return } if h.DestIP == r.localIP { if _, err := r.iface.Write(data); err != nil { log.Fatalf("Failed to write to interface: %v", err) } return } r.forwardData(h.DestIP, data) } func (r *ConnReader) logf(format string, args ...any) { log.Printf("[ConnReader] "+format, args...) }