From 1a6503bbda4946dc7687893fc643cd037ff6a6be Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 29 Jan 2025 11:45:09 +0100 Subject: [PATCH] wip --- node/README.md | 5 +- node/connwriter.go | 46 +++--- node/connwriter_test.go | 69 ++------ node/dupcheck.go | 4 +- node/header.go | 12 ++ node/ifreader.go | 2 +- peer/bitset.go | 21 +++ peer/bitset_test.go | 48 ++++++ peer/cipher-control.go | 26 +++ peer/cipher-control_test.go | 122 ++++++++++++++ peer/cipher-data.go | 61 +++++++ peer/cipher-data_test.go | 141 ++++++++++++++++ peer/cipher-discovery.go | 13 ++ peer/connreader.go | 141 ++++++++++++++++ peer/connreader_test.go | 318 ++++++++++++++++++++++++++++++++++++ peer/connwriter.go | 80 +++++++++ peer/connwriter_test.go | 240 +++++++++++++++++++++++++++ peer/controlmessage.go | 58 +++++++ peer/crypto.go | 113 +++++++++++++ peer/crypto_test.go | 213 ++++++++++++++++++++++++ peer/dupcheck.go | 76 +++++++++ peer/dupcheck_test.go | 57 +++++++ peer/errors.go | 10 ++ peer/globals.go | 19 +++ peer/header.go | 49 ++++++ peer/header_test.go | 21 +++ peer/ifreader.go | 100 ++++++++++++ peer/ifreader_test.go | 232 ++++++++++++++++++++++++++ peer/ifwriter.go | 5 + peer/interfaces.go | 28 ++++ peer/mcwriter.go | 62 +++++++ peer/mcwriter_test.go | 102 ++++++++++++ peer/packets-util.go | 190 +++++++++++++++++++++ peer/packets-util_test.go | 56 +++++++ peer/packets.go | 123 ++++++++++++++ peer/packets_test.go | 1 + peer/state.go | 29 ++++ 37 files changed, 2808 insertions(+), 85 deletions(-) create mode 100644 peer/bitset.go create mode 100644 peer/bitset_test.go create mode 100644 peer/cipher-control.go create mode 100644 peer/cipher-control_test.go create mode 100644 peer/cipher-data.go create mode 100644 peer/cipher-data_test.go create mode 100644 peer/cipher-discovery.go create mode 100644 peer/connreader.go create mode 100644 peer/connreader_test.go create mode 100644 peer/connwriter.go create mode 100644 peer/connwriter_test.go create mode 100644 peer/controlmessage.go create mode 100644 peer/crypto.go create mode 100644 peer/crypto_test.go create mode 100644 peer/dupcheck.go create mode 100644 peer/dupcheck_test.go create mode 100644 peer/errors.go create mode 100644 peer/globals.go create mode 100644 peer/header.go create mode 100644 peer/header_test.go create mode 100644 peer/ifreader.go create mode 100644 peer/ifreader_test.go create mode 100644 peer/ifwriter.go create mode 100644 peer/interfaces.go create mode 100644 peer/mcwriter.go create mode 100644 peer/mcwriter_test.go create mode 100644 peer/packets-util.go create mode 100644 peer/packets-util_test.go create mode 100644 peer/packets.go create mode 100644 peer/packets_test.go create mode 100644 peer/state.go diff --git a/node/README.md b/node/README.md index 30b77c4..58b4298 100644 --- a/node/README.md +++ b/node/README.md @@ -2,11 +2,10 @@ ## Refactoring for Testability -* [ ] connWriter - * [ ] Separate send/relay calls +* [x] connWriter * [x] mcWriter * [x] ifWriter -* [ ] ifReader +* [ ] ifReader (testing) * [ ] connReader * [ ] mcReader * [ ] hubPoller diff --git a/node/connwriter.go b/node/connwriter.go index 597b886..62caa75 100644 --- a/node/connwriter.go +++ b/node/connwriter.go @@ -68,18 +68,18 @@ func newConnWriter(conn udpAddrPortWriter, localIP byte) *connWriter { // Not safe for concurrent use. Should only be called by supervisor. func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { - buf := pkt.Marshal(w.cBuf1) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&w.counters[route.IP], 1), - SourceIP: w.localIP, - DestIP: route.IP, - } - buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) + buf := w.encryptControlPacket(pkt, route) w.writeTo(buf, route.RemoteAddr) } +// Relay control packet. Routes must not be nil. func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { + buf := w.encryptControlPacket(pkt, route) + w.relayPacket(buf, w.cBuf1, route, relay) +} + +// Encrypted packet will occupy cBuf2. +func (w *connWriter) encryptControlPacket(pkt marshaller, route *peerRoute) []byte { buf := pkt.Marshal(w.cBuf1) h := header{ StreamID: controlStreamID, @@ -87,12 +87,11 @@ func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) SourceIP: w.localIP, DestIP: route.IP, } - buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) - w.relayPacket(buf, w.cBuf1, route, relay) + return route.ControlCipher.Encrypt(h, buf, w.cBuf2) } // Not safe for concurrent use. Should only be called by ifReader. -func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) { +func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { h := header{ StreamID: dataStreamID, Counter: atomic.AddUint64(&w.counters[route.IP], 1), @@ -101,16 +100,21 @@ func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) { } enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) - - if route.Direct { - w.writeTo(enc, route.RemoteAddr) - return - } - - w.relayPacket(enc, w.dBuf2, route, relay) + w.writeTo(enc, route.RemoteAddr) } -// TODO: RelayDataPacket +// Relay a data packet. Routes must not be nil. +func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&w.counters[route.IP], 1), + SourceIP: w.localIP, + DestIP: route.IP, + } + + enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) + w.relayPacket(enc, w.dBuf2, route, relay) +} // Safe for concurrent use. Should only be called by connReader. // @@ -122,10 +126,6 @@ func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { } func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) { - if relay == nil || !relay.Up { - return - } - h := header{ StreamID: dataStreamID, Counter: atomic.AddUint64(&w.counters[relay.IP], 1), diff --git a/node/connwriter_test.go b/node/connwriter_test.go index 595d5b7..388fbbc 100644 --- a/node/connwriter_test.go +++ b/node/connwriter_test.go @@ -126,7 +126,7 @@ func TestConnWriter_SendControlPacket_direct(t *testing.T) { } // Testing if we can relay a packet via an intermediary. -func TestConnWriter_SendControlPacket_relay(t *testing.T) { +func TestConnWriter_RelayControlPacket_relay(t *testing.T) { route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() writer := &testUDPAddrPortWriter{} @@ -159,40 +159,6 @@ func TestConnWriter_SendControlPacket_relay(t *testing.T) { } } -// Testing that a nil relay doesn't cause an issue. -func TestConnWriter_SendControlPacket_relay_relayNil(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.RelayControlPacket(in, route, nil) - - out := writer.Written() - if len(out) != 0 { - t.Fatal(out) - } - -} - -// Testing that we don't send anything if the relay isn't up. -func TestConnWriter_SendControlPacket_relay_relayNotUp(t *testing.T) { - route, rRoute, relay, _ := testConnWriter_getTestRoutes() - relay.Up = false - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.RelayControlPacket(in, route, relay) - - out := writer.Written() - if len(out) != 0 { - t.Fatal(out) - } -} - // Testing that we can send a data packet directly to a remote route. func TestConnWriter_SendDataPacket_direct(t *testing.T) { route, rRoute, _, _ := testConnWriter_getTestRoutes() @@ -202,7 +168,7 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) { w := newConnWriter(writer, rRoute.IP) in := []byte("hello world!") - w.SendDataPacket(in, route, nil) + w.SendDataPacket(in, route) out := writer.Written() if len(out) != 1 { @@ -224,14 +190,14 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) { } // Testing that we can relay a data packet via a relay. -func TestConnWriter_SendDataPacket_relay(t *testing.T) { +func TestConnWriter_RelayDataPacket_relay(t *testing.T) { route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() writer := &testUDPAddrPortWriter{} w := newConnWriter(writer, rRoute.IP) in := []byte("Hello world!") - w.SendDataPacket(in, route, relay) + w.RelayDataPacket(in, route, relay) out := writer.Written() if len(out) != 1 { @@ -257,35 +223,26 @@ func TestConnWriter_SendDataPacket_relay(t *testing.T) { } } -// Testing that we don't attempt to relay if the relay is nil. -func TestConnWriter_SendDataPacket_relay_relayNil(t *testing.T) { +// Testing that we can send an already encrypted packet. +func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { route, rRoute, _, _ := testConnWriter_getTestRoutes() writer := &testUDPAddrPortWriter{} w := newConnWriter(writer, rRoute.IP) in := []byte("Hello world!") - w.SendDataPacket(in, route, nil) + w.SendEncryptedDataPacket(in, route) out := writer.Written() - if len(out) != 0 { + if len(out) != 1 { t.Fatal(out) } -} -// Testing that we don't attempt to relay if the relay isn't up. -func TestConnWriter_SendDataPacket_relay_relayNotUp(t *testing.T) { - route, rRoute, relay, _ := testConnWriter_getTestRoutes() - relay.Up = false + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.SendDataPacket(in, route, relay) - - out := writer.Written() - if len(out) != 0 { - t.Fatal(out) + if !bytes.Equal(out[0].Data, in) { + t.Fatal(out[0]) } } diff --git a/node/dupcheck.go b/node/dupcheck.go index fac7a72..76792ae 100644 --- a/node/dupcheck.go +++ b/node/dupcheck.go @@ -38,14 +38,14 @@ func (dc *dupCheck) IsDup(counter uint64) bool { delta := counter - dc.tailCounter // Full clear. - if delta >= bitSetSize { + if delta >= bitSetSize-1 { dc.ClearAll() dc.Set(0) dc.tail = 1 dc.head = 2 dc.tailCounter = counter + 1 - dc.headCounter = dc.tailCounter - bitSetSize + dc.headCounter = dc.tailCounter - bitSetSize + 1 return false } diff --git a/node/header.go b/node/header.go index 9d0417a..915fe3e 100644 --- a/node/header.go +++ b/node/header.go @@ -20,6 +20,18 @@ type header struct { Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. } +func parseHeader(b []byte) (h header, ok bool) { + if len(b) < headerSize { + return + } + h.Version = b[0] + h.StreamID = b[1] + h.SourceIP = b[2] + h.DestIP = b[3] + h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) + return h, true +} + func (h *header) Parse(b []byte) { h.Version = b[0] h.StreamID = b[1] diff --git a/node/ifreader.go b/node/ifreader.go index a0e7a54..67d0999 100644 --- a/node/ifreader.go +++ b/node/ifreader.go @@ -57,7 +57,7 @@ func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { return } - if relay := r.relay.Load(); relay != nil { + if relay := r.relay.Load(); relay != nil && relay.Up { r.relayDataPacket(pkt, route, relay) } } diff --git a/peer/bitset.go b/peer/bitset.go new file mode 100644 index 0000000..8d03b50 --- /dev/null +++ b/peer/bitset.go @@ -0,0 +1,21 @@ +package peer + +const bitSetSize = 512 // Multiple of 64. + +type bitSet [bitSetSize / 64]uint64 + +func (bs *bitSet) Set(i int) { + bs[i/64] |= 1 << (i % 64) +} + +func (bs *bitSet) Clear(i int) { + bs[i/64] &= ^(1 << (i % 64)) +} + +func (bs *bitSet) ClearAll() { + clear(bs[:]) +} + +func (bs *bitSet) Get(i int) bool { + return bs[i/64]&(1<<(i%64)) != 0 +} diff --git a/peer/bitset_test.go b/peer/bitset_test.go new file mode 100644 index 0000000..01ae82b --- /dev/null +++ b/peer/bitset_test.go @@ -0,0 +1,48 @@ +package peer + +import ( + "math/rand" + "testing" +) + +func TestBitSet(t *testing.T) { + state := make([]bool, bitSetSize) + for i := range state { + state[i] = rand.Float32() > 0.5 + } + + bs := bitSet{} + + for i := range state { + if state[i] { + bs.Set(i) + } + } + + for i := range state { + if bs.Get(i) != state[i] { + t.Fatal(i, state[i], bs.Get(i)) + } + } + + for i := range state { + if rand.Float32() > 0.5 { + state[i] = false + bs.Clear(i) + } + } + + for i := range state { + if bs.Get(i) != state[i] { + t.Fatal(i, state[i], bs.Get(i)) + } + } + + bs.ClearAll() + + for i := range state { + if bs.Get(i) { + t.Fatal(i, bs.Get(i)) + } + } +} diff --git a/peer/cipher-control.go b/peer/cipher-control.go new file mode 100644 index 0000000..bfecaeb --- /dev/null +++ b/peer/cipher-control.go @@ -0,0 +1,26 @@ +package peer + +import "golang.org/x/crypto/nacl/box" + +type controlCipher struct { + sharedKey [32]byte +} + +func newControlCipher(privKey, pubKey []byte) *controlCipher { + shared := [32]byte{} + box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) + return &controlCipher{shared} +} + +func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte { + const s = controlHeaderSize + out = out[:s+controlCipherOverhead+len(data)] + h.Marshal(out[:s]) + box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) + return out +} + +func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { + const s = controlHeaderSize + return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) +} diff --git a/peer/cipher-control_test.go b/peer/cipher-control_test.go new file mode 100644 index 0000000..916d2ea --- /dev/null +++ b/peer/cipher-control_test.go @@ -0,0 +1,122 @@ +package peer + +import ( + "bytes" + "crypto/rand" + "reflect" + "testing" + + "golang.org/x/crypto/nacl/box" +) + +func newControlCipherForTesting() (c1, c2 *controlCipher) { + pubKey1, privKey1, err := box.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + pubKey2, privKey2, err := box.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + return newControlCipher(privKey1[:], pubKey2[:]), + newControlCipher(privKey2[:], pubKey1[:]) +} + +func TestControlCipher(t *testing.T) { + c1, c2 := newControlCipherForTesting() + + maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + StreamID: controlStreamID, + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + encrypted = c1.Encrypt(h1, plaintext, encrypted) + + h2 := header{} + h2.Parse(encrypted) + if !reflect.DeepEqual(h1, h2) { + t.Fatal(h1, h2) + } + + decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(decrypted, plaintext) { + t.Fatal("not equal") + } + } +} + +func TestControlCipher_ShortCiphertext(t *testing.T) { + c1, _ := newControlCipherForTesting() + shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) + rand.Read(shortText) + _, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) + if ok { + t.Fatal(ok) + } +} + +func BenchmarkControlCipher_Encrypt(b *testing.B) { + c1, _ := newControlCipherForTesting() + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + encrypted = c1.Encrypt(h1, plaintext, encrypted) + } +} + +func BenchmarkControlCipher_Decrypt(b *testing.B) { + c1, c2 := newControlCipherForTesting() + + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + encrypted = c1.Encrypt(h1, plaintext, encrypted) + + decrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + decrypted, _ = c2.Decrypt(encrypted, decrypted) + } +} diff --git a/peer/cipher-data.go b/peer/cipher-data.go new file mode 100644 index 0000000..9b229bb --- /dev/null +++ b/peer/cipher-data.go @@ -0,0 +1,61 @@ +package peer + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "log" +) + +type dataCipher struct { + key [32]byte + aead cipher.AEAD +} + +func newDataCipher() *dataCipher { + key := [32]byte{} + if _, err := rand.Read(key[:]); err != nil { + log.Fatalf("Failed to read random data: %v", err) + } + return newDataCipherFromKey(key) +} + +func newDataCipherFromKey(key [32]byte) *dataCipher { + block, err := aes.NewCipher(key[:]) + if err != nil { + log.Fatalf("Failed to create new cipher: %v", err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + log.Fatalf("Failed to create new GCM: %v", err) + } + + return &dataCipher{key: key, aead: aead} +} + +func (sc *dataCipher) Key() [32]byte { + return sc.key +} + +func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte { + const s = dataHeaderSize + out = out[:s+dataCipherOverhead+len(data)] + h.Marshal(out[:s]) + sc.aead.Seal(out[s:s], out[:s], data, nil) + return out +} + +func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { + const s = dataHeaderSize + if len(encrypted) < s+dataCipherOverhead { + ok = false + return + } + + var err error + + data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil) + ok = err == nil + return +} diff --git a/peer/cipher-data_test.go b/peer/cipher-data_test.go new file mode 100644 index 0000000..ac9a03a --- /dev/null +++ b/peer/cipher-data_test.go @@ -0,0 +1,141 @@ +package peer + +import ( + "bytes" + "crypto/rand" + mrand "math/rand/v2" + "reflect" + "testing" +) + +func TestDataCipher(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + StreamID: dataStreamID, + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + h2 := header{} + h2.Parse(encrypted) + + dc2 := newDataCipherFromKey(dc1.Key()) + + decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("not equal") + } + + if !reflect.DeepEqual(h1, h2) { + t.Fatalf("%v != %v", h1, h2) + } + } +} + +func TestDataCipher_ModifyCiphertext(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + encrypted[mrand.IntN(len(encrypted))]++ + + dc2 := newDataCipherFromKey(dc1.Key()) + + _, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if ok { + t.Fatal(ok) + } + } +} + +func TestDataCipher_ShortCiphertext(t *testing.T) { + dc1 := newDataCipher() + shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) + rand.Read(shortText) + _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) + if ok { + t.Fatal(ok) + } +} + +func BenchmarkDataCipher_Encrypt(b *testing.B) { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + b.ResetTimer() + for i := 0; i < b.N; i++ { + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + } +} + +func BenchmarkDataCipher_Decrypt(b *testing.B) { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + + decrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + decrypted, _ = dc1.Decrypt(encrypted, decrypted) + } +} diff --git a/peer/cipher-discovery.go b/peer/cipher-discovery.go new file mode 100644 index 0000000..0e66650 --- /dev/null +++ b/peer/cipher-discovery.go @@ -0,0 +1,13 @@ +package peer + +/* +func signData(privKey *[64]byte, h header, data, out []byte) []byte { + out = out[:headerSize] + h.Marshal(out) + return sign.Sign(out, data, privKey) +} + +func openData(pubKey *[32]byte, signed, out []byte) (data []byte, ok bool) { + return sign.Open(out[:0], signed[headerSize:], pubKey) +} +*/ diff --git a/peer/connreader.go b/peer/connreader.go new file mode 100644 index 0000000..757b37c --- /dev/null +++ b/peer/connreader.go @@ -0,0 +1,141 @@ +package peer + +import ( + "log" + "net/netip" + "sync/atomic" +) + +type connReader struct { + conn udpReader + iface ifWriter + sender encryptedPacketSender + super controlMsgHandler + localIP byte + routes [256]*atomic.Pointer[peerRoute] + + buf []byte + decBuf []byte + dupChecks [256]*dupCheck +} + +func newConnReader( + conn udpReader, + ifWriter ifWriter, + sender encryptedPacketSender, + super controlMsgHandler, + localIP byte, + routes [256]*atomic.Pointer[peerRoute], +) *connReader { + return &connReader{ + conn: conn, + iface: ifWriter, + sender: sender, + super: super, + localIP: localIP, + routes: routes, + buf: make([]byte, bufferSize), + decBuf: make([]byte, bufferSize), + dupChecks: func() (out [256]*dupCheck) { + for i := range out { + out[i] = newDupCheck(0) + } + return + }(), + } +} + +func (r *connReader) Run() { + for { + r.handleNextPacket() + } +} + +func (r *connReader) logf(s string, args ...any) { + log.Printf("[ConnReader] "+s, args...) +} + +func (r *connReader) handleNextPacket() { + buf := r.buf[:bufferSize] + n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf) + if err != nil { + log.Fatalf("Failed to read from UDP port: %v", err) + } + + if n < headerSize { + return + } + + remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) + + buf = buf[:n] + h, ok := parseHeader(buf) + if !ok { + return + } + + route := r.routes[h.SourceIP].Load() + + switch h.StreamID { + case controlStreamID: + r.handleControlPacket(route, remoteAddr, h, buf) + + case dataStreamID: + r.handleDataPacket(route, h, buf) + + default: + r.logf("Unknown stream ID: %d", h.StreamID) + } +} + +func (r *connReader) handleControlPacket( + route *peerRoute, + addr netip.AddrPort, + h header, + enc []byte, +) { + if route.ControlCipher == nil { + return + } + + if h.DestIP != r.localIP { + r.logf("Incorrect destination IP on control packet: %d", h.DestIP) + return + } + + msg, err := decryptControlPacket(route, addr, h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt control packet: %v", err) + return + } + + r.super.HandleControlMsg(msg) +} + +func (r *connReader) handleDataPacket(route *peerRoute, h header, enc []byte) { + if !route.Up { + r.logf("Not connected (recv).") + return + } + + data, err := decryptDataPacket(route, h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt data packet: %v", err) + return + } + + if h.DestIP == r.localIP { + if _, err := r.iface.Write(data); err != nil { + log.Fatalf("Failed to write to interface: %v", err) + } + return + } + + destRoute := r.routes[h.DestIP].Load() + if !destRoute.Up { + r.logf("Not connected (relay): %d", destRoute.IP) + return + } + + r.sender.SendEncryptedDataPacket(data, destRoute) +} diff --git a/peer/connreader_test.go b/peer/connreader_test.go new file mode 100644 index 0000000..7ef4ad8 --- /dev/null +++ b/peer/connreader_test.go @@ -0,0 +1,318 @@ +package peer + +/* +type mockIfWriter struct { + Written [][]byte +} + +func (w *mockIfWriter) Write(b []byte) (int, error) { + w.Written = append(w.Written, bytes.Clone(b)) + return len(b), nil +} + +type mockEncryptedPacket struct { + Packet []byte + Route *peerRoute +} + +type mockEncryptedPacketSender struct { + Sent []mockEncryptedPacket +} + +func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { + m.Sent = append(m.Sent, mockEncryptedPacket{ + Packet: bytes.Clone(pkt), + Route: route, + }) +} + +type mockControlMsgHandler struct { + Messages []any +} + +func (m *mockControlMsgHandler) HandleControlMsg(pkt any) { + m.Messages = append(m.Messages, pkt) +} + +type udpPipe struct { + packets chan []byte +} + +func newUDPPipe() *udpPipe { + return &udpPipe{make(chan []byte, 1024)} +} + +func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + p.packets <- bytes.Clone(b) + return len(b), nil +} + +func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + packet := <-p.packets + copy(b, packet) + return len(packet), netip.AddrPort{}, nil +} + +type connReaderTestHarness struct { + Pipe *udpPipe + R *connReader + WRemote *connWriter + WRelayRemote *connWriter + Remote *peerRoute + RelayRemote *peerRoute + IFace *mockIfWriter + Sender *mockEncryptedPacketSender + Super *mockControlMsgHandler +} + +// Peer 2 is indirect, peer 3 is direct. +func newConnReadeTestHarness() (h connReaderTestHarness) { + pipe := newUDPPipe() + routes := [256]*atomic.Pointer[peerRoute]{} + for i := range routes { + routes[i] = &atomic.Pointer[peerRoute]{} + routes[i].Store(&peerRoute{}) + } + + local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() + routes[2].Store(local) + routes[3].Store(relayLocal) + + h.Pipe = pipe + h.WRemote = newConnWriter(pipe, 2) + h.WRelayRemote = newConnWriter(pipe, 3) + + h.Remote = remote + h.RelayRemote = relayRemote + h.IFace = &mockIfWriter{} + h.Sender = &mockEncryptedPacketSender{} + h.Super = &mockControlMsgHandler{} + h.R = newConnReader( + pipe, + h.IFace, + h.Sender, + h.Super, + 1, + routes) + return h +} + +// Testing that we can receive a control packet. +func TestConnReader_handleControlPacket(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + h.WRemote.SendControlPacket(pkt, h.Remote) + + h.R.handleNextPacket() + + if len(h.Super.Messages) != 1 { + t.Fatal(h.Super.Messages) + } + + msg := h.Super.Messages[0].(controlMsg[synPacket]) + if !reflect.DeepEqual(pkt, msg.Packet) { + t.Fatal(msg.Packet) + } +} + +// Testing that a short packet is ignored. +func TestConnReader_handleNextPacket_short(t *testing.T) { + h := newConnReadeTestHarness() + + h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{}) + h.R.handleNextPacket() + + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that a packet with an unexpected stream ID is ignored. +func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + var header header + header.Parse(encrypted) + header.StreamID = 100 + header.Marshal(encrypted) + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that control packet without matching control cipher is ignored. +func TestConnReader_handleControlPacket_noCipher(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + var header header + header.Parse(encrypted) + header.SourceIP = 10 + header.Marshal(encrypted) + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that control packet with incrrect destination IP is ignored. +func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + var header header + header.Parse(encrypted) + header.DestIP++ + header.Marshal(encrypted) + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that modified control packet is ignored. +func TestConnReader_handleControlPacket_modified(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted[len(encrypted)-1]++ + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +type emptyPacket struct{} + +func (p emptyPacket) Marshal(buf []byte) []byte { + return buf[:0] +} + +// Testing that an empty control packet is ignored. +func TestConnReader_handleControlPacket_empty(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := emptyPacket{} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that a duplicate control packet is ignored. +func TestConnReader_handleControlPacket_duplicate(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + log.Printf("%d", h.WRemote.counters[1]) + h.WRemote.SendControlPacket(pkt, h.Remote) + log.Printf("%d", h.WRemote.counters[1]) + + // Rewind the counter. + h.WRemote.counters[1] = h.WRemote.counters[1] - 1 + log.Printf("%d", h.WRemote.counters[1]) + h.WRemote.SendControlPacket(pkt, h.Remote) + + h.R.handleNextPacket() + h.R.handleNextPacket() + + if len(h.Super.Messages) != 1 { + t.Fatal(h.Super.Messages) + } + + msg := h.Super.Messages[0].(controlMsg[synPacket]) + if !reflect.DeepEqual(pkt, msg.Packet) { + t.Fatal(msg.Packet) + } +} + +type invalidPacket struct { +} + +func (p invalidPacket) Marshal(b []byte) []byte { + out := b[:256] + clear(out) + return out +} + +// Testing that an invalid control packet is ignored (fails to parse). +func TestConnReader_handleControlPacket_cantParse(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := invalidPacket{} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that we can receive a data packet. +func TestConnReader_handleDataPacket(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := make([]byte, 1024) + rand.Read(pkt) + + h.WRemote.SendDataPacket(pkt, h.Remote) + + h.R.handleNextPacket() + + if len(h.IFace.Written) != 1 { + t.Fatal(h.IFace.Written) + } + + if !bytes.Equal(pkt, h.IFace.Written[0]) { + t.Fatal(h.IFace.Written) + } +} + +// Testing that data packet is ignored if route isn't up. +func TestConnReader_handleDataPacket_routeDown(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := make([]byte, 1024) + rand.Read(pkt) + + h.WRemote.SendDataPacket(pkt, h.Remote) + route := h.R.routes[2].Load() + route.Up = false + + h.R.handleNextPacket() + + if len(h.IFace.Written) != 0 { + t.Fatal(h.IFace.Written) + } +} +*/ +// Testing that a duplicate data packet is ignored. + +// Testing that we send a relayed data packet. + +// Testing that a relayed data packet is ignored if destination isn't up. diff --git a/peer/connwriter.go b/peer/connwriter.go new file mode 100644 index 0000000..928b2a0 --- /dev/null +++ b/peer/connwriter.go @@ -0,0 +1,80 @@ +package peer + +import ( + "log" + "net/netip" + "sync" +) + +// ---------------------------------------------------------------------------- + +type connWriter struct { + localIP byte + conn udpWriter + + // For sending control packets. + cBuf1 []byte + cBuf2 []byte + + // For sending data packets. + dBuf1 []byte + dBuf2 []byte + + // Lock around for sending on UDP Conn. + wLock sync.Mutex +} + +func newConnWriter(conn udpWriter, localIP byte) *connWriter { + w := &connWriter{ + localIP: localIP, + conn: conn, + cBuf1: make([]byte, bufferSize), + cBuf2: make([]byte, bufferSize), + dBuf1: make([]byte, bufferSize), + dBuf2: make([]byte, bufferSize), + } + return w +} + +// Not safe for concurrent use. Should only be called by supervisor. +func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { + enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) + w.writeTo(enc, route.RemoteAddr) +} + +// Relay control packet. Route must not be nil. +func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { + enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) + enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.cBuf1) + w.writeTo(enc, relay.RemoteAddr) +} + +// Not safe for concurrent use. Should only be called by ifReader. +func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { + enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) + w.writeTo(enc, route.RemoteAddr) +} + +// Relay a data packet. Route must not be nil. +func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { + enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) + enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.dBuf2) + w.writeTo(enc, relay.RemoteAddr) +} + +// Safe for concurrent use. Should only be called by connReader. +// +// This function will send pkt to the peer directly. This is used when a peer +// is acting as a relay and is forwarding already encrypted data for another +// peer. +func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { + w.writeTo(pkt, route.RemoteAddr) +} + +func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { + w.wLock.Lock() + if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { + log.Printf("[ConnWriter] Failed to write to UDP port: %v", err) + } + w.wLock.Unlock() +} diff --git a/peer/connwriter_test.go b/peer/connwriter_test.go new file mode 100644 index 0000000..14f128e --- /dev/null +++ b/peer/connwriter_test.go @@ -0,0 +1,240 @@ +package peer + +import ( + "bytes" + "net/netip" + "testing" +) + +// ---------------------------------------------------------------------------- + +type testUDPPacket struct { + Addr netip.AddrPort + Data []byte +} + +type testUDPAddrPortWriter struct { + written []testUDPPacket +} + +func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + w.written = append(w.written, testUDPPacket{ + Addr: addr, + Data: bytes.Clone(b), + }) + return len(b), nil +} + +func (w *testUDPAddrPortWriter) Written() []testUDPPacket { + out := w.written + w.written = []testUDPPacket{} + return out +} + +// ---------------------------------------------------------------------------- + +type testPacket string + +func (p testPacket) Marshal(b []byte) []byte { + b = b[:len(p)] + copy(b, []byte(p)) + return b +} + +// ---------------------------------------------------------------------------- + +func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *peerRoute) { + localKeys := generateKeys() + remoteKeys := generateKeys() + + local = newPeerRoute(2) + local.Up = true + local.Relay = false + local.PubSignKey = remoteKeys.PubSignKey + local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) + local.DataCipher = newDataCipher() + local.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) + + remote = newPeerRoute(1) + remote.Up = true + remote.Relay = false + remote.PubSignKey = localKeys.PubSignKey + remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) + remote.DataCipher = local.DataCipher + remote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) + + rLocalKeys := generateKeys() + rRemoteKeys := generateKeys() + + relayLocal = newPeerRoute(3) + relayLocal.Up = true + relayLocal.Relay = true + relayLocal.Direct = true + relayLocal.PubSignKey = rRemoteKeys.PubSignKey + relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) + relayLocal.DataCipher = newDataCipher() + relayLocal.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) + + relayRemote = newPeerRoute(1) + relayRemote.Up = true + relayRemote.Relay = false + relayRemote.Direct = true + relayRemote.PubSignKey = rLocalKeys.PubSignKey + relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) + relayRemote.DataCipher = relayLocal.DataCipher + relayRemote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) + + return +} + +// ---------------------------------------------------------------------------- + +// Testing if we can send a control packet directly to the remote route. +func TestConnWriter_SendControlPacket_direct(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + route.Direct = true + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.SendControlPacket(in, route) + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + if string(dec) != string(in) { + t.Fatal(dec) + } +} + +// Testing if we can relay a packet via an intermediary. +func TestConnWriter_RelayControlPacket_relay(t *testing.T) { + route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.RelayControlPacket(in, route, relay) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != relay.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if string(dec2) != string(in) { + t.Fatal(dec2) + } +} + +// Testing that we can send a data packet directly to a remote route. +func TestConnWriter_SendDataPacket_direct(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + route.Direct = true + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + + in := []byte("hello world!") + w.SendDataPacket(in, route) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(dec, in) { + t.Fatal(dec) + } +} + +// Testing that we can relay a data packet via a relay. +func TestConnWriter_RelayDataPacket_relay(t *testing.T) { + route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := []byte("Hello world!") + + w.RelayDataPacket(in, route, relay) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != relay.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(dec2, in) { + t.Fatal(dec2) + } +} + +// Testing that we can send an already encrypted packet. +func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := []byte("Hello world!") + + w.SendEncryptedDataPacket(in, route) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + if !bytes.Equal(out[0].Data, in) { + t.Fatal(out[0]) + } +} diff --git a/peer/controlmessage.go b/peer/controlmessage.go new file mode 100644 index 0000000..d8e9a17 --- /dev/null +++ b/peer/controlmessage.go @@ -0,0 +1,58 @@ +package peer + +import ( + "net/netip" + "vppn/m" +) + +// ---------------------------------------------------------------------------- + +type controlMsg[T any] struct { + SrcIP byte + SrcAddr netip.AddrPort + // TODO: RecvdAt int64 // Unixmilli. + Packet T +} + +func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { + switch buf[0] { + + case packetTypeSyn: + packet, err := parseSynPacket(buf) + return controlMsg[synPacket]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + + case packetTypeAck: + packet, err := parseAckPacket(buf) + return controlMsg[ackPacket]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + + case packetTypeProbe: + packet, err := parseProbePacket(buf) + return controlMsg[probePacket]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + + default: + return nil, errUnknownPacketType + } +} + +// ---------------------------------------------------------------------------- + +type peerUpdateMsg struct { + PeerIP byte + Peer *m.Peer +} + +// ---------------------------------------------------------------------------- + +type pingTimerMsg struct{} diff --git a/peer/crypto.go b/peer/crypto.go new file mode 100644 index 0000000..f41c8bf --- /dev/null +++ b/peer/crypto.go @@ -0,0 +1,113 @@ +package peer + +import ( + "crypto/rand" + "log" + "net/netip" + "sync/atomic" + + "golang.org/x/crypto/nacl/box" + "golang.org/x/crypto/nacl/sign" +) + +type cryptoKeys struct { + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + +func generateKeys() cryptoKeys { + pubKey, privKey, err := box.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate encryption keys: %v", err) + } + + pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate signing keys: %v", err) + } + + return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} +} + +// ---------------------------------------------------------------------------- + +// Route must have a ControlCipher. +func encryptControlPacket( + localIP byte, + route *peerRoute, + pkt marshaller, + tmp []byte, + out []byte, +) []byte { + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(route.Counter, 1), + SourceIP: localIP, + DestIP: route.IP, + } + tmp = pkt.Marshal(tmp) + return route.ControlCipher.Encrypt(h, tmp, out) +} + +// Returns a controlMsg[PacketType]. Route must have ControlCipher. +func decryptControlPacket( + route *peerRoute, + fromAddr netip.AddrPort, + h header, + encrypted []byte, + tmp []byte, +) (any, error) { + out, ok := route.ControlCipher.Decrypt(encrypted, tmp) + if !ok { + return nil, errDecryptionFailed + } + + if route.DupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + msg, err := parseControlMsg(h.SourceIP, fromAddr, out) + if err != nil { + return nil, err + } + + return msg, nil +} + +// ---------------------------------------------------------------------------- + +func encryptDataPacket( + localIP byte, + destIP byte, + route *peerRoute, + data []byte, + out []byte, +) []byte { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(route.Counter, 1), + SourceIP: localIP, + DestIP: destIP, + } + return route.DataCipher.Encrypt(h, data, out) +} + +func decryptDataPacket( + route *peerRoute, + h header, + encrypted []byte, + out []byte, +) ([]byte, error) { + dec, ok := route.DataCipher.Decrypt(encrypted, out) + if !ok { + return nil, errDecryptionFailed + } + + if route.DupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + return dec, nil +} diff --git a/peer/crypto_test.go b/peer/crypto_test.go new file mode 100644 index 0000000..29ee377 --- /dev/null +++ b/peer/crypto_test.go @@ -0,0 +1,213 @@ +package peer + +import ( + "bytes" + "crypto/rand" + "errors" + "net/netip" + "reflect" + "testing" +) + +func newRoutePairForTesting() (*peerRoute, *peerRoute) { + keys1 := generateKeys() + keys2 := generateKeys() + + r1 := newPeerRoute(1) + r1.PubSignKey = keys1.PubSignKey + r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) + r1.DataCipher = newDataCipher() + + r2 := newPeerRoute(2) + r2.PubSignKey = keys2.PubSignKey + r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) + r2.DataCipher = r1.DataCipher + + return r1, r2 +} + +func TestDecryptControlPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := synPacket{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + iMsg, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if err != nil { + t.Fatal(err) + } + + msg, ok := iMsg.(controlMsg[synPacket]) + if !ok { + t.Fatal(ok) + } + + if !reflect.DeepEqual(msg.Packet, in) { + t.Fatal(msg) + } +} + +func TestDecryptControlPacket_decryptionFailed(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := synPacket{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + for i := range enc { + x := bytes.Clone(enc) + x[i]++ + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp) + if !errors.Is(err, errDecryptionFailed) { + t.Fatal(i, err) + } + } +} + +func TestDecryptControlPacket_duplicate(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := synPacket{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { + t.Fatal(err) + } + + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errDuplicateSeqNum) { + t.Fatal(err) + } +} + +func TestDecryptControlPacket_invalidPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := testPacket("hello!") + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errUnknownPacketType) { + t.Fatal(err) + } +} + +func TestDecryptDataPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, out) { + t.Fatal(data, out) + } +} + +func TestDecryptDataPacket_incorrectCipher(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + r1.DataCipher = newDataCipher() + _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if !errors.Is(err, errDecryptionFailed) { + t.Fatal(err) + } +} + +func TestDecryptDataPacket_duplicate(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if err != nil { + t.Fatal(err) + } + + _, err = decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if !errors.Is(err, errDuplicateSeqNum) { + t.Fatal(err) + } +} diff --git a/peer/dupcheck.go b/peer/dupcheck.go new file mode 100644 index 0000000..09b5b11 --- /dev/null +++ b/peer/dupcheck.go @@ -0,0 +1,76 @@ +package peer + +type dupCheck struct { + bitSet + head int + tail int + headCounter uint64 + tailCounter uint64 // Also next expected counter value. +} + +func newDupCheck(headCounter uint64) *dupCheck { + return &dupCheck{ + headCounter: headCounter, + tailCounter: headCounter + 1, + tail: 1, + } +} + +func (dc *dupCheck) IsDup(counter uint64) bool { + + // Before head => it's late, say it's a dup. + if counter < dc.headCounter { + return true + } + + // It's within the counter bounds. + if counter < dc.tailCounter { + index := (int(counter-dc.headCounter) + dc.head) % bitSetSize + if dc.Get(index) { + return true + } + + dc.Set(index) + return false + } + + // It's more than 1 beyond the tail. + delta := counter - dc.tailCounter + + // Full clear. + if delta >= bitSetSize-1 { + dc.ClearAll() + dc.Set(0) + + dc.tail = 1 + dc.head = 2 + dc.tailCounter = counter + 1 + dc.headCounter = dc.tailCounter - bitSetSize + 1 + + return false + } + + // Clear if necessary. + for i := 0; i < int(delta); i++ { + dc.put(false) + } + + dc.put(true) + return false +} + +func (dc *dupCheck) put(set bool) { + if set { + dc.Set(dc.tail) + } else { + dc.Clear(dc.tail) + } + + dc.tail = (dc.tail + 1) % bitSetSize + dc.tailCounter++ + + if dc.head == dc.tail { + dc.head = (dc.head + 1) % bitSetSize + dc.headCounter++ + } +} diff --git a/peer/dupcheck_test.go b/peer/dupcheck_test.go new file mode 100644 index 0000000..2b50d74 --- /dev/null +++ b/peer/dupcheck_test.go @@ -0,0 +1,57 @@ +package peer + +import ( + "testing" +) + +func TestDupCheck(t *testing.T) { + dc := newDupCheck(0) + + for i := range bitSetSize { + if dc.IsDup(uint64(i)) { + t.Fatal("!") + } + } + + type TestCase struct { + Counter uint64 + Dup bool + } + + testCases := []TestCase{ + {511, true}, + {0, true}, + {1, true}, + {2, true}, + {3, true}, + {63, true}, + {256, true}, + {510, true}, + {511, true}, + {512, false}, + {0, true}, + {512, true}, + {513, false}, + {517, false}, + {512, true}, + {513, true}, + {514, false}, + {515, false}, + {516, false}, + {517, true}, + {2512, false}, + {2512, true}, + {2001, true}, + {2002, false}, + {2002, true}, + {4000, false}, + {4000 - 511, true}, // Too old. + {4000 - 510, false}, // Just in the window. + } + + for i, tc := range testCases { + if ok := dc.IsDup(tc.Counter); ok != tc.Dup { + t.Fatal(i, ok, tc) + } + } +} diff --git a/peer/errors.go b/peer/errors.go new file mode 100644 index 0000000..b1e07e2 --- /dev/null +++ b/peer/errors.go @@ -0,0 +1,10 @@ +package peer + +import "errors" + +var ( + errDecryptionFailed = errors.New("decryption failed") + errDuplicateSeqNum = errors.New("duplicate sequence number") + errMalformedPacket = errors.New("malformed packet") + errUnknownPacketType = errors.New("unknown packet type") +) diff --git a/peer/globals.go b/peer/globals.go new file mode 100644 index 0000000..a4d8d65 --- /dev/null +++ b/peer/globals.go @@ -0,0 +1,19 @@ +package peer + +import ( + "net" + "net/netip" +) + +const ( + bufferSize = 1536 + if_mtu = 1200 + if_queue_len = 2048 + controlCipherOverhead = 16 + dataCipherOverhead = 16 + signOverhead = 64 +) + +var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( + netip.AddrFrom4([4]byte{224, 0, 0, 157}), + 4560)) diff --git a/peer/header.go b/peer/header.go new file mode 100644 index 0000000..08698dd --- /dev/null +++ b/peer/header.go @@ -0,0 +1,49 @@ +package peer + +import "unsafe" + +// ---------------------------------------------------------------------------- + +const ( + headerSize = 12 + controlStreamID = 2 + controlHeaderSize = 24 + dataStreamID = 1 + dataHeaderSize = 12 +) + +type header struct { + Version byte + StreamID byte + SourceIP byte + DestIP byte + Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. +} + +func parseHeader(b []byte) (h header, ok bool) { + if len(b) < headerSize { + return + } + h.Version = b[0] + h.StreamID = b[1] + h.SourceIP = b[2] + h.DestIP = b[3] + h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) + return h, true +} + +func (h *header) Parse(b []byte) { + h.Version = b[0] + h.StreamID = b[1] + h.SourceIP = b[2] + h.DestIP = b[3] + h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) +} + +func (h *header) Marshal(buf []byte) { + buf[0] = h.Version + buf[1] = h.StreamID + buf[2] = h.SourceIP + buf[3] = h.DestIP + *(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter +} diff --git a/peer/header_test.go b/peer/header_test.go new file mode 100644 index 0000000..11e2f8f --- /dev/null +++ b/peer/header_test.go @@ -0,0 +1,21 @@ +package peer + +import "testing" + +func TestHeaderMarshalParse(t *testing.T) { + nIn := header{ + StreamID: 23, + Counter: 3212, + SourceIP: 34, + DestIP: 200, + } + + buf := make([]byte, headerSize) + nIn.Marshal(buf) + + nOut := header{} + nOut.Parse(buf) + if nIn != nOut { + t.Fatal(nIn, nOut) + } +} diff --git a/peer/ifreader.go b/peer/ifreader.go new file mode 100644 index 0000000..61627a2 --- /dev/null +++ b/peer/ifreader.go @@ -0,0 +1,100 @@ +package peer + +import ( + "io" + "log" + "sync/atomic" +) + +type ifReader struct { + iface io.Reader + routes [256]*atomic.Pointer[peerRoute] + relay *atomic.Pointer[peerRoute] + sender dataPacketSender +} + +func newIFReader( + iface io.Reader, + routes [256]*atomic.Pointer[peerRoute], + relay *atomic.Pointer[peerRoute], + sender dataPacketSender, +) *ifReader { + return &ifReader{ + iface: iface, + routes: routes, + relay: relay, + sender: sender, + } +} + +func (r *ifReader) Run() { + var ( + packet = make([]byte, bufferSize) + remoteIP byte + ok bool + ) + + for { + packet = r.readNextPacket(packet) + if remoteIP, ok = r.parsePacket(packet); ok { + r.sendPacket(packet, remoteIP) + } + } +} + +func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { + route := r.routes[remoteIP].Load() + if !route.Up { + log.Printf("Route not connected: %d", remoteIP) + return + } + + // Direct path => early return. + if route.Direct { + r.sender.SendDataPacket(pkt, route) + return + } + + if relay := r.relay.Load(); relay != nil && relay.Up { + r.sender.RelayDataPacket(pkt, route, relay) + } +} + +// Get next packet, returning packet, and destination ip. +func (r *ifReader) readNextPacket(buf []byte) []byte { + n, err := r.iface.Read(buf[:cap(buf)]) + if err != nil { + log.Fatalf("Failed to read from interface: %v", err) + } + + return buf[:n] +} + +func (r *ifReader) parsePacket(buf []byte) (byte, bool) { + n := len(buf) + if n == 0 { + return 0, false + } + + version := buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + log.Printf("Short IPv4 packet: %d", len(buf)) + return 0, false + } + return buf[19], true + + case 6: + if len(buf) < 40 { + log.Printf("Short IPv6 packet: %d", len(buf)) + return 0, false + } + return buf[39], true + + default: + log.Printf("Invalid IP packet version: %v", version) + return 0, false + } +} diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go new file mode 100644 index 0000000..c5efb30 --- /dev/null +++ b/peer/ifreader_test.go @@ -0,0 +1,232 @@ +package peer + +import ( + "bytes" + "net" + "reflect" + "sync/atomic" + "testing" +) + +// Test that we parse IPv4 packets correctly. +func TestIFReader_parsePacket_ipv4(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 1234) + pkt[0] = 4 << 4 + pkt[19] = 128 + + if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { + t.Fatal(ip, ok) + } +} + +// Test that we parse IPv6 packets correctly. +func TestIFReader_parsePacket_ipv6(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 1234) + pkt[0] = 6 << 4 + pkt[39] = 42 + + if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { + t.Fatal(ip, ok) + } +} + +// Test that empty packets work as expected. +func TestIFReader_parsePacket_emptyPacket(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 0) + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that invalid IP versions fail. +func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + for i := byte(1); i < 16; i++ { + if i == 4 || i == 6 { + continue + } + pkt := make([]byte, 1234) + pkt[0] = i << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(i, ip, ok) + } + } +} + +// Test that short IPv4 packets fail. +func TestIFReader_parsePacket_shortIPv4(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 19) + pkt[0] = 4 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that short IPv6 packets fail. +func TestIFReader_parsePacket_shortIPv6(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 39) + pkt[0] = 6 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that we can read a packet. +func TestIFReader_readNextpacket(t *testing.T) { + in, out := net.Pipe() + r := newIFReader(out, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + defer in.Close() + defer out.Close() + + go in.Write([]byte("hello world!")) + + pkt := r.readNextPacket(make([]byte, bufferSize)) + if !bytes.Equal(pkt, []byte("hello world!")) { + t.Fatalf("%s", pkt) + } +} + +// ---------------------------------------------------------------------------- + +type sentPacket struct { + Relayed bool + Packet []byte + Route peerRoute + Relay peerRoute +} + +type sendPacketTestHarness struct { + Packets []sentPacket +} + +func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *peerRoute) { + h.Packets = append(h.Packets, sentPacket{ + Packet: bytes.Clone(pkt), + Route: *route, + }) +} + +func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *peerRoute) { + h.Packets = append(h.Packets, sentPacket{ + Relayed: true, + Packet: bytes.Clone(pkt), + Route: *route, + Relay: *relay, + }) +} + +func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { + h := &sendPacketTestHarness{} + + routes := [256]*atomic.Pointer[peerRoute]{} + for i := range routes { + routes[i] = &atomic.Pointer[peerRoute]{} + routes[i].Store(&peerRoute{}) + } + relay := &atomic.Pointer[peerRoute]{} + r := newIFReader(nil, routes, relay, h) + return r, h +} + +// Testing that we can send a packet directly. +func TestIFReader_sendPacket_direct(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Up = true + route.Direct = true + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 1 { + t.Fatal(h.Packets) + } + + expected := sentPacket{ + Relayed: false, + Packet: in, + Route: *route, + } + + if !reflect.DeepEqual(h.Packets[0], expected) { + t.Fatal(h.Packets[0]) + } +} + +// Testing that we don't send a packet if route isn't up. +func TestIFReader_sendPacket_directNotUp(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Direct = true + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 0 { + t.Fatal(h.Packets) + } +} + +// Testing that we can send a packet via a relay. +func TestIFReader_sendPacket_relayed(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Up = true + route.Direct = false + + relay := r.routes[3].Load() + r.relay.Store(relay) + relay.Up = true + relay.Direct = true + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 1 { + t.Fatal(h.Packets) + } + + expected := sentPacket{ + Relayed: true, + Packet: in, + Route: *route, + Relay: *relay, + } + + if !reflect.DeepEqual(h.Packets[0], expected) { + t.Fatal(h.Packets[0]) + } +} + +// Testing that we don't try to send on a nil relay IP. +func TestIFReader_sendPacket_nilRealy(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Up = true + route.Direct = false + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 0 { + t.Fatal(h.Packets) + } +} diff --git a/peer/ifwriter.go b/peer/ifwriter.go new file mode 100644 index 0000000..59e2e26 --- /dev/null +++ b/peer/ifwriter.go @@ -0,0 +1,5 @@ +package peer + +import "io" + +type ifWriter io.Writer diff --git a/peer/interfaces.go b/peer/interfaces.go new file mode 100644 index 0000000..84f9c99 --- /dev/null +++ b/peer/interfaces.go @@ -0,0 +1,28 @@ +package peer + +import "net/netip" + +type udpReader interface { + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) +} + +type udpWriter interface { + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) +} + +type marshaller interface { + Marshal([]byte) []byte +} + +type dataPacketSender interface { + SendDataPacket(pkt []byte, route *peerRoute) + RelayDataPacket(pkt []byte, route, relay *peerRoute) +} + +type encryptedPacketSender interface { + SendEncryptedDataPacket(pkt []byte, route *peerRoute) +} + +type controlMsgHandler interface { + HandleControlMsg(pkt any) +} diff --git a/peer/mcwriter.go b/peer/mcwriter.go new file mode 100644 index 0000000..db9a76b --- /dev/null +++ b/peer/mcwriter.go @@ -0,0 +1,62 @@ +package peer + +import ( + "log" + "net" + + "golang.org/x/crypto/nacl/sign" +) + +// ---------------------------------------------------------------------------- + +type mcUDPWriter interface { + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} + +// ---------------------------------------------------------------------------- + +func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { + h := header{ + SourceIP: localIP, + DestIP: 255, + } + buf := make([]byte, headerSize) + h.Marshal(buf) + out := make([]byte, headerSize+signOverhead) + return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) +} + +func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) { + if len(pkt) != headerSize+signOverhead { + return + } + + h.Parse(pkt[signOverhead:]) + ok = true + return +} + +func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { + _, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey)) + return ok +} + +// ---------------------------------------------------------------------------- + +type mcWriter struct { + conn mcUDPWriter + discoveryPacket []byte +} + +func newMCWriter(conn mcUDPWriter, localIP byte, signingKey []byte) *mcWriter { + return &mcWriter{ + conn: conn, + discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey), + } +} + +func (w *mcWriter) SendLocalDiscovery() { + if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { + log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err) + } +} diff --git a/peer/mcwriter_test.go b/peer/mcwriter_test.go new file mode 100644 index 0000000..ffef05d --- /dev/null +++ b/peer/mcwriter_test.go @@ -0,0 +1,102 @@ +package peer + +import ( + "bytes" + "net" + "testing" +) + +// ---------------------------------------------------------------------------- + +// Testing that we can create and verify a local discovery packet. +func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + header, ok := headerFromLocalDiscoveryPacket(created) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 55 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Not valid") + } +} + +// Testing that we don't try to parse short packets. +func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + _, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1]) + if ok { + t.Fatal(ok) + } +} + +// Testing that modifying a packet makes it invalid. +func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + buf := make([]byte, 1024) + for i := range created { + modified := bytes.Clone(created) + modified[i]++ + if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) { + t.Fatal("Verification should have failed.") + } + } +} + +// ---------------------------------------------------------------------------- + +type testUDPWriter struct { + written [][]byte +} + +func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + w.written = append(w.written, bytes.Clone(b)) + return len(b), nil +} + +func (w *testUDPWriter) Written() [][]byte { + out := w.written + w.written = [][]byte{} + return out +} + +// ---------------------------------------------------------------------------- + +// Testing that the mcWriter sends local discovery packets as expected. +func TestMCWriter_SendLocalDiscovery(t *testing.T) { + keys := generateKeys() + writer := &testUDPWriter{} + + mcw := newMCWriter(writer, 42, keys.PrivSignKey) + mcw.SendLocalDiscovery() + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + pkt := out[0] + + header, ok := headerFromLocalDiscoveryPacket(pkt) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 42 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Verification should succeed.") + } +} diff --git a/peer/packets-util.go b/peer/packets-util.go new file mode 100644 index 0000000..bda33b9 --- /dev/null +++ b/peer/packets-util.go @@ -0,0 +1,190 @@ +package peer + +import ( + "net/netip" + "sync/atomic" + "time" + "unsafe" +) + +var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1 + +func newTraceID() uint64 { + return atomic.AddUint64(&traceIDCounter, 1) +} + +// ---------------------------------------------------------------------------- + +type binWriter struct { + b []byte + i int +} + +func newBinWriter(buf []byte) *binWriter { + buf = buf[:cap(buf)] + return &binWriter{buf, 0} +} + +func (w *binWriter) Bool(b bool) *binWriter { + if b { + return w.Byte(1) + } + return w.Byte(0) +} + +func (w *binWriter) Byte(b byte) *binWriter { + w.b[w.i] = b + w.i++ + return w +} + +func (w *binWriter) SharedKey(key [32]byte) *binWriter { + copy(w.b[w.i:w.i+32], key[:]) + w.i += 32 + return w +} + +func (w *binWriter) Uint16(x uint16) *binWriter { + *(*uint16)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 2 + return w +} + +func (w *binWriter) Uint64(x uint64) *binWriter { + *(*uint64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) Int64(x int64) *binWriter { + *(*int64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { + w.Bool(addrPort.IsValid()) + addr := addrPort.Addr().As16() + copy(w.b[w.i:w.i+16], addr[:]) + w.i += 16 + return w.Uint16(addrPort.Port()) +} + +func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { + for _, addrPort := range l { + w.AddrPort(addrPort) + } + return w +} + +func (w *binWriter) Build() []byte { + return w.b[:w.i] +} + +// ---------------------------------------------------------------------------- + +type binReader struct { + b []byte + i int + err error +} + +func newBinReader(buf []byte) *binReader { + return &binReader{b: buf} +} + +func (r *binReader) hasBytes(n int) bool { + if r.err != nil || (len(r.b)-r.i) < n { + r.err = errMalformedPacket + return false + } + return true +} + +func (r *binReader) Bool(b *bool) *binReader { + var bb byte + r.Byte(&bb) + *b = bb != 0 + return r +} + +func (r *binReader) Byte(b *byte) *binReader { + if !r.hasBytes(1) { + return r + } + *b = r.b[r.i] + r.i++ + return r +} + +func (r *binReader) SharedKey(x *[32]byte) *binReader { + if !r.hasBytes(32) { + return r + } + *x = ([32]byte)(r.b[r.i : r.i+32]) + r.i += 32 + return r +} + +func (r *binReader) Uint16(x *uint16) *binReader { + if !r.hasBytes(2) { + return r + } + *x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) + r.i += 2 + return r +} + +func (r *binReader) Uint64(x *uint64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) Int64(x *int64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*int64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { + if !r.hasBytes(19) { + return r + } + + var ( + valid bool + port uint16 + ) + + r.Bool(&valid) + addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() + r.i += 16 + + r.Uint16(&port) + + if valid { + *x = netip.AddrPortFrom(addr, port) + } else { + *x = netip.AddrPort{} + } + + return r +} + +func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { + for i := range x { + r.AddrPort(&x[i]) + } + return r +} + +func (r *binReader) Error() error { + return r.err +} diff --git a/peer/packets-util_test.go b/peer/packets-util_test.go new file mode 100644 index 0000000..5a518d7 --- /dev/null +++ b/peer/packets-util_test.go @@ -0,0 +1,56 @@ +package peer + +import ( + "net/netip" + "reflect" + "testing" +) + +func TestBinWriteRead(t *testing.T) { + buf := make([]byte, 1024) + + type Item struct { + Type byte + TraceID uint64 + Addrs [8]netip.AddrPort + DestAddr netip.AddrPort + } + + in := Item{ + 1, + 2, + [8]netip.AddrPort{}, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), + } + + in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) + in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) + in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) + in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) + in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) + in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) + in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) + + buf = newBinWriter(buf). + Byte(in.Type). + Uint64(in.TraceID). + AddrPort(in.DestAddr). + AddrPortArray(in.Addrs). + Build() + + out := Item{} + + err := newBinReader(buf). + Byte(&out.Type). + Uint64(&out.TraceID). + AddrPort(&out.DestAddr). + AddrPortArray(&out.Addrs). + Error() + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal(in, out) + } +} diff --git a/peer/packets.go b/peer/packets.go new file mode 100644 index 0000000..f7f1f85 --- /dev/null +++ b/peer/packets.go @@ -0,0 +1,123 @@ +package peer + +import ( + "net/netip" +) + +const ( + packetTypeSyn = iota + 1 + packetTypeSynAck + packetTypeAck + packetTypeProbe + packetTypeAddrDiscovery +) + +// ---------------------------------------------------------------------------- + +type synPacket struct { + TraceID uint64 // TraceID to match response w/ request. + // TODO: SentAt int64 // Unixmilli. + SharedKey [32]byte // Our shared key. + Direct bool + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. +} + +func (p synPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSyn). + Uint64(p.TraceID). + SharedKey(p.SharedKey). + Bool(p.Direct). + AddrPort(p.PossibleAddrs[0]). + AddrPort(p.PossibleAddrs[1]). + AddrPort(p.PossibleAddrs[2]). + AddrPort(p.PossibleAddrs[3]). + AddrPort(p.PossibleAddrs[4]). + AddrPort(p.PossibleAddrs[5]). + AddrPort(p.PossibleAddrs[6]). + AddrPort(p.PossibleAddrs[7]). + Build() +} + +func parseSynPacket(buf []byte) (p synPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + SharedKey(&p.SharedKey). + Bool(&p.Direct). + AddrPort(&p.PossibleAddrs[0]). + AddrPort(&p.PossibleAddrs[1]). + AddrPort(&p.PossibleAddrs[2]). + AddrPort(&p.PossibleAddrs[3]). + AddrPort(&p.PossibleAddrs[4]). + AddrPort(&p.PossibleAddrs[5]). + AddrPort(&p.PossibleAddrs[6]). + AddrPort(&p.PossibleAddrs[7]). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type ackPacket struct { + TraceID uint64 + ToAddr netip.AddrPort + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. +} + +func (p ackPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeAck). + Uint64(p.TraceID). + AddrPort(p.ToAddr). + AddrPort(p.PossibleAddrs[0]). + AddrPort(p.PossibleAddrs[1]). + AddrPort(p.PossibleAddrs[2]). + AddrPort(p.PossibleAddrs[3]). + AddrPort(p.PossibleAddrs[4]). + AddrPort(p.PossibleAddrs[5]). + AddrPort(p.PossibleAddrs[6]). + AddrPort(p.PossibleAddrs[7]). + Build() +} + +func parseAckPacket(buf []byte) (p ackPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + AddrPort(&p.ToAddr). + AddrPort(&p.PossibleAddrs[0]). + AddrPort(&p.PossibleAddrs[1]). + AddrPort(&p.PossibleAddrs[2]). + AddrPort(&p.PossibleAddrs[3]). + AddrPort(&p.PossibleAddrs[4]). + AddrPort(&p.PossibleAddrs[5]). + AddrPort(&p.PossibleAddrs[6]). + AddrPort(&p.PossibleAddrs[7]). + Error() + return +} + +// ---------------------------------------------------------------------------- + +// A probeReqPacket is sent from a client to a server to determine if direct +// UDP communication can be used. +type probePacket struct { + TraceID uint64 +} + +func (p probePacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeProbe). + Uint64(p.TraceID). + Build() +} + +func parseProbePacket(buf []byte) (p probePacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type localDiscoveryPacket struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go new file mode 100644 index 0000000..333deff --- /dev/null +++ b/peer/packets_test.go @@ -0,0 +1 @@ +package peer diff --git a/peer/state.go b/peer/state.go new file mode 100644 index 0000000..2ef248b --- /dev/null +++ b/peer/state.go @@ -0,0 +1,29 @@ +package peer + +import ( + "net/netip" + "time" +) + +type peerRoute struct { + IP byte // VPN IP of peer (last byte). + Up bool // True if data can be sent on the route. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + PubSignKey []byte + ControlCipher *controlCipher + DataCipher *dataCipher + RemoteAddr netip.AddrPort // Remote address if directly connected. + + Counter *uint64 // For sending to. Atomic access only. + DupCheck *dupCheck // For receiving from. Not safe for concurrent use. +} + +func newPeerRoute(ip byte) *peerRoute { + counter := uint64(time.Now().Unix()<<30 + 1) + return &peerRoute{ + IP: ip, + Counter: &counter, + DupCheck: newDupCheck(0), + } +}