WIP: separate data and control ciphers.
This commit is contained in:
		| @@ -74,7 +74,7 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt | |||||||
| 		Stream:   stream, | 		Stream:   stream, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	buf := encryptPacket(&h, dstPeer.SharedKey, data, w.buf) | 	buf := encryptPacketAsym(&h, dstPeer.SharedKey, data, w.buf) | ||||||
|  |  | ||||||
| 	if viaPeer != nil { | 	if viaPeer != nil { | ||||||
| 		h := header{ | 		h := header{ | ||||||
| @@ -85,7 +85,7 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt | |||||||
| 			Stream:   stream, | 			Stream:   stream, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		buf = encryptPacket(&h, viaPeer.SharedKey, buf, w.buf2) | 		buf = encryptPacketAsym(&h, viaPeer.SharedKey, buf, w.buf2) | ||||||
| 		addr = viaPeer.Addr | 		addr = viaPeer.Addr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -155,7 +155,7 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		out, ok := decryptPacket(peer.SharedKey, data, r.buf) | 		out, ok := decryptPacketAsym(peer.SharedKey, data, r.buf) | ||||||
| 		if !ok { | 		if !ok { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -8,14 +8,14 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| // Encrypting the packet will also set the header's DataSize field. | // Encrypting the packet will also set the header's DataSize field. | ||||||
| func encryptPacket(h *header, sharedKey, data, out []byte) []byte { | func encryptPacketAsym(h *header, sharedKey, data, out []byte) []byte { | ||||||
| 	out = out[:headerSize] | 	out = out[:headerSize] | ||||||
| 	h.Marshal(out) | 	h.Marshal(out) | ||||||
| 	b := box.SealAfterPrecomputation(out[headerSize:headerSize], data, (*[24]byte)(out[:headerSize]), (*[32]byte)(sharedKey)) | 	b := box.SealAfterPrecomputation(out[headerSize:headerSize], data, (*[24]byte)(out[:headerSize]), (*[32]byte)(sharedKey)) | ||||||
| 	return out[:len(b)+headerSize] | 	return out[:len(b)+headerSize] | ||||||
| } | } | ||||||
|  |  | ||||||
| func decryptPacket(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) { | func decryptPacketAsym(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) { | ||||||
| 	return box.OpenAfterPrecomputation( | 	return box.OpenAfterPrecomputation( | ||||||
| 		out[:0], | 		out[:0], | ||||||
| 		packetAndHeader[headerSize:], | 		packetAndHeader[headerSize:], | ||||||
|   | |||||||
| @@ -3,14 +3,13 @@ package node | |||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"log" |  | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"golang.org/x/crypto/nacl/box" | 	"golang.org/x/crypto/nacl/box" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestEncryptDecryptPacket(t *testing.T) { | func TestEncryptDecryptAsym(t *testing.T) { | ||||||
| 	pubKey1, privKey1, err := box.GenerateKey(rand.Reader) | 	pubKey1, privKey1, err := box.GenerateKey(rand.Reader) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| @@ -21,8 +20,6 @@ func TestEncryptDecryptPacket(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Printf("\n%#v\n%#v\n%#v\n%#v\n", pubKey1, privKey1, pubKey2, privKey2) |  | ||||||
|  |  | ||||||
| 	sharedEncKey := [32]byte{} | 	sharedEncKey := [32]byte{} | ||||||
| 	box.Precompute(&sharedEncKey, pubKey2, privKey1) | 	box.Precompute(&sharedEncKey, pubKey2, privKey1) | ||||||
|  |  | ||||||
| @@ -41,11 +38,11 @@ func TestEncryptDecryptPacket(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	encrypted := make([]byte, bufferSize) | 	encrypted := make([]byte, bufferSize) | ||||||
| 	encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) | 	encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted) | ||||||
|  |  | ||||||
| 	decrypted := make([]byte, bufferSize) | 	decrypted := make([]byte, bufferSize) | ||||||
| 	var ok bool | 	var ok bool | ||||||
| 	decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) | 	decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		t.Fatal(ok) | 		t.Fatal(ok) | ||||||
| 	} | 	} | ||||||
| @@ -62,7 +59,7 @@ func TestEncryptDecryptPacket(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func BenchmarkEncryptPacket(b *testing.B) { | func BenchmarkEncryptAsym(b *testing.B) { | ||||||
| 	_, privKey1, err := box.GenerateKey(rand.Reader) | 	_, privKey1, err := box.GenerateKey(rand.Reader) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		b.Fatal(err) | 		b.Fatal(err) | ||||||
| @@ -93,11 +90,11 @@ func BenchmarkEncryptPacket(b *testing.B) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for i := 0; i < b.N; i++ { | 	for i := 0; i < b.N; i++ { | ||||||
| 		encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) | 		encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func BenchmarkDecryptPacket(b *testing.B) { | func BenchmarkDecryptAsym(b *testing.B) { | ||||||
| 	pubKey1, privKey1, err := box.GenerateKey(rand.Reader) | 	pubKey1, privKey1, err := box.GenerateKey(rand.Reader) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		b.Fatal(err) | 		b.Fatal(err) | ||||||
| @@ -128,11 +125,11 @@ func BenchmarkDecryptPacket(b *testing.B) { | |||||||
| 		Stream:   1, | 		Stream:   1, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize)) | 	encrypted := encryptPacketAsym(&h, sharedEncKey[:], original, make([]byte, bufferSize)) | ||||||
| 	decrypted := make([]byte, bufferSize) | 	decrypted := make([]byte, bufferSize) | ||||||
| 	var ok bool | 	var ok bool | ||||||
| 	for i := 0; i < b.N; i++ { | 	for i := 0; i < b.N; i++ { | ||||||
| 		decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) | 		decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted) | ||||||
| 		if !ok { | 		if !ok { | ||||||
| 			panic(ok) | 			panic(ok) | ||||||
| 		} | 		} | ||||||
|   | |||||||
							
								
								
									
										97
									
								
								node/datacipher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								node/datacipher.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/aes" | ||||||
|  | 	"crypto/cipher" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"unsafe" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	dataStreamID       = 1 | ||||||
|  | 	dataHeaderSize     = 12 | ||||||
|  | 	dataCipherOverhead = 16 + 1 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type dataHeader struct { | ||||||
|  | 	Counter  uint64 // Init with fasttime.Now() << 30 to ensure monotonic. | ||||||
|  | 	SourceIP byte | ||||||
|  | 	DestIP   byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *dataHeader) Parse(b []byte) { | ||||||
|  | 	h.Counter = *(*uint64)(unsafe.Pointer(&b[0])) | ||||||
|  | 	h.SourceIP = b[8] | ||||||
|  | 	h.DestIP = b[9] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *dataHeader) Marshal(buf []byte) { | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&buf[0])) = h.Counter | ||||||
|  | 	buf[8] = h.SourceIP | ||||||
|  | 	buf[9] = h.DestIP | ||||||
|  | 	buf[10] = 0 | ||||||
|  | 	buf[11] = 0 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type dataCipher struct { | ||||||
|  | 	key  []byte | ||||||
|  | 	aead cipher.AEAD | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newDataCipher() *dataCipher { | ||||||
|  | 	key := make([]byte, 32) | ||||||
|  | 	if _, err := rand.Read(key); err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	return newDataCipherFromKey(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // key must be 32 bytes. | ||||||
|  | func newDataCipherFromKey(key []byte) *dataCipher { | ||||||
|  | 	block, err := aes.NewCipher(key) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	aead, err := cipher.NewGCM(block) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &dataCipher{key: key, aead: aead} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Key() []byte { | ||||||
|  | 	return sc.key | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Encrypt(h *dataHeader, data, out []byte) []byte { | ||||||
|  | 	out = out[:dataHeaderSize+dataCipherOverhead+len(data)] | ||||||
|  | 	out[0] = dataStreamID | ||||||
|  |  | ||||||
|  | 	h.Marshal(out[1:]) | ||||||
|  |  | ||||||
|  | 	const s = dataHeaderSize | ||||||
|  | 	sc.aead.Seal(out[1+s:1+s], out[1:1+s], data, nil) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, h dataHeader, ok bool) { | ||||||
|  | 	const s = dataHeaderSize | ||||||
|  | 	if len(encrypted) < s+dataCipherOverhead { | ||||||
|  | 		ok = false | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.Parse(encrypted[1 : 1+s]) | ||||||
|  |  | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	data, err = sc.aead.Open(out[:0], encrypted[1:1+s], encrypted[1+s:], nil) | ||||||
|  | 	ok = err == nil | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										138
									
								
								node/datacipher_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								node/datacipher_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	mrand "math/rand/v2" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDataCipher(t *testing.T) { | ||||||
|  | 	maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(maxSizePlaintext) | ||||||
|  |  | ||||||
|  | 	testCases := [][]byte{ | ||||||
|  | 		make([]byte, 0), | ||||||
|  | 		{1}, | ||||||
|  | 		{255}, | ||||||
|  | 		{1, 2, 3, 4, 5}, | ||||||
|  | 		[]byte("Hello world"), | ||||||
|  | 		maxSizePlaintext, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, plaintext := range testCases { | ||||||
|  | 		h1 := dataHeader{ | ||||||
|  | 			Counter:  235153, | ||||||
|  | 			SourceIP: 4, | ||||||
|  | 			DestIP:   88, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 		dc1 := newDataCipher() | ||||||
|  | 		encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | ||||||
|  |  | ||||||
|  | 		dc2 := newDataCipherFromKey(dc1.Key()) | ||||||
|  |  | ||||||
|  | 		decrypted, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | ||||||
|  | 		if !ok { | ||||||
|  | 			t.Fatal(ok) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !bytes.Equal(plaintext, decrypted) { | ||||||
|  | 			t.Fatal("not equal") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !reflect.DeepEqual(h1, h2) { | ||||||
|  | 			t.Fatalf("%v != %v", h1, h2) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDataCipher_ModifyCiphertext(t *testing.T) { | ||||||
|  | 	maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(maxSizePlaintext) | ||||||
|  |  | ||||||
|  | 	testCases := [][]byte{ | ||||||
|  | 		make([]byte, 0), | ||||||
|  | 		{1}, | ||||||
|  | 		{255}, | ||||||
|  | 		{1, 2, 3, 4, 5}, | ||||||
|  | 		[]byte("Hello world"), | ||||||
|  | 		maxSizePlaintext, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, plaintext := range testCases { | ||||||
|  | 		h1 := dataHeader{ | ||||||
|  | 			Counter:  235153, | ||||||
|  | 			SourceIP: 4, | ||||||
|  | 			DestIP:   88, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 		dc1 := newDataCipher() | ||||||
|  | 		encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | ||||||
|  | 		encrypted[mrand.IntN(len(encrypted))]++ | ||||||
|  |  | ||||||
|  | 		dc2 := newDataCipherFromKey(dc1.Key()) | ||||||
|  |  | ||||||
|  | 		_, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | ||||||
|  | 		if ok { | ||||||
|  | 			t.Fatal(ok, h2) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDataCipher_ShortCiphertext(t *testing.T) { | ||||||
|  | 	dc1 := newDataCipher() | ||||||
|  | 	shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) | ||||||
|  | 	rand.Read(shortText) | ||||||
|  | 	_, _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) | ||||||
|  | 	if ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkDataCipher_Encrypt(b *testing.B) { | ||||||
|  | 	h1 := dataHeader{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	dc1 := newDataCipher() | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkDataCipher_Decrypt(b *testing.B) { | ||||||
|  | 	h1 := dataHeader{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	dc1 := newDataCipher() | ||||||
|  | 	encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | ||||||
|  |  | ||||||
|  | 	decrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		decrypted, _, _ = dc1.Decrypt(encrypted, decrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,7 +1,5 @@ | |||||||
| package node | package node | ||||||
|  |  | ||||||
| import "log" |  | ||||||
|  |  | ||||||
| type dupCheck struct { | type dupCheck struct { | ||||||
| 	bitSet | 	bitSet | ||||||
| 	head        int | 	head        int | ||||||
| @@ -22,7 +20,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool { | |||||||
|  |  | ||||||
| 	// Before head => it's late, say it's a dup. | 	// Before head => it's late, say it's a dup. | ||||||
| 	if counter < dc.headCounter { | 	if counter < dc.headCounter { | ||||||
| 		log.Printf("Late: %d", counter) |  | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -30,7 +27,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool { | |||||||
| 	if counter < dc.tailCounter { | 	if counter < dc.tailCounter { | ||||||
| 		index := (int(counter-dc.headCounter) + dc.head) % bitSetSize | 		index := (int(counter-dc.headCounter) + dc.head) % bitSetSize | ||||||
| 		if dc.Get(index) { | 		if dc.Get(index) { | ||||||
| 			log.Printf("Dup: %d, %d", counter, dc.tailCounter) |  | ||||||
| 			return true | 			return true | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,7 +1,6 @@ | |||||||
| package node | package node | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"log" |  | ||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -49,8 +48,6 @@ func TestDupCheck(t *testing.T) { | |||||||
|  |  | ||||||
| 	for i, tc := range testCases { | 	for i, tc := range testCases { | ||||||
| 		if ok := dc.IsDup(tc.Counter); ok != tc.Dup { | 		if ok := dc.IsDup(tc.Counter); ok != tc.Dup { | ||||||
| 			log.Printf("%b", dc.bitSet) |  | ||||||
| 			log.Printf("%+v", *dc) |  | ||||||
| 			t.Fatal(i, ok, tc) | 			t.Fatal(i, ok, tc) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -8,10 +8,8 @@ import ( | |||||||
| var errMalformedPacket = errors.New("malformed packet") | var errMalformedPacket = errors.New("malformed packet") | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	packetTypeInvalid = iota |  | ||||||
|  |  | ||||||
| 	// Used to maintain connection. | 	// Used to maintain connection. | ||||||
| 	packetTypePing | 	packetTypePing = iota + 1 | ||||||
| 	packetTypePong | 	packetTypePong | ||||||
| ) | ) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user