vppn/node/conn.go
2024-12-18 12:35:47 +01:00

137 lines
2.6 KiB
Go

package node
import (
"log"
"net"
"net/netip"
"sync"
"sync/atomic"
"vppn/fasttime"
)
type connWriter struct {
*net.UDPConn
lock sync.Mutex
localIP byte
buf []byte
counters [256]uint64
routing *routingTable
}
func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *connWriter {
w := &connWriter{
UDPConn: conn,
localIP: localIP,
buf: make([]byte, bufferSize),
routing: routing,
}
for i := range w.counters {
w.counters[i] = uint64(fasttime.Now() << 30)
}
return w
}
func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) error {
// TODO: Handle mediator.
peer := w.routing.Get(remoteIP)
if peer == nil || peer.Addr == nil {
log.Printf("No peer: %d", remoteIP)
return nil
}
if stream == streamData && !peer.Up {
log.Printf("Peer down: %d", remoteIP)
}
return w.WriteToPeer(peer, stream, data)
}
func (w *connWriter) WriteToPeer(peer *peer, stream byte, data []byte) error {
w.lock.Lock()
remoteIP := peer.IP
h := header{
Counter: atomic.AddUint64(&w.counters[remoteIP], 1),
SourceIP: w.localIP,
ViaIP: 0,
DestIP: remoteIP,
Stream: stream,
}
buf := encryptPacket(&h, peer.SharedKey, data, w.buf)
_, err := w.WriteToUDPAddrPort(buf, *peer.Addr)
w.lock.Unlock()
return err
}
// ----------------------------------------------------------------------------
type connReader struct {
*net.UDPConn
localIP byte
dupChecks [256]*dupCheck
routing *routingTable
buf []byte
}
func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *connReader {
r := &connReader{
UDPConn: conn,
localIP: localIP,
routing: routing,
buf: make([]byte, bufferSize),
}
for i := range r.dupChecks {
r.dupChecks[i] = newDupCheck(0)
}
return r
}
func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data []byte, err error) {
var n int
for {
n, remoteAddr, err = r.ReadFromUDPAddrPort(buf[:bufferSize])
if err != nil {
return
}
data = buf[:n]
if n < headerSize {
log.Printf("Dropping short packet: %d", n)
continue // Packet it soo short.
}
h.Parse(data)
if len(data) != headerSize+int(h.DataSize) {
log.Printf("Malformed packet: %d != %d", len(data), headerSize+int(h.DataSize))
continue
}
peer := r.routing.Get(h.SourceIP)
if peer == nil {
log.Printf("No peer: %d...", h.SourceIP)
continue
}
out, ok := decryptPacket(peer.SharedKey, data, r.buf)
if !ok {
log.Printf("Decrypt failed...")
continue
}
out, data = data, out
if r.dupChecks[h.SourceIP].IsDup(h.Counter) {
log.Printf("Duplicate: %d", h.Counter)
continue
}
return
}
}