113 lines
2.0 KiB
Go
113 lines
2.0 KiB
Go
package stage2
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/netip"
|
|
"runtime/debug"
|
|
)
|
|
|
|
var (
|
|
network = []byte{10, 1, 1, 0}
|
|
serverIP = byte(1)
|
|
clientIP = byte(2)
|
|
port = uint16(5151)
|
|
netName = "testnet"
|
|
bufferSize = if_mtu * 2
|
|
)
|
|
|
|
func must(err error) {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func RunServer() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
fmt.Printf("%v", r)
|
|
debug.PrintStack()
|
|
}
|
|
}()
|
|
|
|
iface, err := openInterface(network, serverIP, netName)
|
|
must(err)
|
|
|
|
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
|
|
must(err)
|
|
|
|
conn, err := net.ListenUDP("udp", myAddr)
|
|
must(err)
|
|
|
|
// Get remoteAddr from a packet.
|
|
buf := make([]byte, 8)
|
|
_, remoteAddr, err := conn.ReadFromUDPAddrPort(buf)
|
|
log.Printf("Got remote addr: %v", remoteAddr)
|
|
must(err)
|
|
|
|
go readFromIFace(iface, conn, clientIP, remoteAddr)
|
|
readFromConn(iface, conn)
|
|
}
|
|
|
|
func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte, remoteAddr netip.AddrPort) {
|
|
var (
|
|
n int
|
|
packet = make([]byte, bufferSize)
|
|
version byte
|
|
ip byte
|
|
err error
|
|
)
|
|
|
|
for {
|
|
n, err = iface.Read(packet[:bufferSize])
|
|
must(err)
|
|
packet = packet[:n]
|
|
|
|
if len(packet) < 20 {
|
|
log.Printf("Dropping small packet: %d", n)
|
|
continue
|
|
}
|
|
|
|
packet = packet[:n]
|
|
version = packet[0] >> 4
|
|
|
|
switch version {
|
|
case 4:
|
|
ip = packet[19]
|
|
case 6:
|
|
ip = packet[39]
|
|
default:
|
|
log.Printf("Dropping packet with IP version: %d", version)
|
|
continue
|
|
}
|
|
|
|
if ip != remoteIP {
|
|
log.Printf("Dropping packet for incorrect IP: %d", ip)
|
|
continue
|
|
}
|
|
|
|
_, err = conn.WriteToUDPAddrPort(packet, remoteAddr)
|
|
must(err)
|
|
}
|
|
}
|
|
|
|
func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn) {
|
|
var (
|
|
n int
|
|
packet = make([]byte, bufferSize)
|
|
err error
|
|
)
|
|
|
|
for {
|
|
// We assume that we're only receiving packets from one source.
|
|
n, err = conn.Read(packet[:bufferSize])
|
|
must(err)
|
|
|
|
packet = packet[:n]
|
|
_, err = iface.Write(packet)
|
|
must(err)
|
|
}
|
|
}
|