package stage3 import ( "fmt" "io" "log" "net" "net/netip" "runtime/debug" ) var ( network = []byte{10, 1, 1, 0} serverIP = byte(1) clientIP = byte(2) port = uint16(5151) netName = "testnet" bufferSize = if_mtu * 2 ) func must(err error) { if err != nil { panic(err) } } func RunServer() { defer func() { if r := recover(); r != nil { fmt.Printf("%v", r) debug.PrintStack() } }() iface, err := openInterface(network, serverIP, netName) must(err) myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) must(err) conn, err := net.ListenUDP("udp", myAddr) must(err) // Get remoteAddr from a packet. buf := make([]byte, 8) _, remoteAddr, err := conn.ReadFromUDPAddrPort(buf) log.Printf("Got remote addr: %v", remoteAddr) must(err) go readFromIFace(iface, conn, serverIP, clientIP, remoteAddr) readFromConn(iface, conn, clientIP) } func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, localIP, remoteIP byte, remoteAddr netip.AddrPort) { var ( n int packet = make([]byte, bufferSize) version byte ip byte err error counter uint64 buf = make([]byte, bufferSize) ) for { n, err = iface.Read(packet[:bufferSize]) must(err) packet = packet[:n] if len(packet) < 20 { log.Printf("Dropping small packet: %d", n) continue } packet = packet[:n] version = packet[0] >> 4 switch version { case 4: ip = packet[19] case 6: ip = packet[39] default: log.Printf("Dropping packet with IP version: %d", version) continue } if ip != remoteIP { log.Printf("Dropping packet for incorrect IP: %d", ip) continue } h := packetHeader{SrcIP: localIP, Counter: counter} counter++ buf = buf[:headerSize+len(packet)] h.Marshal(buf) copy(buf[headerSize:], packet) _, err = conn.WriteToUDPAddrPort(buf, remoteAddr) must(err) } } func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte) { var ( n int packet = make([]byte, bufferSize) err error counter uint64 run uint64 h packetHeader ) for { // We assume that we're only receiving packets from one source. n, err = conn.Read(packet[:bufferSize]) must(err) packet = packet[:n] if len(packet) < headerSize { fmt.Print("_") continue } h.Parse(packet) if h.SrcIP != remoteIP { fmt.Print("?") continue } if h.Counter == counter+1 { run++ counter = h.Counter } else if h.Counter > counter+1 { fmt.Printf("x(%d/%d)", h.Counter-counter+1, run) run = 0 counter = h.Counter } else if h.Counter <= counter { //log.Printf("Skipped late packet: -%d", counter-h.Counter) //continue fmt.Print("<") } _, err = iface.Write(packet[headerSize:]) must(err) } }