diff --git a/README.md b/README.md index 1a5ab6a..0e5a4ba 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ ## Roadmap +* `node` package + * rename `peerRepo` to `routingTable` + * create router type with `Get(ip) *peer` and `Mediator() *peer` methods + * connReader / Writer should have access to the peerRepo * Use default port 456 * Remove signing key from hub * Peer: UDP hole-punching diff --git a/cmd/vppn/main.go b/cmd/vppn/main.go index 5daa907..8c04016 100644 --- a/cmd/vppn/main.go +++ b/cmd/vppn/main.go @@ -2,10 +2,10 @@ package main import ( "log" - "vppn/peer" + "vppn/node" ) func main() { log.SetFlags(0) - peer.Main() + node.Main() } diff --git a/node/conn.go b/node/conn.go index 3e95220..2e2bce8 100644 --- a/node/conn.go +++ b/node/conn.go @@ -4,30 +4,26 @@ import ( "log" "net" "net/netip" + "sync" "sync/atomic" "vppn/fasttime" ) -// TODO: -type connRouter interface { - Lookup(byte) *peer - Mediator() *peer -} - type connWriter struct { *net.UDPConn + lock sync.Mutex localIP byte buf []byte counters [256]uint64 - lookup func(byte) *peer + routing *routingTable } -func newConnWriter(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connWriter { +func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *connWriter { w := &connWriter{ UDPConn: conn, localIP: localIP, buf: make([]byte, bufferSize), - lookup: lookup, + routing: routing, } for i := range w.counters { @@ -37,24 +33,36 @@ func newConnWriter(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *co return w } -func (w *connWriter) WriteTo(remoteIP, packetType byte, data []byte) error { - peer := w.lookup(remoteIP) +func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) error { + // TODO: Handle mediator. + peer := w.routing.Get(remoteIP) if peer == nil || peer.Addr == nil { log.Printf("No peer: %d", remoteIP) return nil } + if stream == streamData && !peer.Up { + log.Printf("Peer down: %d", remoteIP) + } + return w.WriteToPeer(peer, stream, data) +} + +func (w *connWriter) WriteToPeer(peer *peer, stream byte, data []byte) error { + w.lock.Lock() + + remoteIP := peer.IP h := header{ - Counter: atomic.AddUint64(&w.counters[remoteIP], 1), - SourceIP: w.localIP, - ViaIP: 0, - DestIP: remoteIP, - PacketType: packetType, + Counter: atomic.AddUint64(&w.counters[remoteIP], 1), + SourceIP: w.localIP, + ViaIP: 0, + DestIP: remoteIP, + Stream: stream, } buf := encryptPacket(&h, peer.SharedKey, data, w.buf) _, err := w.WriteToUDPAddrPort(buf, *peer.Addr) + w.lock.Unlock() return err } @@ -64,15 +72,15 @@ type connReader struct { *net.UDPConn localIP byte dupChecks [256]*dupCheck - lookup func(byte) *peer + routing *routingTable buf []byte } -func newConnReader(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connReader { +func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *connReader { r := &connReader{ UDPConn: conn, localIP: localIP, - lookup: lookup, + routing: routing, buf: make([]byte, bufferSize), } for i := range r.dupChecks { @@ -93,18 +101,20 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data data = buf[:n] if n < headerSize { + log.Printf("Dropping short packet: %d", n) continue // Packet it soo short. } h.Parse(data) + if len(data) != headerSize+int(h.DataSize) { - log.Printf("Incorrect size") - continue // Packet is corrupt. + log.Printf("Malformed packet: %d != %d", len(data), headerSize+int(h.DataSize)) + continue } - peer := r.lookup(h.SourceIP) + peer := r.routing.Get(h.SourceIP) if peer == nil { - log.Printf("No peer...") + log.Printf("No peer: %d...", h.SourceIP) continue } @@ -117,7 +127,7 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data out, data = data, out if r.dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("Duplicate...") + log.Printf("Duplicate: %d", h.Counter) continue } diff --git a/node/crypto.go b/node/crypto.go index 7240bb7..3da0156 100644 --- a/node/crypto.go +++ b/node/crypto.go @@ -1,6 +1,11 @@ package node -import "golang.org/x/crypto/nacl/box" +import ( + "sync" + "vppn/fasttime" + + "golang.org/x/crypto/nacl/box" +) // Encrypting the packet will also set the header's DataSize field. func encryptPacket(h *header, sharedKey, data, out []byte) []byte { @@ -24,3 +29,23 @@ func computeSharedKey(peerPubKey, privKey []byte) []byte { box.Precompute(&shared, (*[32]byte)(peerPubKey), (*[32]byte)(privKey)) return shared[:] } + +var ( + traceIDLock sync.Mutex + traceIDTime uint64 + traceIDCounter uint64 +) + +func newTraceID() (id uint64) { + traceIDLock.Lock() + defer traceIDLock.Unlock() + + now := uint64(fasttime.Now()) + if traceIDTime < now { + traceIDTime = now + traceIDCounter = 0 + } + traceIDCounter++ + + return traceIDTime<<30 + traceIDCounter +} diff --git a/node/crypto_test.go b/node/crypto_test.go index f0ee9f6..be4282c 100644 --- a/node/crypto_test.go +++ b/node/crypto_test.go @@ -33,11 +33,11 @@ func TestEncryptDecryptPacket(t *testing.T) { rand.Read(original) h := header{ - Counter: 2893749238, - SourceIP: 5, - ViaIP: 8, - DestIP: 12, - PacketType: 32, + Counter: 2893749238, + SourceIP: 5, + ViaIP: 8, + DestIP: 12, + Stream: 1, } encrypted := make([]byte, bufferSize) @@ -62,7 +62,6 @@ func TestEncryptDecryptPacket(t *testing.T) { } } -/* func BenchmarkEncryptPacket(b *testing.B) { _, privKey1, err := box.GenerateKey(rand.Reader) if err != nil { @@ -77,16 +76,24 @@ func BenchmarkEncryptPacket(b *testing.B) { sharedEncKey := [32]byte{} box.Precompute(&sharedEncKey, pubKey2, privKey1) - original := make([]byte, MTU) + original := make([]byte, if_mtu) rand.Read(original) - nonce := make([]byte, NONCE_SIZE) + nonce := make([]byte, headerSize) rand.Read(nonce) - encrypted := make([]byte, BUFFER_SIZE) + encrypted := make([]byte, bufferSize) + + h := header{ + Counter: 2893749238, + SourceIP: 5, + ViaIP: 8, + DestIP: 12, + Stream: 1, + } for i := 0; i < b.N; i++ { - encrypted = encryptPacket(sharedEncKey[:], nonce, original, encrypted) + encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) } } @@ -107,18 +114,27 @@ func BenchmarkDecryptPacket(b *testing.B) { sharedDecKey := [32]byte{} box.Precompute(&sharedDecKey, pubKey1, privKey2) - original := make([]byte, MTU) + original := make([]byte, if_mtu) rand.Read(original) - nonce := make([]byte, NONCE_SIZE) + nonce := make([]byte, headerSize) rand.Read(nonce) - encrypted := make([]byte, BUFFER_SIZE) - encrypted = encryptPacket(sharedEncKey[:], nonce, original, encrypted) + h := header{ + Counter: 2893749238, + SourceIP: 5, + ViaIP: 8, + DestIP: 12, + Stream: 1, + } - decrypted := make([]byte, MTU) + encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize)) + decrypted := make([]byte, bufferSize) + var ok bool for i := 0; i < b.N; i++ { - decrypted, _ = decryptPacket(sharedDecKey[:], encrypted, decrypted) + decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) + if !ok { + panic(ok) + } } } -*/ diff --git a/node/dupcheck.go b/node/dupcheck.go index fac7a72..e960bd4 100644 --- a/node/dupcheck.go +++ b/node/dupcheck.go @@ -1,5 +1,7 @@ package node +import "log" + type dupCheck struct { bitSet head int @@ -20,6 +22,7 @@ func (dc *dupCheck) IsDup(counter uint64) bool { // Before head => it's late, say it's a dup. if counter < dc.headCounter { + log.Printf("Late: %d", counter) return true } @@ -27,6 +30,7 @@ func (dc *dupCheck) IsDup(counter uint64) bool { if counter < dc.tailCounter { index := (int(counter-dc.headCounter) + dc.head) % bitSetSize if dc.Get(index) { + log.Printf("Dup: %d, %d", counter, dc.tailCounter) return true } diff --git a/node/files.go b/node/files.go new file mode 100644 index 0000000..6f0ec77 --- /dev/null +++ b/node/files.go @@ -0,0 +1,82 @@ +package node + +import ( + "encoding/json" + "log" + "os" + "path/filepath" + "vppn/m" +) + +func configDir(netName string) string { + d, err := os.UserHomeDir() + if err != nil { + log.Fatalf("Failed to get user home directory: %v", err) + } + return filepath.Join(d, ".vppn", netName) +} + +func peerConfigPath(netName string) string { + return filepath.Join(configDir(netName), "peer-config.json") +} + +func peerStatePath(netName string) string { + return filepath.Join(configDir(netName), "peer-state.json") +} + +func storeJson(x any, outPath string) error { + outDir := filepath.Dir(outPath) + _ = os.MkdirAll(outDir, 0700) + + tmpPath := outPath + ".tmp" + buf, err := json.Marshal(x) + if err != nil { + return err + } + + f, err := os.Create(tmpPath) + if err != nil { + return err + } + + if _, err := f.Write(buf); err != nil { + f.Close() + return err + } + + if err := f.Sync(); err != nil { + f.Close() + return err + } + + if err := f.Close(); err != nil { + return err + } + + return os.Rename(tmpPath, outPath) +} + +func storePeerConfig(netName string, pc m.PeerConfig) error { + return storeJson(pc, peerConfigPath(netName)) +} + +func storeNetworkState(netName string, ps m.NetworkState) error { + return storeJson(ps, peerStatePath(netName)) +} + +func loadJson(dataPath string, ptr any) error { + data, err := os.ReadFile(dataPath) + if err != nil { + return err + } + + return json.Unmarshal(data, ptr) +} + +func loadPeerConfig(netName string) (pc m.PeerConfig, err error) { + return pc, loadJson(peerConfigPath(netName), &pc) +} + +func loadNetworkState(netName string) (ps m.NetworkState, err error) { + return ps, loadJson(peerStatePath(netName), &ps) +} diff --git a/node/header.go b/node/header.go index 44affb9..05ce29b 100644 --- a/node/header.go +++ b/node/header.go @@ -2,15 +2,19 @@ package node import "unsafe" -const headerSize = 24 +const ( + headerSize = 24 + streamData = 1 + streamRouting = 2 +) type header struct { - Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. - SourceIP byte - ViaIP byte - DestIP byte - PacketType byte // The packet type. See PACKET_* constants. - DataSize uint16 // Data size following associated data. + Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. + SourceIP byte + ViaIP byte + DestIP byte + Stream byte // See stream* constants. + DataSize uint16 // Data size following associated data. } func (hdr *header) Parse(nb []byte) { @@ -18,7 +22,7 @@ func (hdr *header) Parse(nb []byte) { hdr.SourceIP = nb[8] hdr.ViaIP = nb[9] hdr.DestIP = nb[10] - hdr.PacketType = nb[11] + hdr.Stream = nb[11] hdr.DataSize = *(*uint16)(unsafe.Pointer(&nb[12])) } @@ -27,6 +31,6 @@ func (hdr header) Marshal(buf []byte) { buf[8] = hdr.SourceIP buf[9] = hdr.ViaIP buf[10] = hdr.DestIP - buf[11] = hdr.PacketType + buf[11] = hdr.Stream *(*uint16)(unsafe.Pointer(&buf[12])) = hdr.DataSize } diff --git a/node/header_test.go b/node/header_test.go index e4ff3a3..6c343bc 100644 --- a/node/header_test.go +++ b/node/header_test.go @@ -4,12 +4,12 @@ import "testing" func TestHeaderMarshalParse(t *testing.T) { nIn := header{ - Counter: 3212, - SourceIP: 34, - ViaIP: 20, - DestIP: 200, - PacketType: 44, - DataSize: 1235, + Counter: 3212, + SourceIP: 34, + ViaIP: 20, + DestIP: 200, + Stream: 44, + DataSize: 1235, } buf := make([]byte, headerSize) diff --git a/node/interface.go b/node/interface.go index 2dc6ba6..c5edf3e 100644 --- a/node/interface.go +++ b/node/interface.go @@ -3,6 +3,7 @@ package node import ( "fmt" "io" + "log" "net" "os" "syscall" @@ -22,23 +23,27 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) return nil, ip, err } - if n < 20 { - continue // Packet too short. - } - buf = buf[:n] version = buf[0] >> 4 switch version { case 4: + if n < 20 { + log.Printf("Short IPv4 packet: %d", len(buf)) + continue + } ip = buf[19] + case 6: if len(buf) < 40 { - continue // Packet too short. + log.Printf("Short IPv6 packet: %d", len(buf)) + continue } ip = buf[39] + default: - continue // Invalid version. + log.Printf("Invalid IP packet version: %v", version) + continue } return buf, ip, nil @@ -47,7 +52,7 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) const ( if_mtu = 1200 - if_queue_len = 1000 + if_queue_len = 2048 ) func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { diff --git a/node/main.go b/node/main.go new file mode 100644 index 0000000..092751a --- /dev/null +++ b/node/main.go @@ -0,0 +1,190 @@ +package node + +import ( + "encoding/json" + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "net/netip" + "os" + "runtime/debug" + "vppn/m" +) + +func panicHandler() { + if r := recover(); r != nil { + log.Fatalf("\n %v\n\nstacktrace from panic: %s\n", r, string(debug.Stack())) + } +} + +func Main() { + defer panicHandler() + + var ( + netName string + initURL string + listenIP string + port int + ) + + flag.StringVar(&netName, "name", "", "[REQUIRED] The network name.") + flag.StringVar(&initURL, "init-url", "", "Initializes peer from the hub URL.") + flag.StringVar(&listenIP, "listen-ip", "", "IP address to listen on.") + flag.IntVar(&port, "port", 0, "Port to listen on.") + flag.Parse() + + if netName == "" { + flag.Usage() + os.Exit(1) + } + + if initURL != "" { + mainInit(netName, initURL) + return + } + + main(netName, listenIP, uint16(port)) +} + +func mainInit(netName, initURL string) { + if _, err := loadPeerConfig(netName); err == nil { + log.Fatalf("Network is already initialized.") + } + + resp, err := http.Get(initURL) + if err != nil { + log.Fatalf("Failed to fetch data from hub: %v", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("Failed to read response body: %v", err) + } + + peerConfig := m.PeerConfig{} + if err := json.Unmarshal(data, &peerConfig); err != nil { + log.Fatalf("Failed to parse configuration: %v", err) + } + + if err := storePeerConfig(netName, peerConfig); err != nil { + log.Fatalf("Failed to store configuration: %v", err) + } + + log.Print("Initialization successful.") +} + +// ---------------------------------------------------------------------------- + +func main(netName, listenIP string, port uint16) { + conf, err := loadPeerConfig(netName) + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + port = determinePort(conf.Port, port) + + iface, err := openInterface(conf.Network, conf.PeerIP, netName) + if err != nil { + log.Fatalf("Failed to open interface: %v", err) + } + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", listenIP, port)) + if err != nil { + log.Fatalf("Failed to resolve UDP address: %v", err) + } + + conn, err := net.ListenUDP("udp", myAddr) + if err != nil { + log.Fatalf("Failed to open UDP port: %v", err) + } + + routing := newRoutingTable() + + w := newConnWriter(conn, conf.PeerIP, routing) + r := newConnReader(conn, conf.PeerIP, routing) + + router := newRouter(netName, conf, routing, w) + + go nodeConnReader(r, w, iface, router) + nodeIFaceReader(w, iface, router) +} + +// ---------------------------------------------------------------------------- + +func determinePort(confPort, portFromCommandLine uint16) uint16 { + if portFromCommandLine != 0 { + return portFromCommandLine + } + if confPort != 0 { + return confPort + } + return 456 +} + +// ---------------------------------------------------------------------------- + +func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) { + defer panicHandler() + var ( + remoteAddr netip.AddrPort + h header + buf = make([]byte, bufferSize) + data []byte + err error + ) + + for { + remoteAddr, h, data, err = r.Read(buf) + if err != nil { + log.Fatalf("Failed to read from UDP connection: %v", err) + } + + switch h.Stream { + + case streamData: + if _, err = iface.Write(data); err != nil { + log.Printf("Malformed data from peer %d: %v", h.SourceIP, err) + } + + case streamRouting: + router.HandlePacket(h.SourceIP, remoteAddr, data) + + default: + log.Printf("Dropping unknown stream: %d", h.Stream) + } + } +} + +// ---------------------------------------------------------------------------- + +func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { + + var ( + buf = make([]byte, bufferSize) + packet []byte + remoteIP byte + err error + ) + + for { + + packet, remoteIP, err = readNextPacket(iface, buf) + if err != nil { + log.Fatalf("Failed to read from interface: %v", err) + } + + if remoteIP == w.localIP { + //log.Printf("Incoming packet for self: %x", packet) + //iface.Write(packet) + continue + } + + if err := w.WriteTo(remoteIP, streamData, packet); err != nil { + log.Fatalf("Failed to write to network: %v", err) + } + } +} diff --git a/node/node.go b/node/node.go new file mode 100644 index 0000000..2b4023a --- /dev/null +++ b/node/node.go @@ -0,0 +1 @@ +package node diff --git a/node/peer.go b/node/peer.go index 7f8be25..2b4023a 100644 --- a/node/peer.go +++ b/node/peer.go @@ -1,30 +1 @@ package node - -import ( - "net/netip" - "sync/atomic" -) - -type peer struct { - IP byte - Addr *netip.AddrPort // If we have direct connection, otherwise use mediator. - SharedKey []byte -} - -type peerRepo [256]*atomic.Pointer[peer] - -func newPeerRepo() peerRepo { - pr := peerRepo{} - for i := range pr { - pr[i] = &atomic.Pointer[peer]{} - } - return pr -} - -func (pr peerRepo) Get(ip byte) *peer { - return pr[ip].Load() -} - -func (pr *peerRepo) Set(ip byte, p *peer) { - pr[ip].Store(p) -} diff --git a/node/peerstate.go b/node/peerstate.go new file mode 100644 index 0000000..2b4023a --- /dev/null +++ b/node/peerstate.go @@ -0,0 +1 @@ +package node diff --git a/node/router.go b/node/router.go new file mode 100644 index 0000000..bd3e5a9 --- /dev/null +++ b/node/router.go @@ -0,0 +1,163 @@ +package node + +import ( + "encoding/json" + "io" + "log" + "net/http" + "net/netip" + "net/url" + "sync/atomic" + "time" + "vppn/m" +) + +type peer struct { + Up bool // No data will be sent to peers that are down. + IP byte + Addr *netip.AddrPort // If we have direct connection, otherwise use mediator. + SharedKey []byte +} + +// ---------------------------------------------------------------------------- + +type routingTable struct { + table [256]*atomic.Pointer[peer] + mediator *atomic.Pointer[peer] +} + +func newRoutingTable() *routingTable { + r := routingTable{ + mediator: &atomic.Pointer[peer]{}, + } + + for i := range r.table { + r.table[i] = &atomic.Pointer[peer]{} + } + + return &r +} + +func (r *routingTable) Get(ip byte) *peer { + return r.table[ip].Load() +} + +func (r *routingTable) Set(ip byte, p *peer) { + r.table[ip].Store(p) +} + +// ---------------------------------------------------------------------------- + +type router struct { + netName string + *routingTable + peerSupers [256]*peerSupervisor +} + +func newRouter(netName string, conf m.PeerConfig, routingData *routingTable, w *connWriter) *router { + r := &router{ + netName: netName, + routingTable: routingData, + } + + for i := range r.peerSupers { + r.peerSupers[i] = newPeerSupervisor( + conf, + byte(i), + w, + r.routingTable) + } + + // TODO: Handle Mediator + go r.pollHub(conf) + + return r +} + +// ---------------------------------------------------------------------------- + +func (r *router) HandlePacket(sourceIP byte, remoteAddr netip.AddrPort, data []byte) { + p := routingPacket{} + if err := p.Parse(data); err != nil { + log.Printf("Dropping malformed routing packet: %v", err) + return + } + + w := routingPacketWrapper{ + routingPacket: p, + Addr: remoteAddr, + } + + r.peerSupers[sourceIP].HandlePacket(w) +} + +// ---------------------------------------------------------------------------- + +func (r *router) pollHub(conf m.PeerConfig) { + defer panicHandler() + + u, err := url.Parse(conf.HubAddress) + if err != nil { + log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err) + } + u.Path = "/peer/fetch-state/" + + client := &http.Client{Timeout: 8 * time.Second} + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Header: http.Header{}, + } + req.SetBasicAuth("", conf.APIKey) + + // TODO: Before we start polling, load state from the file system. + state, err := loadNetworkState(r.netName) + if err != nil { + log.Printf("Failed to load network state: %v", err) + log.Printf("Polling hub...") + r._pollHub(conf, client, req) + } else { + r.applyNetworkState(conf, state) + } + + for range time.Tick(64 * time.Second) { + r._pollHub(conf, client, req) + } +} + +func (r *router) _pollHub(conf m.PeerConfig, client *http.Client, req *http.Request) { + var state m.NetworkState + + log.Printf("Fetching peer state from %s...", conf.HubAddress) + resp, err := client.Do(req) + if err != nil { + log.Printf("Failed to fetch peer state: %v", err) + return + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + log.Printf("Failed to read body from hub: %v", err) + return + } + + if err := json.Unmarshal(body, &state); err != nil { + log.Printf("Failed to unmarshal response from hub: %v", err) + return + } + + r.applyNetworkState(conf, state) + + if err := storeNetworkState(r.netName, state); err != nil { + log.Printf("Failed to store network state: %v", err) + } +} + +func (r *router) applyNetworkState(conf m.PeerConfig, state m.NetworkState) { + for i := range state.Peers { + if i != int(conf.PeerIP) { + r.peerSupers[i].HandlePeerUpdate(state.Peers[i]) + } + } +} diff --git a/node/routingpacket.go b/node/routingpacket.go new file mode 100644 index 0000000..4e35055 --- /dev/null +++ b/node/routingpacket.go @@ -0,0 +1,44 @@ +package node + +import ( + "errors" + "unsafe" +) + +var errMalformedPacket = errors.New("malformed packet") + +const ( + packetTypeInvalid = iota + + // Used to maintain connection. + packetTypePing + packetTypePong +) + +type routingPacket struct { + Type byte // One of the packetType* constants. + TraceID uint64 // For matching requests and responses. +} + +func newRoutingPacket(reqType byte, traceID uint64) routingPacket { + return routingPacket{ + Type: reqType, + TraceID: traceID, + } +} + +func (p routingPacket) Marshal(buf []byte) []byte { + buf = buf[:32] // Reserve 32 bytes just in case we need to add anything. + buf[0] = p.Type + *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.TraceID) + return buf +} + +func (p *routingPacket) Parse(buf []byte) error { + if len(buf) != 32 { + return errMalformedPacket + } + p.Type = buf[0] + p.TraceID = *(*uint64)(unsafe.Pointer(&buf[1])) + return nil +} diff --git a/node/tmp-server.go b/node/tmp-server.go index acae4e5..179a8a4 100644 --- a/node/tmp-server.go +++ b/node/tmp-server.go @@ -1,14 +1,6 @@ package node -import ( - "fmt" - "io" - "log" - "net" - "net/netip" - "runtime/debug" -) - +/* var ( network = []byte{10, 1, 1, 0} serverIP = byte(1) @@ -30,7 +22,7 @@ func must(err error) { type TmpNode struct { network []byte localIP byte - peers peerRepo + router *router port uint16 netName string iface io.ReadWriteCloser @@ -46,7 +38,7 @@ func NewTmpNodeServer() *TmpNode { n := &TmpNode{ localIP: serverIP, network: network, - peers: newPeerRepo(), + router: &router{table: newPeerRepo()}, port: port, netName: netName, pubKey: pubKey1, @@ -63,10 +55,10 @@ func NewTmpNodeServer() *TmpNode { conn, err := net.ListenUDP("udp", myAddr) must(err) - n.w = newConnWriter(conn, n.localIP, n.peers.Get) - n.r = newConnReader(conn, n.localIP, n.peers.Get) + n.w = newConnWriter(conn, n.localIP, n.router) + n.r = newConnReader(conn, n.localIP, n.router) - n.peers.Set(clientIP, &peer{ + n.router.table.Set(clientIP, &peer{ IP: clientIP, SharedKey: computeSharedKey(pubKey2, n.privKey), }) @@ -80,7 +72,7 @@ func NewTmpNodeClient(srvAddrStr string) *TmpNode { n := &TmpNode{ localIP: clientIP, network: network, - peers: newPeerRepo(), + router: &router{table: newPeerRepo()}, port: port, netName: netName, pubKey: pubKey2, @@ -97,13 +89,13 @@ func NewTmpNodeClient(srvAddrStr string) *TmpNode { conn, err := net.ListenUDP("udp", myAddr) must(err) - n.w = newConnWriter(conn, n.localIP, n.peers.Get) - n.r = newConnReader(conn, n.localIP, n.peers.Get) + n.w = newConnWriter(conn, n.localIP, n.router) + n.r = newConnReader(conn, n.localIP, n.router) serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", srvAddrStr, port)) must(err) - n.peers.Set(serverIP, &peer{ + n.router.table.Set(serverIP, &peer{ IP: serverIP, Addr: &serverAddr, SharedKey: computeSharedKey(pubKey1, n.privKey), @@ -129,7 +121,7 @@ func (n *TmpNode) RunServer() { log.Printf("Got remote addr: %d -> %v", h.SourceIP, remoteAddr) must(err) - n.peers.Set(h.SourceIP, &peer{ + n.router.table.Set(h.SourceIP, &peer{ IP: h.SourceIP, Addr: &remoteAddr, SharedKey: computeSharedKey(pubKey2, n.privKey), @@ -144,7 +136,7 @@ func (n *TmpNode) RunServer() { func (n *TmpNode) RunClient() { defer func() { if r := recover(); r != nil { - fmt.Printf("%v", r) + fmt.Printf("%v\n", r) debug.PrintStack() } }() @@ -184,6 +176,10 @@ func (node *TmpNode) readFromConn() { // We assume that we're only receiving packets from one source. _, err = node.iface.Write(packet) - must(err) + if err != nil { + log.Printf("Got error: %v", err) + } + //must(err) } } +*/ diff --git a/node/tmp_peerstate.go b/node/tmp_peerstate.go new file mode 100644 index 0000000..5e365ed --- /dev/null +++ b/node/tmp_peerstate.go @@ -0,0 +1,312 @@ +package node + +import ( + "fmt" + "log" + "net/netip" + "time" + "vppn/m" +) + +const ( + connectTimeout = 6 * time.Second + pingInterval = 6 * time.Second + timeoutInterval = 20 * time.Second +) + +type routingPacketWrapper struct { + routingPacket + Addr netip.AddrPort // Source. +} + +type peerSupervisor struct { + // Constants: + localIP byte + localPublic bool + remoteIP byte + privKey []byte + + // Shared data: + w *connWriter + table *routingTable + + packets chan routingPacketWrapper + peerUpdates chan *m.Peer + + // Peer-related items. + version int64 // Ony accessed in HandlePeerUpdate. + peer *m.Peer + remoteAddrPort *netip.AddrPort + sharedKey []byte + + // Used by our state functions. + pingTimer *time.Timer + timeoutTimer *time.Timer + buf []byte +} + +// ---------------------------------------------------------------------------- + +func newPeerSupervisor( + conf m.PeerConfig, + remoteIP byte, + w *connWriter, + table *routingTable, +) *peerSupervisor { + s := &peerSupervisor{ + localIP: conf.PeerIP, + remoteIP: remoteIP, + privKey: conf.EncPrivKey, + w: w, + table: table, + packets: make(chan routingPacketWrapper, 256), + peerUpdates: make(chan *m.Peer, 1), + pingTimer: time.NewTimer(pingInterval), + timeoutTimer: time.NewTimer(timeoutInterval), + buf: make([]byte, bufferSize), + } + + _, s.localPublic = netip.AddrFromSlice(conf.PublicIP) + + go s.mainLoop() + return s +} + +func (s *peerSupervisor) logf(msg string, args ...any) { + msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg + log.Printf(msg, args...) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) mainLoop() { + defer panicHandler() + state := s.stateInit + for { + state = state() + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) { + if p != nil { + if p.Version == s.version { + return + } + s.logf("New peer version: %d", p.Version) + s.version = p.Version + } else { + s.version = 0 + } + + s.peerUpdates <- p +} + +func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) { + select { + case s.packets <- w: + default: + // Drop + } +} + +// ---------------------------------------------------------------------------- + +type stateFunc func() stateFunc + +func (s *peerSupervisor) stateInit() stateFunc { + if s.peer == nil { + return s.stateDisconnected + } + + addr, ok := netip.AddrFromSlice(s.peer.PublicIP) + if ok { + addrPort := netip.AddrPortFrom(addr, s.peer.Port) + s.remoteAddrPort = &addrPort + } else { + s.remoteAddrPort = nil + } + s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey) + + return s.stateSelectRole() +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) stateDisconnected() stateFunc { + s.clearRoutingTable() + + for { + select { + case <-s.packets: + // Drop + case s.peer = <-s.peerUpdates: + return s.stateInit + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) stateSelectRole() stateFunc { + s.logf("STATE: SelectRole") + s.updateRoutingTable(false) + + if s.remoteAddrPort != nil { + // If both remote and local are public, one side acts as client, and one + // side as server. + if s.localPublic && s.localIP < s.peer.PeerIP { + return s.stateAccept + } + return s.stateDial + } + + // We're public, remote is not => can only wait for connection + if s.localPublic { + return s.stateAccept + } + + // Both non-public => need to use mediator. + return s.stateMediated +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) stateAccept() stateFunc { + s.logf("STATE: Accept") + + for { + + select { + case pkt := <-s.packets: + switch pkt.Type { + + case packetTypePing: + s.remoteAddrPort = &pkt.Addr + s.updateRoutingTable(true) + s.sendPong(pkt.TraceID) + return s.stateConnected + + default: + // Still waiting for ping... + } + + case s.peer = <-s.peerUpdates: + return s.stateInit + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) stateDial() stateFunc { + s.logf("STATE: Dial") + s.updateRoutingTable(false) + + s.sendPing() + + for { + select { + case pkt := <-s.packets: + + switch pkt.Type { + + case packetTypePong: + s.updateRoutingTable(true) + return s.stateConnected + + default: + // Ignore + } + + case <-s.pingTimer.C: + s.sendPing() + + case s.peer = <-s.peerUpdates: + return s.stateInit + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) stateConnected() stateFunc { + s.logf("STATE: Connected") + + s.timeoutTimer.Reset(timeoutInterval) + + for { + select { + + case <-s.pingTimer.C: + s.sendPing() + + case <-s.timeoutTimer.C: + s.logf("Timeout") + return s.stateInit + + case pkt := <-s.packets: + switch pkt.Type { + case packetTypePing: + s.sendPong(pkt.TraceID) + + // Server should always follow remote port. + if s.localPublic { + if pkt.Addr != *s.remoteAddrPort { + s.remoteAddrPort = &pkt.Addr + s.updateRoutingTable(true) + } + } + + case packetTypePong: + s.timeoutTimer.Reset(timeoutInterval) + + default: + // Drop packet. + } + + case s.peer = <-s.peerUpdates: + s.logf("New peer: %v", s.peer) + return s.stateInit + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) stateMediated() stateFunc { + s.logf("STATE: Mediated") + // TODO + select {} +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) clearRoutingTable() { + s.table.Set(s.remoteIP, nil) +} + +func (s *peerSupervisor) updateRoutingTable(up bool) { + s.table.Set(s.remoteIP, &peer{ + Up: up, + IP: s.remoteIP, + Addr: s.remoteAddrPort, + SharedKey: s.sharedKey, + }) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) sendPing() uint64 { + traceID := newTraceID() + pkt := newRoutingPacket(packetTypePing, traceID) + s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) + s.pingTimer.Reset(pingInterval) + return traceID +} + +func (s *peerSupervisor) sendPong(traceID uint64) { + pkt := newRoutingPacket(packetTypePong, traceID) + s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) +}