vppn/peer/connreader.go
2025-08-26 15:45:06 +02:00

142 lines
2.9 KiB
Go

package peer
import (
"io"
"log"
"net/netip"
"sync/atomic"
)
type connReader struct {
// Input
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error)
// Output
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
iface io.Writer
handleControlMsg func(fromIP byte, pkt any)
localIP byte
rt *atomic.Pointer[routingTable]
buf []byte
decBuf []byte
}
func newConnReader(
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error),
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
iface io.Writer,
handleControlMsg func(fromIP byte, pkt any),
rt *atomic.Pointer[routingTable],
) *connReader {
return &connReader{
readFromUDPAddrPort: readFromUDPAddrPort,
writeToUDPAddrPort: writeToUDPAddrPort,
iface: iface,
handleControlMsg: handleControlMsg,
localIP: rt.Load().LocalIP,
rt: rt,
buf: newBuf(),
decBuf: newBuf(),
}
}
func (r *connReader) Run() {
for {
r.handleNextPacket()
}
}
func (r *connReader) handleNextPacket() {
buf := r.buf[:bufferSize]
n, remoteAddr, err := r.readFromUDPAddrPort(buf)
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
if n < headerSize {
return
}
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
buf = buf[:n]
h := parseHeader(buf)
rt := r.rt.Load()
peer := rt.Peers[h.SourceIP]
switch h.StreamID {
case controlStreamID:
r.handleControlPacket(remoteAddr, peer, h, buf)
case dataStreamID:
r.handleDataPacket(rt, peer, h, buf)
default:
r.logf("Unknown stream ID: %d", h.StreamID)
}
}
func (r *connReader) handleControlPacket(
remoteAddr netip.AddrPort,
peer remotePeer,
h header,
enc []byte,
) {
if peer.ControlCipher == nil {
r.logf("No control cipher for peer: %d", h.SourceIP)
return
}
if h.DestIP != r.localIP {
r.logf("Incorrect destination IP on control packet: %d", h.DestIP)
return
}
msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt control packet: %v", err)
return
}
r.handleControlMsg(h.SourceIP, msg)
}
func (r *connReader) handleDataPacket(
rt *routingTable,
peer remotePeer,
h header,
enc []byte,
) {
if !peer.Up {
r.logf("Not connected (recv).")
return
}
data, err := peer.DecryptDataPacket(h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt data packet: %v", err)
return
}
if h.DestIP == r.localIP {
if _, err := r.iface.Write(data); err != nil {
// Could be invalid data from peer. Don't crash.
log.Printf("Failed to write to interface: %v", err)
}
return
}
remote := rt.Peers[h.DestIP]
if !remote.Direct {
r.logf("Unable to relay data to %d.", h.DestIP)
return
}
r.writeToUDPAddrPort(data, remote.DirectAddr)
}
func (r *connReader) logf(format string, args ...any) {
log.Printf("[ConnReader] "+format, args...)
}