vppn/stage3/server.go
2024-12-16 20:51:30 +01:00

148 lines
2.7 KiB
Go

package stage3
import (
"fmt"
"io"
"log"
"net"
"net/netip"
"runtime/debug"
)
var (
network = []byte{10, 1, 1, 0}
serverIP = byte(1)
clientIP = byte(2)
port = uint16(5151)
netName = "testnet"
bufferSize = if_mtu * 2
)
func must(err error) {
if err != nil {
panic(err)
}
}
func RunServer() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("%v", r)
debug.PrintStack()
}
}()
iface, err := openInterface(network, serverIP, netName)
must(err)
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
must(err)
conn, err := net.ListenUDP("udp", myAddr)
must(err)
// Get remoteAddr from a packet.
buf := make([]byte, 8)
_, remoteAddr, err := conn.ReadFromUDPAddrPort(buf)
log.Printf("Got remote addr: %v", remoteAddr)
must(err)
go readFromIFace(iface, conn, serverIP, clientIP, remoteAddr)
readFromConn(iface, conn, clientIP)
}
func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, localIP, remoteIP byte, remoteAddr netip.AddrPort) {
var (
n int
packet = make([]byte, bufferSize)
version byte
ip byte
err error
counter uint64
buf = make([]byte, bufferSize)
)
for {
n, err = iface.Read(packet[:bufferSize])
must(err)
packet = packet[:n]
if len(packet) < 20 {
log.Printf("Dropping small packet: %d", n)
continue
}
packet = packet[:n]
version = packet[0] >> 4
switch version {
case 4:
ip = packet[19]
case 6:
ip = packet[39]
default:
log.Printf("Dropping packet with IP version: %d", version)
continue
}
if ip != remoteIP {
log.Printf("Dropping packet for incorrect IP: %d", ip)
continue
}
h := packetHeader{SrcIP: localIP, Counter: counter}
counter++
buf = buf[:headerSize+len(packet)]
h.Marshal(buf)
copy(buf[headerSize:], packet)
_, err = conn.WriteToUDPAddrPort(buf, remoteAddr)
must(err)
}
}
func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte) {
var (
n int
packet = make([]byte, bufferSize)
err error
counter uint64
run uint64
h packetHeader
)
for {
// We assume that we're only receiving packets from one source.
n, err = conn.Read(packet[:bufferSize])
must(err)
packet = packet[:n]
if len(packet) < headerSize {
fmt.Print("_")
continue
}
h.Parse(packet)
if h.SrcIP != remoteIP {
fmt.Print("?")
continue
}
if h.Counter == counter+1 {
run++
counter = h.Counter
} else if h.Counter > counter+1 {
fmt.Printf("x(%d/%d)", h.Counter-counter+1, run)
run = 0
counter = h.Counter
} else if h.Counter <= counter {
//log.Printf("Skipped late packet: -%d", counter-h.Counter)
//continue
fmt.Print("<")
}
_, err = iface.Write(packet[headerSize:])
must(err)
}
}