From c7bc1ecc458da56ce33aaf1a3d4b382d8dfd7901 Mon Sep 17 00:00:00 2001 From: jdl Date: Thu, 19 Dec 2024 06:45:17 +0100 Subject: [PATCH] WIP: separate data and control ciphers. --- node/conn.go | 6 +- node/crypto.go | 4 +- node/crypto_test.go | 19 +++--- node/datacipher.go | 97 ++++++++++++++++++++++++++++ node/datacipher_test.go | 138 ++++++++++++++++++++++++++++++++++++++++ node/dupcheck.go | 4 -- node/dupcheck_test.go | 3 - node/routingpacket.go | 4 +- 8 files changed, 249 insertions(+), 26 deletions(-) create mode 100644 node/datacipher.go create mode 100644 node/datacipher_test.go diff --git a/node/conn.go b/node/conn.go index 9224d57..0823e31 100644 --- a/node/conn.go +++ b/node/conn.go @@ -74,7 +74,7 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt Stream: stream, } - buf := encryptPacket(&h, dstPeer.SharedKey, data, w.buf) + buf := encryptPacketAsym(&h, dstPeer.SharedKey, data, w.buf) if viaPeer != nil { h := header{ @@ -85,7 +85,7 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt Stream: stream, } - buf = encryptPacket(&h, viaPeer.SharedKey, buf, w.buf2) + buf = encryptPacketAsym(&h, viaPeer.SharedKey, buf, w.buf2) addr = viaPeer.Addr } @@ -155,7 +155,7 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data continue } - out, ok := decryptPacket(peer.SharedKey, data, r.buf) + out, ok := decryptPacketAsym(peer.SharedKey, data, r.buf) if !ok { continue } diff --git a/node/crypto.go b/node/crypto.go index cc5904f..0f7710f 100644 --- a/node/crypto.go +++ b/node/crypto.go @@ -8,14 +8,14 @@ import ( ) // Encrypting the packet will also set the header's DataSize field. -func encryptPacket(h *header, sharedKey, data, out []byte) []byte { +func encryptPacketAsym(h *header, sharedKey, data, out []byte) []byte { out = out[:headerSize] h.Marshal(out) b := box.SealAfterPrecomputation(out[headerSize:headerSize], data, (*[24]byte)(out[:headerSize]), (*[32]byte)(sharedKey)) return out[:len(b)+headerSize] } -func decryptPacket(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) { +func decryptPacketAsym(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) { return box.OpenAfterPrecomputation( out[:0], packetAndHeader[headerSize:], diff --git a/node/crypto_test.go b/node/crypto_test.go index 0a651b0..76f408f 100644 --- a/node/crypto_test.go +++ b/node/crypto_test.go @@ -3,14 +3,13 @@ package node import ( "bytes" "crypto/rand" - "log" "reflect" "testing" "golang.org/x/crypto/nacl/box" ) -func TestEncryptDecryptPacket(t *testing.T) { +func TestEncryptDecryptAsym(t *testing.T) { pubKey1, privKey1, err := box.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) @@ -21,8 +20,6 @@ func TestEncryptDecryptPacket(t *testing.T) { t.Fatal(err) } - log.Printf("\n%#v\n%#v\n%#v\n%#v\n", pubKey1, privKey1, pubKey2, privKey2) - sharedEncKey := [32]byte{} box.Precompute(&sharedEncKey, pubKey2, privKey1) @@ -41,11 +38,11 @@ func TestEncryptDecryptPacket(t *testing.T) { } encrypted := make([]byte, bufferSize) - encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) + encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted) decrypted := make([]byte, bufferSize) var ok bool - decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) + decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted) if !ok { t.Fatal(ok) } @@ -62,7 +59,7 @@ func TestEncryptDecryptPacket(t *testing.T) { } } -func BenchmarkEncryptPacket(b *testing.B) { +func BenchmarkEncryptAsym(b *testing.B) { _, privKey1, err := box.GenerateKey(rand.Reader) if err != nil { b.Fatal(err) @@ -93,11 +90,11 @@ func BenchmarkEncryptPacket(b *testing.B) { } for i := 0; i < b.N; i++ { - encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) + encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted) } } -func BenchmarkDecryptPacket(b *testing.B) { +func BenchmarkDecryptAsym(b *testing.B) { pubKey1, privKey1, err := box.GenerateKey(rand.Reader) if err != nil { b.Fatal(err) @@ -128,11 +125,11 @@ func BenchmarkDecryptPacket(b *testing.B) { Stream: 1, } - encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize)) + encrypted := encryptPacketAsym(&h, sharedEncKey[:], original, make([]byte, bufferSize)) decrypted := make([]byte, bufferSize) var ok bool for i := 0; i < b.N; i++ { - decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) + decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted) if !ok { panic(ok) } diff --git a/node/datacipher.go b/node/datacipher.go new file mode 100644 index 0000000..b631d14 --- /dev/null +++ b/node/datacipher.go @@ -0,0 +1,97 @@ +package node + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "unsafe" +) + +// ---------------------------------------------------------------------------- + +const ( + dataStreamID = 1 + dataHeaderSize = 12 + dataCipherOverhead = 16 + 1 +) + +type dataHeader struct { + Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. + SourceIP byte + DestIP byte +} + +func (h *dataHeader) Parse(b []byte) { + h.Counter = *(*uint64)(unsafe.Pointer(&b[0])) + h.SourceIP = b[8] + h.DestIP = b[9] +} + +func (h *dataHeader) Marshal(buf []byte) { + *(*uint64)(unsafe.Pointer(&buf[0])) = h.Counter + buf[8] = h.SourceIP + buf[9] = h.DestIP + buf[10] = 0 + buf[11] = 0 +} + +// ---------------------------------------------------------------------------- + +type dataCipher struct { + key []byte + aead cipher.AEAD +} + +func newDataCipher() *dataCipher { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + panic(err) + } + return newDataCipherFromKey(key) +} + +// key must be 32 bytes. +func newDataCipherFromKey(key []byte) *dataCipher { + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + panic(err) + } + + return &dataCipher{key: key, aead: aead} +} + +func (sc *dataCipher) Key() []byte { + return sc.key +} + +func (sc *dataCipher) Encrypt(h *dataHeader, data, out []byte) []byte { + out = out[:dataHeaderSize+dataCipherOverhead+len(data)] + out[0] = dataStreamID + + h.Marshal(out[1:]) + + const s = dataHeaderSize + sc.aead.Seal(out[1+s:1+s], out[1:1+s], data, nil) + return out +} + +func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, h dataHeader, ok bool) { + const s = dataHeaderSize + if len(encrypted) < s+dataCipherOverhead { + ok = false + return + } + + h.Parse(encrypted[1 : 1+s]) + + var err error + + data, err = sc.aead.Open(out[:0], encrypted[1:1+s], encrypted[1+s:], nil) + ok = err == nil + return +} diff --git a/node/datacipher_test.go b/node/datacipher_test.go new file mode 100644 index 0000000..8a0f012 --- /dev/null +++ b/node/datacipher_test.go @@ -0,0 +1,138 @@ +package node + +import ( + "bytes" + "crypto/rand" + mrand "math/rand/v2" + "reflect" + "testing" +) + +func TestDataCipher(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := dataHeader{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(&h1, plaintext, encrypted) + + dc2 := newDataCipherFromKey(dc1.Key()) + + decrypted, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("not equal") + } + + if !reflect.DeepEqual(h1, h2) { + t.Fatalf("%v != %v", h1, h2) + } + } +} + +func TestDataCipher_ModifyCiphertext(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := dataHeader{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(&h1, plaintext, encrypted) + encrypted[mrand.IntN(len(encrypted))]++ + + dc2 := newDataCipherFromKey(dc1.Key()) + + _, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if ok { + t.Fatal(ok, h2) + } + } +} + +func TestDataCipher_ShortCiphertext(t *testing.T) { + dc1 := newDataCipher() + shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) + rand.Read(shortText) + _, _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) + if ok { + t.Fatal(ok) + } +} + +func BenchmarkDataCipher_Encrypt(b *testing.B) { + h1 := dataHeader{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + b.ResetTimer() + for i := 0; i < b.N; i++ { + encrypted = dc1.Encrypt(&h1, plaintext, encrypted) + } +} + +func BenchmarkDataCipher_Decrypt(b *testing.B) { + h1 := dataHeader{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(&h1, plaintext, encrypted) + + decrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + decrypted, _, _ = dc1.Decrypt(encrypted, decrypted) + } +} diff --git a/node/dupcheck.go b/node/dupcheck.go index e960bd4..fac7a72 100644 --- a/node/dupcheck.go +++ b/node/dupcheck.go @@ -1,7 +1,5 @@ package node -import "log" - type dupCheck struct { bitSet head int @@ -22,7 +20,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool { // Before head => it's late, say it's a dup. if counter < dc.headCounter { - log.Printf("Late: %d", counter) return true } @@ -30,7 +27,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool { if counter < dc.tailCounter { index := (int(counter-dc.headCounter) + dc.head) % bitSetSize if dc.Get(index) { - log.Printf("Dup: %d, %d", counter, dc.tailCounter) return true } diff --git a/node/dupcheck_test.go b/node/dupcheck_test.go index 9a939b5..2156b4e 100644 --- a/node/dupcheck_test.go +++ b/node/dupcheck_test.go @@ -1,7 +1,6 @@ package node import ( - "log" "testing" ) @@ -49,8 +48,6 @@ func TestDupCheck(t *testing.T) { for i, tc := range testCases { if ok := dc.IsDup(tc.Counter); ok != tc.Dup { - log.Printf("%b", dc.bitSet) - log.Printf("%+v", *dc) t.Fatal(i, ok, tc) } } diff --git a/node/routingpacket.go b/node/routingpacket.go index 4e35055..64f0374 100644 --- a/node/routingpacket.go +++ b/node/routingpacket.go @@ -8,10 +8,8 @@ import ( var errMalformedPacket = errors.New("malformed packet") const ( - packetTypeInvalid = iota - // Used to maintain connection. - packetTypePing + packetTypePing = iota + 1 packetTypePong )