vppn/peer/connreader2.go
2025-02-10 19:11:30 +01:00

133 lines
2.6 KiB
Go

package peer
import (
"io"
"log"
"net/netip"
"sync/atomic"
)
type ConnReader struct {
// Input
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error)
// Output
iface io.Writer
forwardData func(ip byte, pkt []byte)
handleControlMsg func(pkt any)
localIP byte
rt *atomic.Pointer[RoutingTable]
buf []byte
decBuf []byte
}
func NewConnReader(
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error),
iface io.Writer,
forwardData func(ip byte, pkt []byte),
handleControlMsg func(pkt any),
rt *atomic.Pointer[RoutingTable],
) *ConnReader {
return &ConnReader{
readFromUDPAddrPort: readFromUDPAddrPort,
iface: iface,
forwardData: forwardData,
handleControlMsg: handleControlMsg,
localIP: rt.Load().LocalIP,
rt: rt,
buf: newBuf(),
decBuf: newBuf(),
}
}
func (r *ConnReader) Run() {
for {
r.handleNextPacket()
}
}
func (r *ConnReader) handleNextPacket() {
buf := r.buf[:bufferSize]
n, remoteAddr, err := r.readFromUDPAddrPort(buf)
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
if n < headerSize {
return
}
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
buf = buf[:n]
h := parseHeader(buf)
peer := r.rt.Load().Peers[h.SourceIP]
//peer := rt.Peers[h.SourceIP]
switch h.StreamID {
case controlStreamID:
r.handleControlPacket(remoteAddr, peer, h, buf)
case dataStreamID:
r.handleDataPacket(peer, h, buf)
default:
r.logf("Unknown stream ID: %d", h.StreamID)
}
}
func (r *ConnReader) handleControlPacket(
remoteAddr netip.AddrPort,
peer RemotePeer,
h header,
enc []byte,
) {
if peer.ControlCipher == nil {
return
}
if h.DestIP != r.localIP {
r.logf("Incorrect destination IP on control packet: %d", h.DestIP)
return
}
msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt control packet: %v", err)
return
}
r.handleControlMsg(msg)
}
func (r *ConnReader) handleDataPacket(
peer RemotePeer,
h header,
enc []byte,
) {
if !peer.Up {
r.logf("Not connected (recv).")
return
}
data, err := peer.DecryptDataPacket(h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt data packet: %v", err)
return
}
if h.DestIP == r.localIP {
if _, err := r.iface.Write(data); err != nil {
log.Fatalf("Failed to write to interface: %v", err)
}
return
}
r.forwardData(h.DestIP, data)
}
func (r *ConnReader) logf(format string, args ...any) {
log.Printf("[ConnReader] "+format, args...)
}