133 lines
2.6 KiB
Go
133 lines
2.6 KiB
Go
package peer
|
|
|
|
import (
|
|
"io"
|
|
"log"
|
|
"net/netip"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type ConnReader struct {
|
|
// Input
|
|
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error)
|
|
|
|
// Output
|
|
iface io.Writer
|
|
forwardData func(ip byte, pkt []byte)
|
|
handleControlMsg func(pkt any)
|
|
|
|
localIP byte
|
|
rt *atomic.Pointer[RoutingTable]
|
|
|
|
buf []byte
|
|
decBuf []byte
|
|
}
|
|
|
|
func NewConnReader(
|
|
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error),
|
|
iface io.Writer,
|
|
forwardData func(ip byte, pkt []byte),
|
|
handleControlMsg func(pkt any),
|
|
rt *atomic.Pointer[RoutingTable],
|
|
) *ConnReader {
|
|
return &ConnReader{
|
|
readFromUDPAddrPort: readFromUDPAddrPort,
|
|
iface: iface,
|
|
forwardData: forwardData,
|
|
handleControlMsg: handleControlMsg,
|
|
localIP: rt.Load().LocalIP,
|
|
rt: rt,
|
|
buf: newBuf(),
|
|
decBuf: newBuf(),
|
|
}
|
|
}
|
|
|
|
func (r *ConnReader) Run() {
|
|
for {
|
|
r.handleNextPacket()
|
|
}
|
|
}
|
|
|
|
func (r *ConnReader) handleNextPacket() {
|
|
buf := r.buf[:bufferSize]
|
|
n, remoteAddr, err := r.readFromUDPAddrPort(buf)
|
|
if err != nil {
|
|
log.Fatalf("Failed to read from UDP port: %v", err)
|
|
}
|
|
|
|
if n < headerSize {
|
|
return
|
|
}
|
|
|
|
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
|
|
|
|
buf = buf[:n]
|
|
h := parseHeader(buf)
|
|
|
|
peer := r.rt.Load().Peers[h.SourceIP]
|
|
//peer := rt.Peers[h.SourceIP]
|
|
|
|
switch h.StreamID {
|
|
case controlStreamID:
|
|
r.handleControlPacket(remoteAddr, peer, h, buf)
|
|
case dataStreamID:
|
|
r.handleDataPacket(peer, h, buf)
|
|
default:
|
|
r.logf("Unknown stream ID: %d", h.StreamID)
|
|
}
|
|
}
|
|
|
|
func (r *ConnReader) handleControlPacket(
|
|
remoteAddr netip.AddrPort,
|
|
peer RemotePeer,
|
|
h header,
|
|
enc []byte,
|
|
) {
|
|
if peer.ControlCipher == nil {
|
|
return
|
|
}
|
|
|
|
if h.DestIP != r.localIP {
|
|
r.logf("Incorrect destination IP on control packet: %d", h.DestIP)
|
|
return
|
|
}
|
|
|
|
msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf)
|
|
if err != nil {
|
|
r.logf("Failed to decrypt control packet: %v", err)
|
|
return
|
|
}
|
|
|
|
r.handleControlMsg(msg)
|
|
}
|
|
|
|
func (r *ConnReader) handleDataPacket(
|
|
peer RemotePeer,
|
|
h header,
|
|
enc []byte,
|
|
) {
|
|
if !peer.Up {
|
|
r.logf("Not connected (recv).")
|
|
return
|
|
}
|
|
|
|
data, err := peer.DecryptDataPacket(h, enc, r.decBuf)
|
|
if err != nil {
|
|
r.logf("Failed to decrypt data packet: %v", err)
|
|
return
|
|
}
|
|
|
|
if h.DestIP == r.localIP {
|
|
if _, err := r.iface.Write(data); err != nil {
|
|
log.Fatalf("Failed to write to interface: %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
r.forwardData(h.DestIP, data)
|
|
}
|
|
|
|
func (r *ConnReader) logf(format string, args ...any) {
|
|
log.Printf("[ConnReader] "+format, args...)
|
|
}
|