141 lines
2.8 KiB
Go
141 lines
2.8 KiB
Go
package peer
|
|
|
|
import (
|
|
"io"
|
|
"log"
|
|
"net/netip"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type connReader struct {
|
|
// Input
|
|
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error)
|
|
|
|
// Output
|
|
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
|
|
iface io.Writer
|
|
handleControlMsg func(fromIP byte, pkt any)
|
|
|
|
localIP byte
|
|
rt *atomic.Pointer[routingTable]
|
|
|
|
buf []byte
|
|
decBuf []byte
|
|
}
|
|
|
|
func newConnReader(
|
|
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error),
|
|
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
|
|
iface io.Writer,
|
|
handleControlMsg func(fromIP byte, pkt any),
|
|
rt *atomic.Pointer[routingTable],
|
|
) *connReader {
|
|
return &connReader{
|
|
readFromUDPAddrPort: readFromUDPAddrPort,
|
|
writeToUDPAddrPort: writeToUDPAddrPort,
|
|
iface: iface,
|
|
handleControlMsg: handleControlMsg,
|
|
localIP: rt.Load().LocalIP,
|
|
rt: rt,
|
|
buf: newBuf(),
|
|
decBuf: newBuf(),
|
|
}
|
|
}
|
|
|
|
func (r *connReader) Run() {
|
|
for {
|
|
r.handleNextPacket()
|
|
}
|
|
}
|
|
|
|
func (r *connReader) handleNextPacket() {
|
|
buf := r.buf[:bufferSize]
|
|
n, remoteAddr, err := r.readFromUDPAddrPort(buf)
|
|
if err != nil {
|
|
log.Fatalf("Failed to read from UDP port: %v", err)
|
|
}
|
|
|
|
if n < headerSize {
|
|
return
|
|
}
|
|
|
|
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
|
|
|
|
buf = buf[:n]
|
|
h := parseHeader(buf)
|
|
|
|
rt := r.rt.Load()
|
|
peer := rt.Peers[h.SourceIP]
|
|
|
|
switch h.StreamID {
|
|
case controlStreamID:
|
|
r.handleControlPacket(remoteAddr, peer, h, buf)
|
|
case dataStreamID:
|
|
r.handleDataPacket(rt, peer, h, buf)
|
|
default:
|
|
r.logf("Unknown stream ID: %d", h.StreamID)
|
|
}
|
|
}
|
|
|
|
func (r *connReader) handleControlPacket(
|
|
remoteAddr netip.AddrPort,
|
|
peer remotePeer,
|
|
h header,
|
|
enc []byte,
|
|
) {
|
|
if peer.ControlCipher == nil {
|
|
log.Printf("No control cipher for peer: %v", h)
|
|
return
|
|
}
|
|
|
|
if h.DestIP != r.localIP {
|
|
r.logf("Incorrect destination IP on control packet: %d", h.DestIP)
|
|
return
|
|
}
|
|
|
|
msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf)
|
|
if err != nil {
|
|
r.logf("Failed to decrypt control packet: %v", err)
|
|
return
|
|
}
|
|
|
|
r.handleControlMsg(h.SourceIP, msg)
|
|
}
|
|
|
|
func (r *connReader) handleDataPacket(
|
|
rt *routingTable,
|
|
peer remotePeer,
|
|
h header,
|
|
enc []byte,
|
|
) {
|
|
if !peer.Up {
|
|
r.logf("Not connected (recv).")
|
|
return
|
|
}
|
|
|
|
data, err := peer.DecryptDataPacket(h, enc, r.decBuf)
|
|
if err != nil {
|
|
r.logf("Failed to decrypt data packet: %v", err)
|
|
return
|
|
}
|
|
|
|
if h.DestIP == r.localIP {
|
|
if _, err := r.iface.Write(data); err != nil {
|
|
log.Fatalf("Failed to write to interface: %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
remote := rt.Peers[h.DestIP]
|
|
if !remote.Direct {
|
|
r.logf("Unable to relay data to %d.", h.DestIP)
|
|
return
|
|
}
|
|
|
|
r.writeToUDPAddrPort(data, remote.DirectAddr)
|
|
}
|
|
|
|
func (r *connReader) logf(format string, args ...any) {
|
|
log.Printf("[ConnReader] "+format, args...)
|
|
}
|