148 lines
2.7 KiB
Go
148 lines
2.7 KiB
Go
package stage3
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/netip"
|
|
"runtime/debug"
|
|
)
|
|
|
|
var (
|
|
network = []byte{10, 1, 1, 0}
|
|
serverIP = byte(1)
|
|
clientIP = byte(2)
|
|
port = uint16(5151)
|
|
netName = "testnet"
|
|
bufferSize = if_mtu * 2
|
|
)
|
|
|
|
func must(err error) {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func RunServer() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
fmt.Printf("%v", r)
|
|
debug.PrintStack()
|
|
}
|
|
}()
|
|
|
|
iface, err := openInterface(network, serverIP, netName)
|
|
must(err)
|
|
|
|
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
|
|
must(err)
|
|
|
|
conn, err := net.ListenUDP("udp", myAddr)
|
|
must(err)
|
|
|
|
// Get remoteAddr from a packet.
|
|
buf := make([]byte, 8)
|
|
_, remoteAddr, err := conn.ReadFromUDPAddrPort(buf)
|
|
log.Printf("Got remote addr: %v", remoteAddr)
|
|
must(err)
|
|
|
|
go readFromIFace(iface, conn, serverIP, clientIP, remoteAddr)
|
|
readFromConn(iface, conn, clientIP)
|
|
}
|
|
|
|
func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, localIP, remoteIP byte, remoteAddr netip.AddrPort) {
|
|
var (
|
|
n int
|
|
packet = make([]byte, bufferSize)
|
|
version byte
|
|
ip byte
|
|
err error
|
|
counter uint64
|
|
buf = make([]byte, bufferSize)
|
|
)
|
|
|
|
for {
|
|
n, err = iface.Read(packet[:bufferSize])
|
|
must(err)
|
|
packet = packet[:n]
|
|
|
|
if len(packet) < 20 {
|
|
log.Printf("Dropping small packet: %d", n)
|
|
continue
|
|
}
|
|
|
|
packet = packet[:n]
|
|
version = packet[0] >> 4
|
|
|
|
switch version {
|
|
case 4:
|
|
ip = packet[19]
|
|
case 6:
|
|
ip = packet[39]
|
|
default:
|
|
log.Printf("Dropping packet with IP version: %d", version)
|
|
continue
|
|
}
|
|
|
|
if ip != remoteIP {
|
|
log.Printf("Dropping packet for incorrect IP: %d", ip)
|
|
continue
|
|
}
|
|
|
|
h := packetHeader{SrcIP: localIP, Counter: counter}
|
|
counter++
|
|
buf = buf[:headerSize+len(packet)]
|
|
h.Marshal(buf)
|
|
copy(buf[headerSize:], packet)
|
|
|
|
_, err = conn.WriteToUDPAddrPort(buf, remoteAddr)
|
|
must(err)
|
|
}
|
|
}
|
|
|
|
func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte) {
|
|
var (
|
|
n int
|
|
packet = make([]byte, bufferSize)
|
|
err error
|
|
counter uint64
|
|
run uint64
|
|
h packetHeader
|
|
)
|
|
|
|
for {
|
|
// We assume that we're only receiving packets from one source.
|
|
n, err = conn.Read(packet[:bufferSize])
|
|
must(err)
|
|
|
|
packet = packet[:n]
|
|
if len(packet) < headerSize {
|
|
fmt.Print("_")
|
|
continue
|
|
}
|
|
|
|
h.Parse(packet)
|
|
if h.SrcIP != remoteIP {
|
|
fmt.Print("?")
|
|
continue
|
|
}
|
|
|
|
if h.Counter == counter+1 {
|
|
run++
|
|
counter = h.Counter
|
|
} else if h.Counter > counter+1 {
|
|
fmt.Printf("x(%d/%d)", h.Counter-counter+1, run)
|
|
run = 0
|
|
counter = h.Counter
|
|
} else if h.Counter <= counter {
|
|
//log.Printf("Skipped late packet: -%d", counter-h.Counter)
|
|
//continue
|
|
fmt.Print("<")
|
|
}
|
|
|
|
_, err = iface.Write(packet[headerSize:])
|
|
must(err)
|
|
}
|
|
}
|