wip
This commit is contained in:
		
							
								
								
									
										61
									
								
								node/cipher-data.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								node/cipher-data.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/aes" | ||||||
|  | 	"crypto/cipher" | ||||||
|  | 	"crypto/rand" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type dataCipher struct { | ||||||
|  | 	key  []byte | ||||||
|  | 	aead cipher.AEAD | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newDataCipher() *dataCipher { | ||||||
|  | 	key := make([]byte, 32) | ||||||
|  | 	if _, err := rand.Read(key); err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	return newDataCipherFromKey(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // key must be 32 bytes. | ||||||
|  | func newDataCipherFromKey(key []byte) *dataCipher { | ||||||
|  | 	block, err := aes.NewCipher(key) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	aead, err := cipher.NewGCM(block) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &dataCipher{key: key, aead: aead} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Key() []byte { | ||||||
|  | 	return sc.key | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte { | ||||||
|  | 	const s = dataHeaderSize | ||||||
|  | 	out = out[:s+dataCipherOverhead+len(data)] | ||||||
|  | 	h.Marshal(dataStreamID, out[:s]) | ||||||
|  | 	sc.aead.Seal(out[s:s], out[:s], data, nil) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { | ||||||
|  | 	const s = dataHeaderSize | ||||||
|  | 	if len(encrypted) < s+dataCipherOverhead { | ||||||
|  | 		ok = false | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil) | ||||||
|  | 	ok = err == nil | ||||||
|  | 	return | ||||||
|  | } | ||||||
| @@ -22,7 +22,7 @@ func TestDataCipher(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, plaintext := range testCases { | 	for _, plaintext := range testCases { | ||||||
| 		h1 := dataHeader{ | 		h1 := xHeader{ | ||||||
| 			Counter:  235153, | 			Counter:  235153, | ||||||
| 			SourceIP: 4, | 			SourceIP: 4, | ||||||
| 			DestIP:   88, | 			DestIP:   88, | ||||||
| @@ -31,11 +31,13 @@ func TestDataCipher(t *testing.T) { | |||||||
| 		encrypted := make([]byte, bufferSize) | 		encrypted := make([]byte, bufferSize) | ||||||
| 
 | 
 | ||||||
| 		dc1 := newDataCipher() | 		dc1 := newDataCipher() | ||||||
| 		encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | 		encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
|  | 		h2 := xHeader{} | ||||||
|  | 		h2.Parse(encrypted) | ||||||
| 
 | 
 | ||||||
| 		dc2 := newDataCipherFromKey(dc1.Key()) | 		dc2 := newDataCipherFromKey(dc1.Key()) | ||||||
| 
 | 
 | ||||||
| 		decrypted, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | 		decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | ||||||
| 		if !ok { | 		if !ok { | ||||||
| 			t.Fatal(ok) | 			t.Fatal(ok) | ||||||
| 		} | 		} | ||||||
| @@ -64,7 +66,7 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, plaintext := range testCases { | 	for _, plaintext := range testCases { | ||||||
| 		h1 := dataHeader{ | 		h1 := xHeader{ | ||||||
| 			Counter:  235153, | 			Counter:  235153, | ||||||
| 			SourceIP: 4, | 			SourceIP: 4, | ||||||
| 			DestIP:   88, | 			DestIP:   88, | ||||||
| @@ -73,14 +75,14 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) { | |||||||
| 		encrypted := make([]byte, bufferSize) | 		encrypted := make([]byte, bufferSize) | ||||||
| 
 | 
 | ||||||
| 		dc1 := newDataCipher() | 		dc1 := newDataCipher() | ||||||
| 		encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | 		encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
| 		encrypted[mrand.IntN(len(encrypted))]++ | 		encrypted[mrand.IntN(len(encrypted))]++ | ||||||
| 
 | 
 | ||||||
| 		dc2 := newDataCipherFromKey(dc1.Key()) | 		dc2 := newDataCipherFromKey(dc1.Key()) | ||||||
| 
 | 
 | ||||||
| 		_, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | 		_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | ||||||
| 		if ok { | 		if ok { | ||||||
| 			t.Fatal(ok, h2) | 			t.Fatal(ok) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -89,14 +91,14 @@ func TestDataCipher_ShortCiphertext(t *testing.T) { | |||||||
| 	dc1 := newDataCipher() | 	dc1 := newDataCipher() | ||||||
| 	shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) | 	shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) | ||||||
| 	rand.Read(shortText) | 	rand.Read(shortText) | ||||||
| 	_, _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) | 	_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		t.Fatal(ok) | 		t.Fatal(ok) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func BenchmarkDataCipher_Encrypt(b *testing.B) { | func BenchmarkDataCipher_Encrypt(b *testing.B) { | ||||||
| 	h1 := dataHeader{ | 	h1 := xHeader{ | ||||||
| 		Counter:  235153, | 		Counter:  235153, | ||||||
| 		SourceIP: 4, | 		SourceIP: 4, | ||||||
| 		DestIP:   88, | 		DestIP:   88, | ||||||
| @@ -110,12 +112,12 @@ func BenchmarkDataCipher_Encrypt(b *testing.B) { | |||||||
| 	dc1 := newDataCipher() | 	dc1 := newDataCipher() | ||||||
| 	b.ResetTimer() | 	b.ResetTimer() | ||||||
| 	for i := 0; i < b.N; i++ { | 	for i := 0; i < b.N; i++ { | ||||||
| 		encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | 		encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func BenchmarkDataCipher_Decrypt(b *testing.B) { | func BenchmarkDataCipher_Decrypt(b *testing.B) { | ||||||
| 	h1 := dataHeader{ | 	h1 := xHeader{ | ||||||
| 		Counter:  235153, | 		Counter:  235153, | ||||||
| 		SourceIP: 4, | 		SourceIP: 4, | ||||||
| 		DestIP:   88, | 		DestIP:   88, | ||||||
| @@ -127,12 +129,12 @@ func BenchmarkDataCipher_Decrypt(b *testing.B) { | |||||||
| 	encrypted := make([]byte, bufferSize) | 	encrypted := make([]byte, bufferSize) | ||||||
| 
 | 
 | ||||||
| 	dc1 := newDataCipher() | 	dc1 := newDataCipher() | ||||||
| 	encrypted = dc1.Encrypt(&h1, plaintext, encrypted) | 	encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
| 
 | 
 | ||||||
| 	decrypted := make([]byte, bufferSize) | 	decrypted := make([]byte, bufferSize) | ||||||
| 
 | 
 | ||||||
| 	b.ResetTimer() | 	b.ResetTimer() | ||||||
| 	for i := 0; i < b.N; i++ { | 	for i := 0; i < b.N; i++ { | ||||||
| 		decrypted, _, _ = dc1.Decrypt(encrypted, decrypted) | 		decrypted, _ = dc1.Decrypt(encrypted, decrypted) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
							
								
								
									
										26
									
								
								node/cipher-routing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								node/cipher-routing.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import "golang.org/x/crypto/nacl/box" | ||||||
|  |  | ||||||
|  | type routingCipher struct { | ||||||
|  | 	sharedKey [32]byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newRoutingCipher(privKey, pubKey []byte) routingCipher { | ||||||
|  | 	shared := [32]byte{} | ||||||
|  | 	box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) | ||||||
|  | 	return routingCipher{shared} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rc routingCipher) Encrypt(h xHeader, data, out []byte) []byte { | ||||||
|  | 	const s = routingHeaderSize | ||||||
|  | 	out = out[:s+routingCipherOverhead+len(data)] | ||||||
|  | 	h.Marshal(routingStreamID, out[:s]) | ||||||
|  | 	box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &rc.sharedKey) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rc routingCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { | ||||||
|  | 	const s = routingHeaderSize | ||||||
|  | 	return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &rc.sharedKey) | ||||||
|  | } | ||||||
							
								
								
									
										114
									
								
								node/cipher-routing_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								node/cipher-routing_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/crypto/nacl/box" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func newRoutingCipherForTesting() (c1, c2 routingCipher) { | ||||||
|  | 	pubKey1, privKey1, err := box.GenerateKey(rand.Reader) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	pubKey2, privKey2, err := box.GenerateKey(rand.Reader) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return newRoutingCipher(privKey1[:], pubKey2[:]), | ||||||
|  | 		newRoutingCipher(privKey2[:], pubKey1[:]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRoutingCipher(t *testing.T) { | ||||||
|  | 	c1, c2 := newRoutingCipherForTesting() | ||||||
|  |  | ||||||
|  | 	maxSizePlaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) | ||||||
|  | 	rand.Read(maxSizePlaintext) | ||||||
|  |  | ||||||
|  | 	testCases := [][]byte{ | ||||||
|  | 		make([]byte, 0), | ||||||
|  | 		{1}, | ||||||
|  | 		{255}, | ||||||
|  | 		{1, 2, 3, 4, 5}, | ||||||
|  | 		[]byte("Hello world"), | ||||||
|  | 		maxSizePlaintext, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, plaintext := range testCases { | ||||||
|  | 		h1 := xHeader{ | ||||||
|  | 			Counter:  235153, | ||||||
|  | 			SourceIP: 4, | ||||||
|  | 			DestIP:   88, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 		encrypted = c1.Encrypt(h1, plaintext, encrypted) | ||||||
|  |  | ||||||
|  | 		decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) | ||||||
|  | 		if !ok { | ||||||
|  | 			t.Fatal(ok) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !bytes.Equal(decrypted, plaintext) { | ||||||
|  | 			t.Fatal("not equal") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRoutingCipher_ShortCiphertext(t *testing.T) { | ||||||
|  | 	c1, _ := newRoutingCipherForTesting() | ||||||
|  | 	shortText := make([]byte, routingHeaderSize+routingCipherOverhead-1) | ||||||
|  | 	rand.Read(shortText) | ||||||
|  | 	_, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) | ||||||
|  | 	if ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkRoutingCipher_Encrypt(b *testing.B) { | ||||||
|  | 	c1, _ := newRoutingCipherForTesting() | ||||||
|  | 	h1 := xHeader{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		encrypted = c1.Encrypt(h1, plaintext, encrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkRoutingCipher_Decrypt(b *testing.B) { | ||||||
|  | 	c1, c2 := newRoutingCipherForTesting() | ||||||
|  |  | ||||||
|  | 	h1 := xHeader{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	encrypted = c1.Encrypt(h1, plaintext, encrypted) | ||||||
|  |  | ||||||
|  | 	decrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		decrypted, _ = c2.Decrypt(encrypted, decrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										6
									
								
								node/cipher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								node/cipher.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | type packetCipher interface { | ||||||
|  | 	Encrypt(h xHeader, data, out []byte) []byte | ||||||
|  | 	Decrypt(encrypted, out []byte) (data []byte, ok bool) | ||||||
|  | } | ||||||
							
								
								
									
										39
									
								
								node/conn.go
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								node/conn.go
									
									
									
									
									
								
							| @@ -35,6 +35,33 @@ func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *conn | |||||||
| 	return w | 	return w | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | 	func (w *connWriter) SendRouting(remoteIP byte, data []byte) { | ||||||
|  | 		dstPeer := w.routing.Get(remoteIP) | ||||||
|  | 		if dstPeer == nil { | ||||||
|  | 			log.Printf("No peer: %d", remoteIP) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		var viaPeer *peer | ||||||
|  |  | ||||||
|  | 		if dstPeer.Addr == zeroAddrPort { | ||||||
|  | 			viaPeer = w.routing.Mediator() | ||||||
|  | 			if viaPeer == nil { | ||||||
|  | 				log.Printf("No mediator: %d", remoteIP) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		w.sendRouting(dstPeer, viaPeer, data) | ||||||
|  | 	} | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | func (w *connWriter) SendData(remoteIP byte, data []byte) { | ||||||
|  | 	// TODO | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TODO: deprecated | ||||||
| func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) { | func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) { | ||||||
| 	dstPeer := w.routing.Get(remoteIP) | 	dstPeer := w.routing.Get(remoteIP) | ||||||
| 	if dstPeer == nil { | 	if dstPeer == nil { | ||||||
| @@ -50,11 +77,11 @@ func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) { | |||||||
| 	var viaPeer *peer | 	var viaPeer *peer | ||||||
| 	if dstPeer.Mediated { | 	if dstPeer.Mediated { | ||||||
| 		viaPeer = w.routing.mediator.Load() | 		viaPeer = w.routing.mediator.Load() | ||||||
| 		if viaPeer == nil || viaPeer.Addr == nil { | 		if viaPeer == nil || viaPeer.Addr == zeroAddrPort { | ||||||
| 			log.Printf("Mediator not connected") | 			log.Printf("Mediator not connected") | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} else if dstPeer.Addr == nil { | 	} else if dstPeer.Addr == zeroAddrPort { | ||||||
| 		log.Printf("Peer doesn't have address: %d", remoteIP) | 		log.Printf("Peer doesn't have address: %d", remoteIP) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -62,6 +89,7 @@ func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) { | |||||||
| 	w.WriteToPeer(dstPeer, viaPeer, stream, data) | 	w.WriteToPeer(dstPeer, viaPeer, stream, data) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // TODO: deprecated | ||||||
| func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) { | func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) { | ||||||
| 	w.lock.Lock() | 	w.lock.Lock() | ||||||
|  |  | ||||||
| @@ -89,20 +117,21 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt | |||||||
| 		addr = viaPeer.Addr | 		addr = viaPeer.Addr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if _, err := w.WriteToUDPAddrPort(buf, *addr); err != nil { | 	if _, err := w.WriteToUDPAddrPort(buf, addr); err != nil { | ||||||
| 		log.Fatalf("Failed to write to UDP port: %v", err) | 		log.Fatalf("Failed to write to UDP port: %v", err) | ||||||
| 	} | 	} | ||||||
| 	w.lock.Unlock() | 	w.lock.Unlock() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // TODO: deprecated | ||||||
| func (w *connWriter) Forward(dstIP byte, packet []byte) { | func (w *connWriter) Forward(dstIP byte, packet []byte) { | ||||||
| 	dstPeer := w.routing.Get(dstIP) | 	dstPeer := w.routing.Get(dstIP) | ||||||
| 	if dstPeer == nil || dstPeer.Addr == nil { | 	if dstPeer == nil || dstPeer.Addr == zeroAddrPort { | ||||||
| 		log.Printf("No peer: %d", dstIP) | 		log.Printf("No peer: %d", dstIP) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if _, err := w.WriteToUDPAddrPort(packet, *dstPeer.Addr); err != nil { | 	if _, err := w.WriteToUDPAddrPort(packet, dstPeer.Addr); err != nil { | ||||||
| 		log.Fatalf("Failed to write to UDP port: %v", err) | 		log.Fatalf("Failed to write to UDP port: %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,97 +0,0 @@ | |||||||
| package node |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"crypto/aes" |  | ||||||
| 	"crypto/cipher" |  | ||||||
| 	"crypto/rand" |  | ||||||
| 	"unsafe" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	dataStreamID       = 1 |  | ||||||
| 	dataHeaderSize     = 12 |  | ||||||
| 	dataCipherOverhead = 16 + 1 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type dataHeader struct { |  | ||||||
| 	Counter  uint64 // Init with fasttime.Now() << 30 to ensure monotonic. |  | ||||||
| 	SourceIP byte |  | ||||||
| 	DestIP   byte |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *dataHeader) Parse(b []byte) { |  | ||||||
| 	h.Counter = *(*uint64)(unsafe.Pointer(&b[0])) |  | ||||||
| 	h.SourceIP = b[8] |  | ||||||
| 	h.DestIP = b[9] |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *dataHeader) Marshal(buf []byte) { |  | ||||||
| 	*(*uint64)(unsafe.Pointer(&buf[0])) = h.Counter |  | ||||||
| 	buf[8] = h.SourceIP |  | ||||||
| 	buf[9] = h.DestIP |  | ||||||
| 	buf[10] = 0 |  | ||||||
| 	buf[11] = 0 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type dataCipher struct { |  | ||||||
| 	key  []byte |  | ||||||
| 	aead cipher.AEAD |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func newDataCipher() *dataCipher { |  | ||||||
| 	key := make([]byte, 32) |  | ||||||
| 	if _, err := rand.Read(key); err != nil { |  | ||||||
| 		panic(err) |  | ||||||
| 	} |  | ||||||
| 	return newDataCipherFromKey(key) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // key must be 32 bytes. |  | ||||||
| func newDataCipherFromKey(key []byte) *dataCipher { |  | ||||||
| 	block, err := aes.NewCipher(key) |  | ||||||
| 	if err != nil { |  | ||||||
| 		panic(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	aead, err := cipher.NewGCM(block) |  | ||||||
| 	if err != nil { |  | ||||||
| 		panic(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &dataCipher{key: key, aead: aead} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (sc *dataCipher) Key() []byte { |  | ||||||
| 	return sc.key |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (sc *dataCipher) Encrypt(h *dataHeader, data, out []byte) []byte { |  | ||||||
| 	out = out[:dataHeaderSize+dataCipherOverhead+len(data)] |  | ||||||
| 	out[0] = dataStreamID |  | ||||||
|  |  | ||||||
| 	h.Marshal(out[1:]) |  | ||||||
|  |  | ||||||
| 	const s = dataHeaderSize |  | ||||||
| 	sc.aead.Seal(out[1+s:1+s], out[1:1+s], data, nil) |  | ||||||
| 	return out |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, h dataHeader, ok bool) { |  | ||||||
| 	const s = dataHeaderSize |  | ||||||
| 	if len(encrypted) < s+dataCipherOverhead { |  | ||||||
| 		ok = false |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	h.Parse(encrypted[1 : 1+s]) |  | ||||||
|  |  | ||||||
| 	var err error |  | ||||||
|  |  | ||||||
| 	data, err = sc.aead.Open(out[:0], encrypted[1:1+s], encrypted[1+s:], nil) |  | ||||||
| 	ok = err == nil |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| @@ -2,6 +2,41 @@ package node | |||||||
|  |  | ||||||
| import "unsafe" | import "unsafe" | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	routingStreamID       = 2 | ||||||
|  | 	routingHeaderSize     = 24 | ||||||
|  | 	routingCipherOverhead = 16 | ||||||
|  |  | ||||||
|  | 	dataStreamID       = 1 | ||||||
|  | 	dataHeaderSize     = 12 | ||||||
|  | 	dataCipherOverhead = 16 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // TODO: Rename | ||||||
|  | type xHeader struct { | ||||||
|  | 	Counter  uint64 // Init with fasttime.Now() << 30 to ensure monotonic. | ||||||
|  | 	SourceIP byte | ||||||
|  | 	DestIP   byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *xHeader) Parse(b []byte) { | ||||||
|  | 	h.Counter = *(*uint64)(unsafe.Pointer(&b[1])) | ||||||
|  | 	h.SourceIP = b[9] | ||||||
|  | 	h.DestIP = b[10] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *xHeader) Marshal(streamID byte, buf []byte) { | ||||||
|  | 	buf[0] = streamID | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter | ||||||
|  | 	buf[9] = h.SourceIP | ||||||
|  | 	buf[10] = h.DestIP | ||||||
|  | 	buf[11] = 0 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  | // TODO: Remove this code. | ||||||
| const ( | const ( | ||||||
| 	headerSize    = 24 | 	headerSize    = 24 | ||||||
| 	streamData    = 1 | 	streamData    = 1 | ||||||
|   | |||||||
							
								
								
									
										89
									
								
								node/packets.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								node/packets.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"time" | ||||||
|  | 	"unsafe" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var errMalformedPacket = errors.New("malformed packet") | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	packetTypePing = iota + 1 | ||||||
|  | 	packetTypePong | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type packetWrapper struct { | ||||||
|  | 	SrcIP      byte | ||||||
|  | 	RemoteAddr netip.AddrPort | ||||||
|  | 	Packet     any | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | // A pingPacket is sent from a node acting as a client, to a node acting | ||||||
|  | // as a server. It always contains the shared key the client is expecting | ||||||
|  | // to use for data encryption with the server. | ||||||
|  | type pingPacket struct { | ||||||
|  | 	SentAt    int64 // UnixMilli. | ||||||
|  | 	SharedKey [32]byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newPingPacket(sharedKey []byte) (pp pingPacket) { | ||||||
|  | 	pp.SentAt = time.Now().UnixMilli() | ||||||
|  | 	copy(pp.SharedKey[:], sharedKey) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p pingPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	buf = buf[:41] | ||||||
|  | 	buf[0] = packetTypePing | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt) | ||||||
|  | 	copy(buf[9:41], p.SharedKey[:]) | ||||||
|  | 	return buf | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *pingPacket) Parse(buf []byte) error { | ||||||
|  | 	if len(buf) != 41 { | ||||||
|  | 		return errMalformedPacket | ||||||
|  | 	} | ||||||
|  | 	p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) | ||||||
|  | 	copy(p.SharedKey[:], buf[9:41]) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | // A pongPacket is sent by a node in a server role in response to a pingPacket. | ||||||
|  | type pongPacket struct { | ||||||
|  | 	SentAt  int64 // UnixMilli. | ||||||
|  | 	RecvdAt int64 // UnixMilli. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newPongPacket(sentAt int64) (pp pongPacket) { | ||||||
|  | 	pp.SentAt = sentAt | ||||||
|  | 	pp.RecvdAt = time.Now().UnixMilli() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p pongPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	buf = buf[:17] | ||||||
|  | 	buf[0] = packetTypePong | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt) | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&buf[9])) = uint64(p.RecvdAt) | ||||||
|  |  | ||||||
|  | 	return buf | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *pongPacket) Parse(buf []byte) error { | ||||||
|  | 	if len(buf) != 17 { | ||||||
|  | 		return errMalformedPacket | ||||||
|  | 	} | ||||||
|  | 	p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) | ||||||
|  | 	p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9])) | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										42
									
								
								node/packets_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								node/packets_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestPacketPing(t *testing.T) { | ||||||
|  | 	sharedKey := make([]byte, 32) | ||||||
|  | 	rand.Read(sharedKey) | ||||||
|  |  | ||||||
|  | 	buf := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	p := newPingPacket(sharedKey) | ||||||
|  | 	out := p.Marshal(buf) | ||||||
|  |  | ||||||
|  | 	p2 := pingPacket{} | ||||||
|  | 	if err := p2.Parse(out); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(p, p2) { | ||||||
|  | 		t.Fatal(p, p2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPacketPong(t *testing.T) { | ||||||
|  | 	buf := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	p := newPongPacket(123566) | ||||||
|  | 	out := p.Marshal(buf) | ||||||
|  |  | ||||||
|  | 	p2 := pongPacket{} | ||||||
|  | 	if err := p2.Parse(out); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(p, p2) { | ||||||
|  | 		t.Fatal(p, p2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -36,7 +36,7 @@ type peerSupervisor struct { | |||||||
| 	// Peer-related items. | 	// Peer-related items. | ||||||
| 	version        int64 // Ony accessed in HandlePeerUpdate. | 	version        int64 // Ony accessed in HandlePeerUpdate. | ||||||
| 	peer           *m.Peer | 	peer           *m.Peer | ||||||
| 	remoteAddrPort *netip.AddrPort | 	remoteAddrPort netip.AddrPort | ||||||
| 	mediated       bool | 	mediated       bool | ||||||
| 	sharedKey      []byte | 	sharedKey      []byte | ||||||
|  |  | ||||||
| @@ -123,9 +123,9 @@ func (s *peerSupervisor) stateInit() stateFunc { | |||||||
| 	addr, ok := netip.AddrFromSlice(s.peer.PublicIP) | 	addr, ok := netip.AddrFromSlice(s.peer.PublicIP) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		addrPort := netip.AddrPortFrom(addr, s.peer.Port) | 		addrPort := netip.AddrPortFrom(addr, s.peer.Port) | ||||||
| 		s.remoteAddrPort = &addrPort | 		s.remoteAddrPort = addrPort | ||||||
| 	} else { | 	} else { | ||||||
| 		s.remoteAddrPort = nil | 		s.remoteAddrPort = zeroAddrPort | ||||||
| 	} | 	} | ||||||
| 	s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey) | 	s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey) | ||||||
|  |  | ||||||
| @@ -153,7 +153,7 @@ func (s *peerSupervisor) stateSelectRole() stateFunc { | |||||||
| 	s.logf("STATE: SelectRole") | 	s.logf("STATE: SelectRole") | ||||||
| 	s.updateRoutingTable(false) | 	s.updateRoutingTable(false) | ||||||
|  |  | ||||||
| 	if s.remoteAddrPort != nil { | 	if s.remoteAddrPort != zeroAddrPort { | ||||||
| 		s.mediated = false | 		s.mediated = false | ||||||
|  |  | ||||||
| 		// If both remote and local are public, one side acts as client, and one | 		// If both remote and local are public, one side acts as client, and one | ||||||
| @@ -186,7 +186,7 @@ func (s *peerSupervisor) stateAccept() stateFunc { | |||||||
| 			switch pkt.Type { | 			switch pkt.Type { | ||||||
|  |  | ||||||
| 			case packetTypePing: | 			case packetTypePing: | ||||||
| 				s.remoteAddrPort = &pkt.Addr | 				s.remoteAddrPort = pkt.Addr | ||||||
| 				s.updateRoutingTable(true) | 				s.updateRoutingTable(true) | ||||||
| 				s.sendPong(pkt.TraceID) | 				s.sendPong(pkt.TraceID) | ||||||
| 				return s.stateConnected | 				return s.stateConnected | ||||||
| @@ -256,8 +256,8 @@ func (s *peerSupervisor) stateConnected() stateFunc { | |||||||
|  |  | ||||||
| 				// Server should always follow remote port. | 				// Server should always follow remote port. | ||||||
| 				if s.localPublic { | 				if s.localPublic { | ||||||
| 					if pkt.Addr != *s.remoteAddrPort { | 					if pkt.Addr != s.remoteAddrPort { | ||||||
| 						s.remoteAddrPort = &pkt.Addr | 						s.remoteAddrPort = pkt.Addr | ||||||
| 						s.updateRoutingTable(true) | 						s.updateRoutingTable(true) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|   | |||||||
| @@ -12,12 +12,18 @@ import ( | |||||||
| 	"vppn/m" | 	"vppn/m" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var zeroAddrPort = netip.AddrPort{} | ||||||
|  |  | ||||||
| type peer struct { | type peer struct { | ||||||
| 	Up        bool // No data will be sent to peers that are down. | 	IP            byte           // The VPN IP. | ||||||
| 	Mediator  bool | 	Up            bool           // No data will be sent to peers that are down. | ||||||
|  | 	Addr          netip.AddrPort // If we have direct connection, otherwise use mediator. | ||||||
|  | 	Mediator      bool           // True if the peer will mediate. | ||||||
|  | 	RoutingCipher routingCipher | ||||||
|  | 	DataCipher    dataCipher | ||||||
|  |  | ||||||
|  | 	// TODO: Deprecated below. | ||||||
| 	Mediated  bool | 	Mediated  bool | ||||||
| 	IP        byte |  | ||||||
| 	Addr      *netip.AddrPort // If we have direct connection, otherwise use mediator. |  | ||||||
| 	SharedKey []byte | 	SharedKey []byte | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -48,6 +54,10 @@ func (r *routingTable) Set(ip byte, p *peer) { | |||||||
| 	r.table[ip].Store(p) | 	r.table[ip].Store(p) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (r *routingTable) Mediator() *peer { | ||||||
|  | 	return r.mediator.Load() | ||||||
|  | } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type router struct { | type router struct { | ||||||
|   | |||||||
| @@ -1,18 +1,9 @@ | |||||||
| package node | package node | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" |  | ||||||
| 	"unsafe" | 	"unsafe" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var errMalformedPacket = errors.New("malformed packet") |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	// Used to maintain connection. |  | ||||||
| 	packetTypePing = iota + 1 |  | ||||||
| 	packetTypePong |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type routingPacket struct { | type routingPacket struct { | ||||||
| 	Type    byte   // One of the packetType* constants. | 	Type    byte   // One of the packetType* constants. | ||||||
| 	TraceID uint64 // For matching requests and responses. | 	TraceID uint64 // For matching requests and responses. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user