vppn/stage2/server.go
2024-12-16 20:51:30 +01:00

113 lines
2.0 KiB
Go

package stage2
import (
"fmt"
"io"
"log"
"net"
"net/netip"
"runtime/debug"
)
var (
network = []byte{10, 1, 1, 0}
serverIP = byte(1)
clientIP = byte(2)
port = uint16(5151)
netName = "testnet"
bufferSize = if_mtu * 2
)
func must(err error) {
if err != nil {
panic(err)
}
}
func RunServer() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("%v", r)
debug.PrintStack()
}
}()
iface, err := openInterface(network, serverIP, netName)
must(err)
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
must(err)
conn, err := net.ListenUDP("udp", myAddr)
must(err)
// Get remoteAddr from a packet.
buf := make([]byte, 8)
_, remoteAddr, err := conn.ReadFromUDPAddrPort(buf)
log.Printf("Got remote addr: %v", remoteAddr)
must(err)
go readFromIFace(iface, conn, clientIP, remoteAddr)
readFromConn(iface, conn)
}
func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte, remoteAddr netip.AddrPort) {
var (
n int
packet = make([]byte, bufferSize)
version byte
ip byte
err error
)
for {
n, err = iface.Read(packet[:bufferSize])
must(err)
packet = packet[:n]
if len(packet) < 20 {
log.Printf("Dropping small packet: %d", n)
continue
}
packet = packet[:n]
version = packet[0] >> 4
switch version {
case 4:
ip = packet[19]
case 6:
ip = packet[39]
default:
log.Printf("Dropping packet with IP version: %d", version)
continue
}
if ip != remoteIP {
log.Printf("Dropping packet for incorrect IP: %d", ip)
continue
}
_, err = conn.WriteToUDPAddrPort(packet, remoteAddr)
must(err)
}
}
func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn) {
var (
n int
packet = make([]byte, bufferSize)
err error
)
for {
// We assume that we're only receiving packets from one source.
n, err = conn.Read(packet[:bufferSize])
must(err)
packet = packet[:n]
_, err = iface.Write(packet)
must(err)
}
}