diff --git a/node/bitset.go b/node/bitset.go new file mode 100644 index 0000000..a9024cb --- /dev/null +++ b/node/bitset.go @@ -0,0 +1,21 @@ +package node + +const bitSetSize = 512 // Multiple of 64. + +type bitSet [bitSetSize / 64]uint64 + +func (bs *bitSet) Set(i int) { + bs[i/64] |= 1 << (i % 64) +} + +func (bs *bitSet) Clear(i int) { + bs[i/64] &= ^(1 << (i % 64)) +} + +func (bs *bitSet) ClearAll() { + clear(bs[:]) +} + +func (bs *bitSet) Get(i int) bool { + return bs[i/64]&(1<<(i%64)) != 0 +} diff --git a/node/bitset_test.go b/node/bitset_test.go new file mode 100644 index 0000000..bd3307a --- /dev/null +++ b/node/bitset_test.go @@ -0,0 +1,48 @@ +package node + +import ( + "math/rand" + "testing" +) + +func TestBitSet(t *testing.T) { + state := make([]bool, bitSetSize) + for i := range state { + state[i] = rand.Float32() > 0.5 + } + + bs := bitSet{} + + for i := range state { + if state[i] { + bs.Set(i) + } + } + + for i := range state { + if bs.Get(i) != state[i] { + t.Fatal(i, state[i], bs.Get(i)) + } + } + + for i := range state { + if rand.Float32() > 0.5 { + state[i] = false + bs.Clear(i) + } + } + + for i := range state { + if bs.Get(i) != state[i] { + t.Fatal(i, state[i], bs.Get(i)) + } + } + + bs.ClearAll() + + for i := range state { + if bs.Get(i) { + t.Fatal(i, bs.Get(i)) + } + } +} diff --git a/node/cmd/client/build.sh b/node/cmd/client/build.sh new file mode 100755 index 0000000..c7d72ea --- /dev/null +++ b/node/cmd/client/build.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +go build +sudo setcap cap_net_admin+iep ./client +./client 144.76.78.93 diff --git a/node/cmd/client/main.go b/node/cmd/client/main.go new file mode 100644 index 0000000..e98c2cd --- /dev/null +++ b/node/cmd/client/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "log" + "os" + "vppn/node" +) + +func main() { + if len(os.Args) != 2 { + log.Fatalf("Usage: %s ", os.Args[0]) + } + n := node.NewTmpNodeClient() + n.RunClient(os.Args[1]) +} diff --git a/node/cmd/server/build.sh b/node/cmd/server/build.sh new file mode 100755 index 0000000..fcc5787 --- /dev/null +++ b/node/cmd/server/build.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +go build +ssh kevin "killall server" +scp server kevin:/home/jdl/tmp/ +ssh root@kevin "sudo setcap cap_net_admin+iep /home/jdl/tmp/server" +ssh kevin "/home/jdl/tmp/server" diff --git a/node/cmd/server/main.go b/node/cmd/server/main.go new file mode 100644 index 0000000..47272cb --- /dev/null +++ b/node/cmd/server/main.go @@ -0,0 +1,8 @@ +package main + +import "vppn/node" + +func main() { + n := node.NewTmpNodeServer() + n.RunServer() +} diff --git a/node/conn.go b/node/conn.go new file mode 100644 index 0000000..0a03c22 --- /dev/null +++ b/node/conn.go @@ -0,0 +1,103 @@ +package node + +import ( + "log" + "net" + "net/netip" + "vppn/fasttime" +) + +type connWriter struct { + *net.UDPConn + localIP byte + buf []byte + counters [256]uint64 + lookup func(byte) *peer +} + +func newConnWriter(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connWriter { + w := &connWriter{ + UDPConn: conn, + localIP: localIP, + buf: make([]byte, bufferSize), + lookup: lookup, + } + + for i := range w.counters { + w.counters[i] = uint64(fasttime.Now() << 30) + } + + return w +} + +func (w *connWriter) WriteTo(remoteIP, packetType byte, data []byte) error { + peer := w.lookup(remoteIP) + if peer == nil || peer.Addr == nil { + log.Printf("No peer: %d", remoteIP) + return nil + } + + w.counters[remoteIP]++ + + h := header{ + Counter: w.counters[remoteIP], + SourceIP: w.localIP, + ViaIP: 0, + DestIP: remoteIP, + PacketType: packetType, + DataSize: uint16(len(data)), + } + + buf := w.buf[:len(data)+headerSize] + h.Marshal(buf) + copy(buf[headerSize:], data) + + _, err := w.WriteToUDPAddrPort(buf, *peer.Addr) + return err +} + +// ---------------------------------------------------------------------------- + +type connReader struct { + *net.UDPConn + localIP byte + counters [256]uint64 + lookup func(byte) *peer +} + +func newConnReader(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connReader { + return &connReader{ + UDPConn: conn, + localIP: localIP, + lookup: lookup, + } +} + +func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data []byte, err error) { + var n int + + for { + n, remoteAddr, err = r.ReadFromUDPAddrPort(buf[:bufferSize]) + if err != nil { + return + } + + buf = buf[:n] + + if n < headerSize { + continue // Packet it soo short. + } + + h.Parse(buf) + data = buf[headerSize:] + if len(data) != int(h.DataSize) { + continue // Packet is corrupt. + } + + if h.Counter > r.counters[h.SourceIP] { + r.counters[h.SourceIP] = h.Counter + } + + return + } +} diff --git a/node/dupcheck.go b/node/dupcheck.go new file mode 100644 index 0000000..fac7a72 --- /dev/null +++ b/node/dupcheck.go @@ -0,0 +1,76 @@ +package node + +type dupCheck struct { + bitSet + head int + tail int + headCounter uint64 + tailCounter uint64 // Also next expected counter value. +} + +func newDupCheck(headCounter uint64) *dupCheck { + return &dupCheck{ + headCounter: headCounter, + tailCounter: headCounter + 1, + tail: 1, + } +} + +func (dc *dupCheck) IsDup(counter uint64) bool { + + // Before head => it's late, say it's a dup. + if counter < dc.headCounter { + return true + } + + // It's within the counter bounds. + if counter < dc.tailCounter { + index := (int(counter-dc.headCounter) + dc.head) % bitSetSize + if dc.Get(index) { + return true + } + + dc.Set(index) + return false + } + + // It's more than 1 beyond the tail. + delta := counter - dc.tailCounter + + // Full clear. + if delta >= bitSetSize { + dc.ClearAll() + dc.Set(0) + + dc.tail = 1 + dc.head = 2 + dc.tailCounter = counter + 1 + dc.headCounter = dc.tailCounter - bitSetSize + + return false + } + + // Clear if necessary. + for i := 0; i < int(delta); i++ { + dc.put(false) + } + + dc.put(true) + return false +} + +func (dc *dupCheck) put(set bool) { + if set { + dc.Set(dc.tail) + } else { + dc.Clear(dc.tail) + } + + dc.tail = (dc.tail + 1) % bitSetSize + dc.tailCounter++ + + if dc.head == dc.tail { + dc.head = (dc.head + 1) % bitSetSize + dc.headCounter++ + } +} diff --git a/node/dupcheck_test.go b/node/dupcheck_test.go new file mode 100644 index 0000000..9a939b5 --- /dev/null +++ b/node/dupcheck_test.go @@ -0,0 +1,57 @@ +package node + +import ( + "log" + "testing" +) + +func TestDupCheck(t *testing.T) { + dc := newDupCheck(0) + + for i := range bitSetSize { + if dc.IsDup(uint64(i)) { + t.Fatal("!") + } + } + + type TestCase struct { + Counter uint64 + Dup bool + } + + testCases := []TestCase{ + {0, true}, + {1, true}, + {2, true}, + {3, true}, + {63, true}, + {256, true}, + {510, true}, + {511, true}, + {512, false}, + {0, true}, + {512, true}, + {513, false}, + {517, false}, + {512, true}, + {513, true}, + {514, false}, + {515, false}, + {516, false}, + {517, true}, + {2512, false}, + {2000, true}, + {2001, false}, + {4000, false}, + {4000 - 512, true}, // Too old. + {4000 - 511, false}, // Just in the window. + } + + for i, tc := range testCases { + if ok := dc.IsDup(tc.Counter); ok != tc.Dup { + log.Printf("%b", dc.bitSet) + log.Printf("%+v", *dc) + t.Fatal(i, ok, tc) + } + } +} diff --git a/node/globals.go b/node/globals.go new file mode 100644 index 0000000..172e6ef --- /dev/null +++ b/node/globals.go @@ -0,0 +1,3 @@ +package node + +const bufferSize = if_mtu + 128 diff --git a/node/header.go b/node/header.go new file mode 100644 index 0000000..44affb9 --- /dev/null +++ b/node/header.go @@ -0,0 +1,32 @@ +package node + +import "unsafe" + +const headerSize = 24 + +type header struct { + Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. + SourceIP byte + ViaIP byte + DestIP byte + PacketType byte // The packet type. See PACKET_* constants. + DataSize uint16 // Data size following associated data. +} + +func (hdr *header) Parse(nb []byte) { + hdr.Counter = *(*uint64)(unsafe.Pointer(&nb[0])) + hdr.SourceIP = nb[8] + hdr.ViaIP = nb[9] + hdr.DestIP = nb[10] + hdr.PacketType = nb[11] + hdr.DataSize = *(*uint16)(unsafe.Pointer(&nb[12])) +} + +func (hdr header) Marshal(buf []byte) { + *(*uint64)(unsafe.Pointer(&buf[0])) = hdr.Counter + buf[8] = hdr.SourceIP + buf[9] = hdr.ViaIP + buf[10] = hdr.DestIP + buf[11] = hdr.PacketType + *(*uint16)(unsafe.Pointer(&buf[12])) = hdr.DataSize +} diff --git a/node/header_test.go b/node/header_test.go new file mode 100644 index 0000000..e4ff3a3 --- /dev/null +++ b/node/header_test.go @@ -0,0 +1,23 @@ +package node + +import "testing" + +func TestHeaderMarshalParse(t *testing.T) { + nIn := header{ + Counter: 3212, + SourceIP: 34, + ViaIP: 20, + DestIP: 200, + PacketType: 44, + DataSize: 1235, + } + + buf := make([]byte, headerSize) + nIn.Marshal(buf) + + nOut := header{} + nOut.Parse(buf) + if nIn != nOut { + t.Fatal(nIn, nOut) + } +} diff --git a/node/interface.go b/node/interface.go new file mode 100644 index 0000000..2dc6ba6 --- /dev/null +++ b/node/interface.go @@ -0,0 +1,177 @@ +package node + +import ( + "fmt" + "io" + "net" + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +// Get next packet, returning packet, ip, and possible error. +func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { + var ( + version byte + ip byte + ) + for { + n, err := iface.Read(buf[:cap(buf)]) + if err != nil { + return nil, ip, err + } + + if n < 20 { + continue // Packet too short. + } + + buf = buf[:n] + version = buf[0] >> 4 + + switch version { + case 4: + ip = buf[19] + case 6: + if len(buf) < 40 { + continue // Packet too short. + } + ip = buf[39] + default: + continue // Invalid version. + } + + return buf, ip, nil + } +} + +const ( + if_mtu = 1200 + if_queue_len = 1000 +) + +func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { + if len(network) != 4 { + return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) + } + ip := net.IPv4(network[0], network[1], network[2], localIP) + + ////////////////////////// + // Create TUN Interface // + ////////////////////////// + + tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open TUN device: %w", err) + } + + // New interface request. + req, err := unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) + } + + // Flags: + // + // IFF_NO_PI => don't add packet info data to packets sent to the interface. + // IFF_TUN => create a TUN device handling IP packets. + req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) + + err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) + if err != nil { + return nil, fmt.Errorf("failed to set TUN device settings: %w", err) + } + + // Name may not be exactly the same? + name = req.Name() + + ///////////// + // Set MTU // + ///////////// + + // We need a socket file descriptor to set other options for some reason. + sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return nil, fmt.Errorf("failed to open socket: %w", err) + } + defer unix.Close(sockFD) + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create MTU interface request: %w", err) + } + + req.SetUint32(if_mtu) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { + return nil, fmt.Errorf("failed to set interface MTU: %w", err) + } + + ////////////////////// + // Set Queue Length // + ////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + req.SetUint16(if_queue_len) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { + return nil, fmt.Errorf("failed to set interface queue length: %w", err) + } + + ///////////////////// + // Set IP and Mask // + ///////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + if err := req.SetInet4Addr(ip.To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request IP: %w", err) + } + + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { + return nil, fmt.Errorf("failed to set interface IP: %w", err) + } + + // SET MASK - must happen after setting address. + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create mask interface request: %w", err) + } + + if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request mask: %w", err) + } + + if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { + return nil, fmt.Errorf("failed to set interface mask: %w", err) + } + + //////////////////////// + // Bring Interface Up // + //////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create up interface request: %w", err) + } + + // Get current flags. + if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to get interface flags: %w", err) + } + + flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING + + // Set UP flag / broadcast flags. + req.SetUint16(flags) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to set interface up: %w", err) + } + + return os.NewFile(uintptr(tunFD), "tun"), nil +} diff --git a/node/peer.go b/node/peer.go new file mode 100644 index 0000000..53479d2 --- /dev/null +++ b/node/peer.go @@ -0,0 +1,33 @@ +package node + +import ( + "net/netip" + "sync/atomic" +) + +type peer struct { + IP byte + // TODO: Version + Addr *netip.AddrPort + // TODO: ViaIP + // TODO: EncPubKey + // TODO: SignPrivKey +} + +type peerRepo [256]*atomic.Pointer[peer] + +func newPeerRepo() peerRepo { + pr := peerRepo{} + for i := range pr { + pr[i] = &atomic.Pointer[peer]{} + } + return pr +} + +func (pr peerRepo) Get(ip byte) *peer { + return pr[ip].Load() +} + +func (pr peerRepo) Set(ip byte, p *peer) { + pr[ip].Store(p) +} diff --git a/node/tmp-server.go b/node/tmp-server.go new file mode 100644 index 0000000..04f6c0a --- /dev/null +++ b/node/tmp-server.go @@ -0,0 +1,164 @@ +package node + +import ( + "fmt" + "io" + "log" + "net" + "net/netip" + "runtime/debug" +) + +var ( + network = []byte{10, 1, 1, 0} + serverIP = byte(1) + clientIP = byte(2) + port = uint16(5151) + netName = "testnet" +) + +func must(err error) { + if err != nil { + panic(err) + } +} + +type TmpNode struct { + network []byte + localIP byte + peers peerRepo + port uint16 + netName string + iface io.ReadWriteCloser + w *connWriter + r *connReader +} + +func NewTmpNodeServer() *TmpNode { + n := &TmpNode{ + localIP: serverIP, + network: network, + peers: newPeerRepo(), + port: port, + netName: netName, + } + + var err error + n.iface, err = openInterface(n.network, n.localIP, n.netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + n.w = newConnWriter(conn, n.localIP, n.peers.Get) + n.r = newConnReader(conn, n.localIP, n.peers.Get) + + return n +} + +func NewTmpNodeClient() *TmpNode { + n := &TmpNode{ + localIP: clientIP, + network: network, + peers: newPeerRepo(), + port: port, + netName: netName, + } + + var err error + n.iface, err = openInterface(n.network, n.localIP, n.netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + n.w = newConnWriter(conn, n.localIP, n.peers.Get) + n.r = newConnReader(conn, n.localIP, n.peers.Get) + + return n +} + +func (n *TmpNode) RunServer() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + // Get remoteAddr from a packet. + buf := make([]byte, bufferSize) + remoteAddr, h, _, err := n.r.Read(buf) + must(err) + log.Printf("Got remote addr: %d -> %v", h.SourceIP, remoteAddr) + must(err) + + n.peers.Set(h.SourceIP, &peer{ + IP: h.SourceIP, + Addr: &remoteAddr, + }) + + go n.readFromIFace() + n.readFromConn() +} + +func (n *TmpNode) RunClient(srvAddrStr string) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", srvAddrStr, port)) + must(err) + + log.Printf("Setting %d => %v", serverIP, serverAddr) + n.peers.Set(serverIP, &peer{ + IP: serverIP, + Addr: &serverAddr, + }) + + must(n.w.WriteTo(serverIP, 1, []byte{1, 2, 3, 4, 5, 6, 7, 8})) + + go n.readFromIFace() + n.readFromConn() +} + +func (n *TmpNode) readFromIFace() { + var ( + buf = make([]byte, bufferSize) + packet []byte + remoteIP byte + err error + ) + + for { + packet, remoteIP, err = readNextPacket(n.iface, buf) + must(err) + must(n.w.WriteTo(remoteIP, 1, packet)) + } +} + +func (node *TmpNode) readFromConn() { + var ( + buf = make([]byte, bufferSize) + packet []byte + err error + ) + + for { + _, _, packet, err = node.r.Read(buf) + must(err) + // We assume that we're only receiving packets from one source. + + _, err = node.iface.Write(packet) + must(err) + } +} diff --git a/stage1/README.md b/stage1/README.md new file mode 100644 index 0000000..546f4de --- /dev/null +++ b/stage1/README.md @@ -0,0 +1 @@ +## Stage1: Point-to-point Tunnel w/ no Encryption diff --git a/stage1/client.go b/stage1/client.go new file mode 100644 index 0000000..42d8e03 --- /dev/null +++ b/stage1/client.go @@ -0,0 +1,32 @@ +package stage1 + +import ( + "fmt" + "net" + "net/netip" + "runtime/debug" +) + +func RunClient(serverAddrStr string) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + iface, err := openInterface(network, clientIP, netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", serverAddrStr, port)) + must(err) + + go readFromIFace(iface, conn, serverIP, serverAddr) + readFromConn(iface, conn) +} diff --git a/stage1/cmd/client/build.sh b/stage1/cmd/client/build.sh new file mode 100755 index 0000000..951ca95 --- /dev/null +++ b/stage1/cmd/client/build.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +go build + +scp client kevin:/home/jdl/tmp +ssh root@home "setcap cap_net_admin+iep /home/jdl/tmp/client" +ssh home "/home/jdl/tmp/client 192.168.1.21" diff --git a/stage1/cmd/client/main.go b/stage1/cmd/client/main.go new file mode 100644 index 0000000..60ccfbf --- /dev/null +++ b/stage1/cmd/client/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "log" + "os" + "vppn/stage1" +) + +func main() { + if len(os.Args) != 2 { + log.Fatalf("Usage: %s ", os.Args[0]) + } + stage1.RunClient(os.Args[1]) +} diff --git a/stage1/cmd/server/build.sh b/stage1/cmd/server/build.sh new file mode 100755 index 0000000..0c89238 --- /dev/null +++ b/stage1/cmd/server/build.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +go build +sudo setcap cap_net_admin+iep server diff --git a/stage1/cmd/server/main.go b/stage1/cmd/server/main.go new file mode 100644 index 0000000..5c5cd36 --- /dev/null +++ b/stage1/cmd/server/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "log" + "os" + "vppn/stage1" +) + +func main() { + if len(os.Args) != 2 { + log.Fatalf("Usage: %s ", os.Args[0]) + } + stage1.RunServer(os.Args[1]) +} diff --git a/stage1/interface.go b/stage1/interface.go new file mode 100644 index 0000000..1e587a2 --- /dev/null +++ b/stage1/interface.go @@ -0,0 +1,142 @@ +package stage1 + +import ( + "fmt" + "io" + "net" + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +const ( + if_mtu = 1200 + if_queue_len = 1000 +) + +func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { + if len(network) != 4 { + return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) + } + ip := net.IPv4(network[0], network[1], network[2], localIP) + + ////////////////////////// + // Create TUN Interface // + ////////////////////////// + + tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open TUN device: %w", err) + } + + // New interface request. + req, err := unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) + } + + // Flags: + // + // IFF_NO_PI => don't add packet info data to packets sent to the interface. + // IFF_TUN => create a TUN device handling IP packets. + req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) + + err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) + if err != nil { + return nil, fmt.Errorf("failed to set TUN device settings: %w", err) + } + + // Name may not be exactly the same? + name = req.Name() + + ///////////// + // Set MTU // + ///////////// + + // We need a socket file descriptor to set other options for some reason. + sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return nil, fmt.Errorf("failed to open socket: %w", err) + } + defer unix.Close(sockFD) + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create MTU interface request: %w", err) + } + + req.SetUint32(if_mtu) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { + return nil, fmt.Errorf("failed to set interface MTU: %w", err) + } + + ////////////////////// + // Set Queue Length // + ////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + req.SetUint16(if_queue_len) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { + return nil, fmt.Errorf("failed to set interface queue length: %w", err) + } + + ///////////////////// + // Set IP and Mask // + ///////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + if err := req.SetInet4Addr(ip.To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request IP: %w", err) + } + + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { + return nil, fmt.Errorf("failed to set interface IP: %w", err) + } + + // SET MASK - must happen after setting address. + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create mask interface request: %w", err) + } + + if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request mask: %w", err) + } + + if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { + return nil, fmt.Errorf("failed to set interface mask: %w", err) + } + + //////////////////////// + // Bring Interface Up // + //////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create up interface request: %w", err) + } + + // Get current flags. + if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to get interface flags: %w", err) + } + + flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING + + // Set UP flag / broadcast flags. + req.SetUint16(flags) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to set interface up: %w", err) + } + + return os.NewFile(uintptr(tunFD), "tun"), nil +} diff --git a/stage1/server.go b/stage1/server.go new file mode 100644 index 0000000..8f210c0 --- /dev/null +++ b/stage1/server.go @@ -0,0 +1,109 @@ +package stage1 + +import ( + "fmt" + "io" + "log" + "net" + "net/netip" + "runtime/debug" +) + +var ( + network = []byte{10, 1, 1, 0} + serverIP = byte(1) + clientIP = byte(2) + port = uint16(5151) + netName = "testnet" + bufferSize = if_mtu * 2 +) + +func must(err error) { + if err != nil { + panic(err) + } +} + +func RunServer(clientAddrStr string) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + iface, err := openInterface(network, serverIP, netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + clientAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", clientAddrStr, port)) + must(err) + + go readFromIFace(iface, conn, clientIP, clientAddr) + readFromConn(iface, conn) +} + +func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte, remoteAddr netip.AddrPort) { + var ( + n int + packet = make([]byte, bufferSize) + version byte + ip byte + err error + ) + + for { + n, err = iface.Read(packet[:bufferSize]) + must(err) + packet = packet[:n] + + if len(packet) < 20 { + log.Printf("Dropping small packet: %d", n) + continue + } + + packet = packet[:n] + version = packet[0] >> 4 + + switch version { + case 4: + ip = packet[19] + case 6: + ip = packet[39] + default: + log.Printf("Dropping packet with IP version: %d", version) + continue + } + + if ip != remoteIP { + log.Printf("Dropping packet for incorrect IP: %d", ip) + continue + } + + _, err = conn.WriteToUDPAddrPort(packet, remoteAddr) + must(err) + } +} + +func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn) { + var ( + n int + packet = make([]byte, bufferSize) + err error + ) + + for { + // We assume that we're only receiving packets from one source. + n, err = conn.Read(packet[:bufferSize]) + must(err) + + packet = packet[:n] + _, err = iface.Write(packet) + must(err) + } +} diff --git a/stage1/startup.go b/stage1/startup.go new file mode 100644 index 0000000..e164d95 --- /dev/null +++ b/stage1/startup.go @@ -0,0 +1 @@ +package stage1 diff --git a/stage2/README.md b/stage2/README.md new file mode 100644 index 0000000..ef00a03 --- /dev/null +++ b/stage2/README.md @@ -0,0 +1,4 @@ +## Stage2: + +* Point-to-point Tunnel w/ no Encryption +* Server gets client's addr from first packet diff --git a/stage2/client.go b/stage2/client.go new file mode 100644 index 0000000..6d969d1 --- /dev/null +++ b/stage2/client.go @@ -0,0 +1,35 @@ +package stage2 + +import ( + "fmt" + "net" + "net/netip" + "runtime/debug" +) + +func RunClient(serverAddrStr string) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + iface, err := openInterface(network, clientIP, netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", serverAddrStr, port)) + must(err) + + _, err = conn.WriteToUDPAddrPort([]byte{1, 2, 3, 4, 5, 6, 7, 8}, serverAddr) + must(err) + + go readFromIFace(iface, conn, serverIP, serverAddr) + readFromConn(iface, conn) +} diff --git a/stage2/cmd/client/build.sh b/stage2/cmd/client/build.sh new file mode 100755 index 0000000..c7d72ea --- /dev/null +++ b/stage2/cmd/client/build.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +go build +sudo setcap cap_net_admin+iep ./client +./client 144.76.78.93 diff --git a/stage2/cmd/client/main.go b/stage2/cmd/client/main.go new file mode 100644 index 0000000..7217b31 --- /dev/null +++ b/stage2/cmd/client/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "log" + "os" + "vppn/stage2" +) + +func main() { + if len(os.Args) != 2 { + log.Fatalf("Usage: %s ", os.Args[0]) + } + stage2.RunClient(os.Args[1]) +} diff --git a/stage2/cmd/server/build.sh b/stage2/cmd/server/build.sh new file mode 100755 index 0000000..8f90f02 --- /dev/null +++ b/stage2/cmd/server/build.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +go build +scp server kevin:/home/jdl/tmp/ +ssh root@kevin "sudo setcap cap_net_admin+iep /home/jdl/tmp/server" +ssh kevin "/home/jdl/tmp/server" diff --git a/stage2/cmd/server/main.go b/stage2/cmd/server/main.go new file mode 100644 index 0000000..d38dc86 --- /dev/null +++ b/stage2/cmd/server/main.go @@ -0,0 +1,7 @@ +package main + +import "vppn/stage2" + +func main() { + stage2.RunServer() +} diff --git a/stage2/interface.go b/stage2/interface.go new file mode 100644 index 0000000..f890c55 --- /dev/null +++ b/stage2/interface.go @@ -0,0 +1,142 @@ +package stage2 + +import ( + "fmt" + "io" + "net" + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +const ( + if_mtu = 1200 + if_queue_len = 1000 +) + +func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { + if len(network) != 4 { + return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) + } + ip := net.IPv4(network[0], network[1], network[2], localIP) + + ////////////////////////// + // Create TUN Interface // + ////////////////////////// + + tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open TUN device: %w", err) + } + + // New interface request. + req, err := unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) + } + + // Flags: + // + // IFF_NO_PI => don't add packet info data to packets sent to the interface. + // IFF_TUN => create a TUN device handling IP packets. + req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) + + err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) + if err != nil { + return nil, fmt.Errorf("failed to set TUN device settings: %w", err) + } + + // Name may not be exactly the same? + name = req.Name() + + ///////////// + // Set MTU // + ///////////// + + // We need a socket file descriptor to set other options for some reason. + sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return nil, fmt.Errorf("failed to open socket: %w", err) + } + defer unix.Close(sockFD) + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create MTU interface request: %w", err) + } + + req.SetUint32(if_mtu) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { + return nil, fmt.Errorf("failed to set interface MTU: %w", err) + } + + ////////////////////// + // Set Queue Length // + ////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + req.SetUint16(if_queue_len) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { + return nil, fmt.Errorf("failed to set interface queue length: %w", err) + } + + ///////////////////// + // Set IP and Mask // + ///////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + if err := req.SetInet4Addr(ip.To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request IP: %w", err) + } + + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { + return nil, fmt.Errorf("failed to set interface IP: %w", err) + } + + // SET MASK - must happen after setting address. + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create mask interface request: %w", err) + } + + if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request mask: %w", err) + } + + if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { + return nil, fmt.Errorf("failed to set interface mask: %w", err) + } + + //////////////////////// + // Bring Interface Up // + //////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create up interface request: %w", err) + } + + // Get current flags. + if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to get interface flags: %w", err) + } + + flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING + + // Set UP flag / broadcast flags. + req.SetUint16(flags) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to set interface up: %w", err) + } + + return os.NewFile(uintptr(tunFD), "tun"), nil +} diff --git a/stage2/server.go b/stage2/server.go new file mode 100644 index 0000000..01581d7 --- /dev/null +++ b/stage2/server.go @@ -0,0 +1,112 @@ +package stage2 + +import ( + "fmt" + "io" + "log" + "net" + "net/netip" + "runtime/debug" +) + +var ( + network = []byte{10, 1, 1, 0} + serverIP = byte(1) + clientIP = byte(2) + port = uint16(5151) + netName = "testnet" + bufferSize = if_mtu * 2 +) + +func must(err error) { + if err != nil { + panic(err) + } +} + +func RunServer() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + iface, err := openInterface(network, serverIP, netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + // Get remoteAddr from a packet. + buf := make([]byte, 8) + _, remoteAddr, err := conn.ReadFromUDPAddrPort(buf) + log.Printf("Got remote addr: %v", remoteAddr) + must(err) + + go readFromIFace(iface, conn, clientIP, remoteAddr) + readFromConn(iface, conn) +} + +func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte, remoteAddr netip.AddrPort) { + var ( + n int + packet = make([]byte, bufferSize) + version byte + ip byte + err error + ) + + for { + n, err = iface.Read(packet[:bufferSize]) + must(err) + packet = packet[:n] + + if len(packet) < 20 { + log.Printf("Dropping small packet: %d", n) + continue + } + + packet = packet[:n] + version = packet[0] >> 4 + + switch version { + case 4: + ip = packet[19] + case 6: + ip = packet[39] + default: + log.Printf("Dropping packet with IP version: %d", version) + continue + } + + if ip != remoteIP { + log.Printf("Dropping packet for incorrect IP: %d", ip) + continue + } + + _, err = conn.WriteToUDPAddrPort(packet, remoteAddr) + must(err) + } +} + +func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn) { + var ( + n int + packet = make([]byte, bufferSize) + err error + ) + + for { + // We assume that we're only receiving packets from one source. + n, err = conn.Read(packet[:bufferSize]) + must(err) + + packet = packet[:n] + _, err = iface.Write(packet) + must(err) + } +} diff --git a/stage2/startup.go b/stage2/startup.go new file mode 100644 index 0000000..65b92ec --- /dev/null +++ b/stage2/startup.go @@ -0,0 +1 @@ +package stage2 diff --git a/stage3/README.md b/stage3/README.md new file mode 100644 index 0000000..dc76e28 --- /dev/null +++ b/stage3/README.md @@ -0,0 +1,16 @@ +## Stage3: + +* Point-to-point Tunnel w/ no Encryption +* Server gets client's addr from first packet +* Add packet counter to detect skipped and late packets + +### Learnings + +* Directional packet loss is an issue. + * Sending to hetzner: ~380 Mbits/sec + * From hetzner: ~800 Mbits/sec +* Runs of dropped packets are generally small < 30 + * Saw a few cases of 100-200 +* Runs of correctly-sequenced packets are generally >> drops +* Late packets aren't so common +* Dropping late packets causes large slow-down. diff --git a/stage3/client.go b/stage3/client.go new file mode 100644 index 0000000..a8b9e98 --- /dev/null +++ b/stage3/client.go @@ -0,0 +1,35 @@ +package stage3 + +import ( + "fmt" + "net" + "net/netip" + "runtime/debug" +) + +func RunClient(serverAddrStr string) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + iface, err := openInterface(network, clientIP, netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", serverAddrStr, port)) + must(err) + + _, err = conn.WriteToUDPAddrPort([]byte{1, 2, 3, 4, 5, 6, 7, 8}, serverAddr) + must(err) + + go readFromIFace(iface, conn, clientIP, serverIP, serverAddr) + readFromConn(iface, conn, serverIP) +} diff --git a/stage3/cmd/client/build.sh b/stage3/cmd/client/build.sh new file mode 100755 index 0000000..c7d72ea --- /dev/null +++ b/stage3/cmd/client/build.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +go build +sudo setcap cap_net_admin+iep ./client +./client 144.76.78.93 diff --git a/stage3/cmd/client/main.go b/stage3/cmd/client/main.go new file mode 100644 index 0000000..e27e22f --- /dev/null +++ b/stage3/cmd/client/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "log" + "os" + "vppn/stage3" +) + +func main() { + if len(os.Args) != 2 { + log.Fatalf("Usage: %s ", os.Args[0]) + } + stage3.RunClient(os.Args[1]) +} diff --git a/stage3/cmd/server/build.sh b/stage3/cmd/server/build.sh new file mode 100755 index 0000000..fcc5787 --- /dev/null +++ b/stage3/cmd/server/build.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +go build +ssh kevin "killall server" +scp server kevin:/home/jdl/tmp/ +ssh root@kevin "sudo setcap cap_net_admin+iep /home/jdl/tmp/server" +ssh kevin "/home/jdl/tmp/server" diff --git a/stage3/cmd/server/main.go b/stage3/cmd/server/main.go new file mode 100644 index 0000000..e8430a4 --- /dev/null +++ b/stage3/cmd/server/main.go @@ -0,0 +1,7 @@ +package main + +import "vppn/stage3" + +func main() { + stage3.RunServer() +} diff --git a/stage3/interface.go b/stage3/interface.go new file mode 100644 index 0000000..fa8de32 --- /dev/null +++ b/stage3/interface.go @@ -0,0 +1,142 @@ +package stage3 + +import ( + "fmt" + "io" + "net" + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +const ( + if_mtu = 1200 + if_queue_len = 1000 +) + +func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { + if len(network) != 4 { + return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) + } + ip := net.IPv4(network[0], network[1], network[2], localIP) + + ////////////////////////// + // Create TUN Interface // + ////////////////////////// + + tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open TUN device: %w", err) + } + + // New interface request. + req, err := unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) + } + + // Flags: + // + // IFF_NO_PI => don't add packet info data to packets sent to the interface. + // IFF_TUN => create a TUN device handling IP packets. + req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) + + err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) + if err != nil { + return nil, fmt.Errorf("failed to set TUN device settings: %w", err) + } + + // Name may not be exactly the same? + name = req.Name() + + ///////////// + // Set MTU // + ///////////// + + // We need a socket file descriptor to set other options for some reason. + sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return nil, fmt.Errorf("failed to open socket: %w", err) + } + defer unix.Close(sockFD) + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create MTU interface request: %w", err) + } + + req.SetUint32(if_mtu) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { + return nil, fmt.Errorf("failed to set interface MTU: %w", err) + } + + ////////////////////// + // Set Queue Length // + ////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + req.SetUint16(if_queue_len) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { + return nil, fmt.Errorf("failed to set interface queue length: %w", err) + } + + ///////////////////// + // Set IP and Mask // + ///////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + if err := req.SetInet4Addr(ip.To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request IP: %w", err) + } + + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { + return nil, fmt.Errorf("failed to set interface IP: %w", err) + } + + // SET MASK - must happen after setting address. + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create mask interface request: %w", err) + } + + if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request mask: %w", err) + } + + if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { + return nil, fmt.Errorf("failed to set interface mask: %w", err) + } + + //////////////////////// + // Bring Interface Up // + //////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create up interface request: %w", err) + } + + // Get current flags. + if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to get interface flags: %w", err) + } + + flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING + + // Set UP flag / broadcast flags. + req.SetUint16(flags) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to set interface up: %w", err) + } + + return os.NewFile(uintptr(tunFD), "tun"), nil +} diff --git a/stage3/packet.go b/stage3/packet.go new file mode 100644 index 0000000..d8e96b7 --- /dev/null +++ b/stage3/packet.go @@ -0,0 +1,23 @@ +package stage3 + +import "unsafe" + +const headerSize = 9 + +type packetHeader struct { + SrcIP byte + Counter uint64 +} + +func (h packetHeader) Marshal(buf []byte) int { + buf = buf[:9] + buf[0] = h.SrcIP + *(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter + return headerSize +} + +func (h *packetHeader) Parse(buf []byte) int { + h.SrcIP = buf[0] + h.Counter = *(*uint64)(unsafe.Pointer(&buf[1])) + return headerSize +} diff --git a/stage3/packet_test.go b/stage3/packet_test.go new file mode 100644 index 0000000..ef643e4 --- /dev/null +++ b/stage3/packet_test.go @@ -0,0 +1,22 @@ +package stage3 + +import ( + "reflect" + "testing" +) + +func TestPacketHeader(t *testing.T) { + b := make([]byte, 1024) + + h := packetHeader{ + SrcIP: 8, + Counter: 2354, + } + n := h.Marshal(b) + h2 := packetHeader{} + h2.Parse(b[:n]) + + if !reflect.DeepEqual(h, h2) { + t.Fatal(h, h2) + } +} diff --git a/stage3/server.go b/stage3/server.go new file mode 100644 index 0000000..4bb3d87 --- /dev/null +++ b/stage3/server.go @@ -0,0 +1,147 @@ +package stage3 + +import ( + "fmt" + "io" + "log" + "net" + "net/netip" + "runtime/debug" +) + +var ( + network = []byte{10, 1, 1, 0} + serverIP = byte(1) + clientIP = byte(2) + port = uint16(5151) + netName = "testnet" + bufferSize = if_mtu * 2 +) + +func must(err error) { + if err != nil { + panic(err) + } +} + +func RunServer() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("%v", r) + debug.PrintStack() + } + }() + + iface, err := openInterface(network, serverIP, netName) + must(err) + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) + must(err) + + conn, err := net.ListenUDP("udp", myAddr) + must(err) + + // Get remoteAddr from a packet. + buf := make([]byte, 8) + _, remoteAddr, err := conn.ReadFromUDPAddrPort(buf) + log.Printf("Got remote addr: %v", remoteAddr) + must(err) + + go readFromIFace(iface, conn, serverIP, clientIP, remoteAddr) + readFromConn(iface, conn, clientIP) +} + +func readFromIFace(iface io.ReadWriteCloser, conn *net.UDPConn, localIP, remoteIP byte, remoteAddr netip.AddrPort) { + var ( + n int + packet = make([]byte, bufferSize) + version byte + ip byte + err error + counter uint64 + buf = make([]byte, bufferSize) + ) + + for { + n, err = iface.Read(packet[:bufferSize]) + must(err) + packet = packet[:n] + + if len(packet) < 20 { + log.Printf("Dropping small packet: %d", n) + continue + } + + packet = packet[:n] + version = packet[0] >> 4 + + switch version { + case 4: + ip = packet[19] + case 6: + ip = packet[39] + default: + log.Printf("Dropping packet with IP version: %d", version) + continue + } + + if ip != remoteIP { + log.Printf("Dropping packet for incorrect IP: %d", ip) + continue + } + + h := packetHeader{SrcIP: localIP, Counter: counter} + counter++ + buf = buf[:headerSize+len(packet)] + h.Marshal(buf) + copy(buf[headerSize:], packet) + + _, err = conn.WriteToUDPAddrPort(buf, remoteAddr) + must(err) + } +} + +func readFromConn(iface io.ReadWriteCloser, conn *net.UDPConn, remoteIP byte) { + var ( + n int + packet = make([]byte, bufferSize) + err error + counter uint64 + run uint64 + h packetHeader + ) + + for { + // We assume that we're only receiving packets from one source. + n, err = conn.Read(packet[:bufferSize]) + must(err) + + packet = packet[:n] + if len(packet) < headerSize { + fmt.Print("_") + continue + } + + h.Parse(packet) + if h.SrcIP != remoteIP { + fmt.Print("?") + continue + } + + if h.Counter == counter+1 { + run++ + counter = h.Counter + } else if h.Counter > counter+1 { + fmt.Printf("x(%d/%d)", h.Counter-counter+1, run) + run = 0 + counter = h.Counter + } else if h.Counter <= counter { + //log.Printf("Skipped late packet: -%d", counter-h.Counter) + //continue + fmt.Print("<") + } + + _, err = iface.Write(packet[headerSize:]) + must(err) + } +} diff --git a/stage3/startup.go b/stage3/startup.go new file mode 100644 index 0000000..65332e0 --- /dev/null +++ b/stage3/startup.go @@ -0,0 +1 @@ +package stage3